1# Copyright © 2024 Apple Inc. All rights reserved. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6import argparse 7import os 8import shutil 9from pathlib import Path 10 11from typing import List, Optional 12 13import executorchcoreml 14 15from executorch.backends.apple.coreml.compiler import CoreMLBackend 16 17from executorch.exir._serialize._program import deserialize_pte_binary 18 19from executorch.exir.schema import ( 20 BackendDelegate, 21 BackendDelegateDataReference, 22 DataLocation, 23) 24 25 26def extract_coreml_models(pte_data: bytes): 27 program = deserialize_pte_binary(pte_data) 28 delegates: List[BackendDelegate] = sum( 29 [execution_plan.delegates for execution_plan in program.execution_plan], [] 30 ) 31 coreml_delegates: List[BackendDelegate] = [ 32 delegate for delegate in delegates if delegate.id == CoreMLBackend.__name__ 33 ] 34 model_index: int = 1 35 for coreml_delegate in coreml_delegates: 36 coreml_delegate_data: BackendDelegateDataReference = coreml_delegate.processed 37 coreml_processed_bytes: Optional[bytes] = None 38 match coreml_delegate_data.location: 39 case DataLocation.INLINE: 40 coreml_processed_bytes = program.backend_delegate_data[ 41 coreml_delegate_data.index 42 ].data 43 44 case _: 45 AssertionError("The loaded Program must have inline data.") 46 47 model_name: str = f"model_{model_index}" 48 model_path: Path = Path() / "extracted_coreml_models" / model_name 49 if model_path.exists(): 50 shutil.rmtree(model_path.absolute()) 51 os.makedirs(model_path.absolute()) 52 53 if executorchcoreml.unflatten_directory_contents( 54 coreml_processed_bytes, str(model_path.absolute()) 55 ): 56 print(f"Core ML models are extracted and saved to path = {model_path}") 57 model_index += 1 58 59 if len(coreml_delegates) == 0: 60 print("The model isn't delegated to Core ML.") 61 62 63if __name__ == "__main__": 64 """ 65 Extracts the Core ML models embedded in the ``.pte`` file and saves them to the 66 file system. 67 """ 68 parser = argparse.ArgumentParser() 69 parser.add_argument( 70 "-m", 71 "--model_path", 72 required=True, 73 help="Input must be a .pte file.", 74 ) 75 76 args = parser.parse_args() 77 model_path = str(args.model_path) 78 with open(model_path, mode="rb") as pte_file: 79 pte_data = pte_file.read() 80 extract_coreml_models(pte_data) 81