• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2024 Apple Inc. All rights reserved.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6import argparse
7import os
8import shutil
9from pathlib import Path
10
11from typing import List, Optional
12
13import executorchcoreml
14
15from executorch.backends.apple.coreml.compiler import CoreMLBackend
16
17from executorch.exir._serialize._program import deserialize_pte_binary
18
19from executorch.exir.schema import (
20    BackendDelegate,
21    BackendDelegateDataReference,
22    DataLocation,
23)
24
25
26def extract_coreml_models(pte_data: bytes):
27    program = deserialize_pte_binary(pte_data)
28    delegates: List[BackendDelegate] = sum(
29        [execution_plan.delegates for execution_plan in program.execution_plan], []
30    )
31    coreml_delegates: List[BackendDelegate] = [
32        delegate for delegate in delegates if delegate.id == CoreMLBackend.__name__
33    ]
34    model_index: int = 1
35    for coreml_delegate in coreml_delegates:
36        coreml_delegate_data: BackendDelegateDataReference = coreml_delegate.processed
37        coreml_processed_bytes: Optional[bytes] = None
38        match coreml_delegate_data.location:
39            case DataLocation.INLINE:
40                coreml_processed_bytes = program.backend_delegate_data[
41                    coreml_delegate_data.index
42                ].data
43
44            case _:
45                AssertionError("The loaded Program must have inline data.")
46
47        model_name: str = f"model_{model_index}"
48        model_path: Path = Path() / "extracted_coreml_models" / model_name
49        if model_path.exists():
50            shutil.rmtree(model_path.absolute())
51        os.makedirs(model_path.absolute())
52
53        if executorchcoreml.unflatten_directory_contents(
54            coreml_processed_bytes, str(model_path.absolute())
55        ):
56            print(f"Core ML models are extracted and saved to path = {model_path}")
57        model_index += 1
58
59    if len(coreml_delegates) == 0:
60        print("The model isn't delegated to Core ML.")
61
62
63if __name__ == "__main__":
64    """
65    Extracts the Core ML models embedded in the ``.pte`` file and saves them to the
66    file system.
67    """
68    parser = argparse.ArgumentParser()
69    parser.add_argument(
70        "-m",
71        "--model_path",
72        required=True,
73        help="Input must be a .pte file.",
74    )
75
76    args = parser.parse_args()
77    model_path = str(args.model_path)
78    with open(model_path, mode="rb") as pte_file:
79        pte_data = pte_file.read()
80        extract_coreml_models(pte_data)
81