diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index 32cd0df67a2..2e942377ecb 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -2,12 +2,12 @@ # CoreML backend for delegating a EdgeProgram to CoreML. +import hashlib import json import logging import shutil import tempfile -import uuid from dataclasses import asdict, dataclass from enum import Enum @@ -36,6 +36,55 @@ logger.setLevel(get_coreml_log_level(default_level=logging.WARNING)) +from google.protobuf import text_format + + +def _hash_model(model_spec: ct.proto.Model_pb2, model_path: Path) -> str: # pyre-ignore + """Hash model deterministically, including both spec and weights. + + This function addresses three sources of non-determinism in CoreML models: + + 1. Timestamps in metadata: CoreML's coremltools embeds a conversion timestamp + in the model's userDefined metadata (com.github.apple.coremltools.conversion_date). + We clear this metadata before hashing. + + 2. Random UUIDs in Manifest.json: The mlpackage's Manifest.json contains randomly + generated UUIDs that change on every save, even for identical model content. + We exclude this file from hashing. + + 3. Non-deterministic protobuf serialization: Protobuf's SerializeToString() does + not guarantee consistent field ordering across processes. We use text_format + for deterministic serialization instead. + """ + hasher = hashlib.sha256() + + # Hash model spec with non-deterministic metadata cleared + # Use text_format for deterministic serialization (protobuf binary + # serialization is not deterministic across processes) + spec_copy = ct.proto.Model_pb2.Model() # pyre-ignore + spec_copy.CopyFrom(model_spec) + # Only clear the specific non-deterministic key, not all userDefined metadata + if ( + "com.github.apple.coremltools.conversion_date" + in spec_copy.description.metadata.userDefined + ): + del spec_copy.description.metadata.userDefined[ + "com.github.apple.coremltools.conversion_date" + ] + hasher.update(text_format.MessageToString(spec_copy).encode()) + + # Hash weight files (exclude Manifest.json which contains random UUIDs) + for file_path in sorted(model_path.rglob("*")): + if file_path.is_file() and file_path.name != "Manifest.json": + # Skip the model.mlmodel since we already hashed the spec above + if file_path.name == "model.mlmodel": + continue + hasher.update(str(file_path.relative_to(model_path)).encode()) + hasher.update(file_path.read_bytes()) + + return hasher.hexdigest()[:32] + + class COMPILE_SPEC_KEYS(Enum): COMPUTE_UNITS = "compute_units" MODEL_TYPE = "model_type" @@ -448,10 +497,18 @@ def save_model_debug_info(model_debug_info: ModelDebugInfo, model_dir_path: Path def preprocess_model( mlmodel: ct.models.MLModel, model_type: MODEL_TYPE ) -> PreprocessResult: - identifier = "executorch_" + str(uuid.uuid4()) - dir_path: Path = Path(tempfile.gettempdir()) / identifier + dir_path: Path = Path(tempfile.mkdtemp()) model_dir_path: Path = dir_path / "lowered_module" model_spec: ct.proto.Model_pb2 = mlmodel.get_spec() + + # Save model first so we can hash both spec and weights. + model_path = model_dir_path / MODEL_PATHS.MODEL.value + mlmodel.save(str(model_path)) + + # Generate deterministic identifier from model content (spec + weights). + content_hash = _hash_model(model_spec, model_path) + identifier = "executorch_" + content_hash + logger.warning( f"The model with identifier {identifier} was exported with CoreML specification version {model_spec.specificationVersion}, and it will not run on all version of iOS/macOS." " See https://apple.github.io/coremltools/mlmodel/Format/Model.html#model for information on what OS versions are compatible with this specifcation version." @@ -462,10 +519,6 @@ def preprocess_model( model_spec=model_spec, identifier=identifier, ) - - # Save model. - model_path = model_dir_path / MODEL_PATHS.MODEL.value - mlmodel.save(str(model_path)) # Extract delegate mapping file. model_debug_info: Optional[ModelDebugInfo] = CoreMLBackend.get_model_debug_info( model_path