• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
3
4from __future__ import annotations
5
6import glob
7import io
8import os
9import shutil
10import zipfile
11from typing import Any, Mapping
12
13import torch
14import torch.jit._trace
15import torch.serialization
16from torch.onnx import _constants, _exporter_states, errors
17from torch.onnx._internal import jit_utils, registration
18
19
20def export_as_test_case(
21    model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
22) -> str:
23    """Export an ONNX model as a self contained ONNX test case.
24
25    The test case contains the model and the inputs/outputs data. The directory structure
26    is as follows:
27
28    dir
29    \u251c\u2500\u2500 test_<name>
30    \u2502   \u251c\u2500\u2500 model.onnx
31    \u2502   \u2514\u2500\u2500 test_data_set_0
32    \u2502       \u251c\u2500\u2500 input_0.pb
33    \u2502       \u251c\u2500\u2500 input_1.pb
34    \u2502       \u251c\u2500\u2500 output_0.pb
35    \u2502       \u2514\u2500\u2500 output_1.pb
36
37    Args:
38        model_bytes: The ONNX model in bytes.
39        inputs_data: The inputs data, nested data structure of numpy.ndarray.
40        outputs_data: The outputs data, nested data structure of numpy.ndarray.
41
42    Returns:
43        The path to the test case directory.
44    """
45    try:
46        import onnx
47    except ImportError as exc:
48        raise ImportError(
49            "Export test case to ONNX format failed: Please install ONNX."
50        ) from exc
51
52    test_case_dir = os.path.join(dir, "test_" + name)
53    os.makedirs(test_case_dir, exist_ok=True)
54    _export_file(
55        model_bytes,
56        os.path.join(test_case_dir, "model.onnx"),
57        _exporter_states.ExportTypes.PROTOBUF_FILE,
58        {},
59    )
60    data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
61    if os.path.exists(data_set_dir):
62        shutil.rmtree(data_set_dir)
63    os.makedirs(data_set_dir)
64
65    proto = onnx.load_model_from_string(model_bytes)  # type: ignore[attr-defined]
66
67    for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
68        export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
69    for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)):
70        export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb"))
71
72    return test_case_dir
73
74
75def load_test_case(dir: str) -> tuple[bytes, Any, Any]:
76    """Load a self contained ONNX test case from a directory.
77
78    The test case must contain the model and the inputs/outputs data. The directory structure
79    should be as follows:
80
81    dir
82    \u251c\u2500\u2500 test_<name>
83    \u2502   \u251c\u2500\u2500 model.onnx
84    \u2502   \u2514\u2500\u2500 test_data_set_0
85    \u2502       \u251c\u2500\u2500 input_0.pb
86    \u2502       \u251c\u2500\u2500 input_1.pb
87    \u2502       \u251c\u2500\u2500 output_0.pb
88    \u2502       \u2514\u2500\u2500 output_1.pb
89
90    Args:
91        dir: The directory containing the test case.
92
93    Returns:
94        model_bytes: The ONNX model in bytes.
95        inputs: the inputs data, mapping from input name to numpy.ndarray.
96        outputs: the outputs data, mapping from output name to numpy.ndarray.
97    """
98    try:
99        import onnx
100        from onnx import numpy_helper  # type: ignore[attr-defined]
101    except ImportError as exc:
102        raise ImportError(
103            "Load test case from ONNX format failed: Please install ONNX."
104        ) from exc
105
106    with open(os.path.join(dir, "model.onnx"), "rb") as f:
107        model_bytes = f.read()
108
109    test_data_dir = os.path.join(dir, "test_data_set_0")
110
111    inputs = {}
112    input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
113    for input_file in input_files:
114        tensor = onnx.load_tensor(input_file)  # type: ignore[attr-defined]
115        inputs[tensor.name] = numpy_helper.to_array(tensor)
116    outputs = {}
117    output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
118    for output_file in output_files:
119        tensor = onnx.load_tensor(output_file)  # type: ignore[attr-defined]
120        outputs[tensor.name] = numpy_helper.to_array(tensor)
121
122    return model_bytes, inputs, outputs
123
124
125def export_data(data, value_info_proto, f: str) -> None:
126    """Export data to ONNX protobuf format.
127
128    Args:
129        data: The data to export, nested data structure of numpy.ndarray.
130        value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto
131            determines how the data is stored.
132        f: The file to write the data to.
133    """
134    try:
135        from onnx import numpy_helper  # type: ignore[attr-defined]
136    except ImportError as exc:
137        raise ImportError(
138            "Export data to ONNX format failed: Please install ONNX."
139        ) from exc
140
141    with open(f, "wb") as opened_file:
142        if value_info_proto.type.HasField("map_type"):
143            opened_file.write(
144                numpy_helper.from_dict(data, value_info_proto.name).SerializeToString()
145            )
146        elif value_info_proto.type.HasField("sequence_type"):
147            opened_file.write(
148                numpy_helper.from_list(data, value_info_proto.name).SerializeToString()
149            )
150        elif value_info_proto.type.HasField("optional_type"):
151            opened_file.write(
152                numpy_helper.from_optional(
153                    data, value_info_proto.name
154                ).SerializeToString()
155            )
156        else:
157            assert value_info_proto.type.HasField("tensor_type")
158            opened_file.write(
159                numpy_helper.from_array(data, value_info_proto.name).SerializeToString()
160            )
161
162
163def _export_file(
164    model_bytes: bytes,
165    f: io.BytesIO | str,
166    export_type: str,
167    export_map: Mapping[str, bytes],
168) -> None:
169    """export/write model bytes into directory/protobuf/zip"""
170    if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
171        assert len(export_map) == 0
172        with torch.serialization._open_file_like(f, "wb") as opened_file:
173            opened_file.write(model_bytes)
174    elif export_type in {
175        _exporter_states.ExportTypes.ZIP_ARCHIVE,
176        _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
177    }:
178        compression = (
179            zipfile.ZIP_DEFLATED
180            if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
181            else zipfile.ZIP_STORED
182        )
183        with zipfile.ZipFile(f, "w", compression=compression) as z:
184            z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
185            for k, v in export_map.items():
186                z.writestr(k, v)
187    elif export_type == _exporter_states.ExportTypes.DIRECTORY:
188        if isinstance(f, io.BytesIO) or not os.path.isdir(f):  # type: ignore[arg-type]
189            raise ValueError(
190                f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
191            )
192        if not os.path.exists(f):  # type: ignore[arg-type]
193            os.makedirs(f)  # type: ignore[arg-type]
194
195        model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME)  # type: ignore[arg-type]
196        with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
197            opened_file.write(model_bytes)
198
199        for k, v in export_map.items():
200            weight_proto_file = os.path.join(f, k)  # type: ignore[arg-type]
201            with torch.serialization._open_file_like(
202                weight_proto_file, "wb"
203            ) as opened_file:
204                opened_file.write(v)
205    else:
206        raise ValueError("Unknown export type")
207
208
209def _add_onnxscript_fn(
210    model_bytes: bytes,
211    custom_opsets: Mapping[str, int],
212) -> bytes:
213    """Insert model-included custom onnx-script function into ModelProto"""
214    try:
215        import onnx
216    except ImportError as e:
217        raise errors.OnnxExporterError("Module onnx is not installed!") from e
218
219    # For > 2GB model, onnx.load_fromstring would fail. However, because
220    # in _export_onnx, the tensors should be saved separately if the proto
221    # size > 2GB, and if it for some reason did not, the model would fail on
222    # serialization anyway in terms of the protobuf limitation. So we don't
223    # need to worry about > 2GB model getting here.
224    model_proto = onnx.load_model_from_string(model_bytes)  # type: ignore[attr-defined]
225
226    # Iterate graph nodes to insert only the included custom
227    # function_proto into model_proto
228    onnx_function_list = []  # type: ignore[var-annotated]
229    included_node_func: set[str] = set()
230    # onnx_function_list and included_node_func are expanded in-place
231    _find_onnxscript_op(
232        model_proto.graph, included_node_func, custom_opsets, onnx_function_list
233    )
234
235    if onnx_function_list:
236        model_proto.functions.extend(onnx_function_list)
237        model_bytes = model_proto.SerializeToString()
238    return model_bytes
239
240
241def _find_onnxscript_op(
242    graph_proto,
243    included_node_func: set[str],
244    custom_opsets: Mapping[str, int],
245    onnx_function_list: list,
246):
247    """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
248    for node in graph_proto.node:
249        node_kind = node.domain + "::" + node.op_type
250        # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto
251        for attr in node.attribute:
252            if attr.g is not None:
253                _find_onnxscript_op(
254                    attr.g, included_node_func, custom_opsets, onnx_function_list
255                )
256        # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
257        onnx_function_group = registration.registry.get_function_group(node_kind)
258        # Ruled out corner cases: onnx/prim in registry
259        if (
260            node.domain
261            and not jit_utils.is_aten(node.domain)
262            and not jit_utils.is_prim(node.domain)
263            and not jit_utils.is_onnx(node.domain)
264            and onnx_function_group is not None
265            and node_kind not in included_node_func
266        ):
267            specified_version = custom_opsets.get(node.domain, 1)
268            onnx_fn = onnx_function_group.get(specified_version)
269            if onnx_fn is not None:
270                if hasattr(onnx_fn, "to_function_proto"):
271                    onnx_function_proto = onnx_fn.to_function_proto()  # type: ignore[attr-defined]
272                    onnx_function_list.append(onnx_function_proto)
273                    included_node_func.add(node_kind)
274                continue
275
276            raise errors.UnsupportedOperatorError(
277                node_kind,
278                specified_version,
279                onnx_function_group.get_min_supported()
280                if onnx_function_group
281                else None,
282            )
283    return onnx_function_list, included_node_func
284