1# mypy: allow-untyped-defs 2"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" 3 4from __future__ import annotations 5 6import glob 7import io 8import os 9import shutil 10import zipfile 11from typing import Any, Mapping 12 13import torch 14import torch.jit._trace 15import torch.serialization 16from torch.onnx import _constants, _exporter_states, errors 17from torch.onnx._internal import jit_utils, registration 18 19 20def export_as_test_case( 21 model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str 22) -> str: 23 """Export an ONNX model as a self contained ONNX test case. 24 25 The test case contains the model and the inputs/outputs data. The directory structure 26 is as follows: 27 28 dir 29 \u251c\u2500\u2500 test_<name> 30 \u2502 \u251c\u2500\u2500 model.onnx 31 \u2502 \u2514\u2500\u2500 test_data_set_0 32 \u2502 \u251c\u2500\u2500 input_0.pb 33 \u2502 \u251c\u2500\u2500 input_1.pb 34 \u2502 \u251c\u2500\u2500 output_0.pb 35 \u2502 \u2514\u2500\u2500 output_1.pb 36 37 Args: 38 model_bytes: The ONNX model in bytes. 39 inputs_data: The inputs data, nested data structure of numpy.ndarray. 40 outputs_data: The outputs data, nested data structure of numpy.ndarray. 41 42 Returns: 43 The path to the test case directory. 44 """ 45 try: 46 import onnx 47 except ImportError as exc: 48 raise ImportError( 49 "Export test case to ONNX format failed: Please install ONNX." 50 ) from exc 51 52 test_case_dir = os.path.join(dir, "test_" + name) 53 os.makedirs(test_case_dir, exist_ok=True) 54 _export_file( 55 model_bytes, 56 os.path.join(test_case_dir, "model.onnx"), 57 _exporter_states.ExportTypes.PROTOBUF_FILE, 58 {}, 59 ) 60 data_set_dir = os.path.join(test_case_dir, "test_data_set_0") 61 if os.path.exists(data_set_dir): 62 shutil.rmtree(data_set_dir) 63 os.makedirs(data_set_dir) 64 65 proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] 66 67 for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)): 68 export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb")) 69 for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)): 70 export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb")) 71 72 return test_case_dir 73 74 75def load_test_case(dir: str) -> tuple[bytes, Any, Any]: 76 """Load a self contained ONNX test case from a directory. 77 78 The test case must contain the model and the inputs/outputs data. The directory structure 79 should be as follows: 80 81 dir 82 \u251c\u2500\u2500 test_<name> 83 \u2502 \u251c\u2500\u2500 model.onnx 84 \u2502 \u2514\u2500\u2500 test_data_set_0 85 \u2502 \u251c\u2500\u2500 input_0.pb 86 \u2502 \u251c\u2500\u2500 input_1.pb 87 \u2502 \u251c\u2500\u2500 output_0.pb 88 \u2502 \u2514\u2500\u2500 output_1.pb 89 90 Args: 91 dir: The directory containing the test case. 92 93 Returns: 94 model_bytes: The ONNX model in bytes. 95 inputs: the inputs data, mapping from input name to numpy.ndarray. 96 outputs: the outputs data, mapping from output name to numpy.ndarray. 97 """ 98 try: 99 import onnx 100 from onnx import numpy_helper # type: ignore[attr-defined] 101 except ImportError as exc: 102 raise ImportError( 103 "Load test case from ONNX format failed: Please install ONNX." 104 ) from exc 105 106 with open(os.path.join(dir, "model.onnx"), "rb") as f: 107 model_bytes = f.read() 108 109 test_data_dir = os.path.join(dir, "test_data_set_0") 110 111 inputs = {} 112 input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb")) 113 for input_file in input_files: 114 tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined] 115 inputs[tensor.name] = numpy_helper.to_array(tensor) 116 outputs = {} 117 output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb")) 118 for output_file in output_files: 119 tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined] 120 outputs[tensor.name] = numpy_helper.to_array(tensor) 121 122 return model_bytes, inputs, outputs 123 124 125def export_data(data, value_info_proto, f: str) -> None: 126 """Export data to ONNX protobuf format. 127 128 Args: 129 data: The data to export, nested data structure of numpy.ndarray. 130 value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto 131 determines how the data is stored. 132 f: The file to write the data to. 133 """ 134 try: 135 from onnx import numpy_helper # type: ignore[attr-defined] 136 except ImportError as exc: 137 raise ImportError( 138 "Export data to ONNX format failed: Please install ONNX." 139 ) from exc 140 141 with open(f, "wb") as opened_file: 142 if value_info_proto.type.HasField("map_type"): 143 opened_file.write( 144 numpy_helper.from_dict(data, value_info_proto.name).SerializeToString() 145 ) 146 elif value_info_proto.type.HasField("sequence_type"): 147 opened_file.write( 148 numpy_helper.from_list(data, value_info_proto.name).SerializeToString() 149 ) 150 elif value_info_proto.type.HasField("optional_type"): 151 opened_file.write( 152 numpy_helper.from_optional( 153 data, value_info_proto.name 154 ).SerializeToString() 155 ) 156 else: 157 assert value_info_proto.type.HasField("tensor_type") 158 opened_file.write( 159 numpy_helper.from_array(data, value_info_proto.name).SerializeToString() 160 ) 161 162 163def _export_file( 164 model_bytes: bytes, 165 f: io.BytesIO | str, 166 export_type: str, 167 export_map: Mapping[str, bytes], 168) -> None: 169 """export/write model bytes into directory/protobuf/zip""" 170 if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: 171 assert len(export_map) == 0 172 with torch.serialization._open_file_like(f, "wb") as opened_file: 173 opened_file.write(model_bytes) 174 elif export_type in { 175 _exporter_states.ExportTypes.ZIP_ARCHIVE, 176 _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, 177 }: 178 compression = ( 179 zipfile.ZIP_DEFLATED 180 if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE 181 else zipfile.ZIP_STORED 182 ) 183 with zipfile.ZipFile(f, "w", compression=compression) as z: 184 z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) 185 for k, v in export_map.items(): 186 z.writestr(k, v) 187 elif export_type == _exporter_states.ExportTypes.DIRECTORY: 188 if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] 189 raise ValueError( 190 f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" 191 ) 192 if not os.path.exists(f): # type: ignore[arg-type] 193 os.makedirs(f) # type: ignore[arg-type] 194 195 model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] 196 with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: 197 opened_file.write(model_bytes) 198 199 for k, v in export_map.items(): 200 weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] 201 with torch.serialization._open_file_like( 202 weight_proto_file, "wb" 203 ) as opened_file: 204 opened_file.write(v) 205 else: 206 raise ValueError("Unknown export type") 207 208 209def _add_onnxscript_fn( 210 model_bytes: bytes, 211 custom_opsets: Mapping[str, int], 212) -> bytes: 213 """Insert model-included custom onnx-script function into ModelProto""" 214 try: 215 import onnx 216 except ImportError as e: 217 raise errors.OnnxExporterError("Module onnx is not installed!") from e 218 219 # For > 2GB model, onnx.load_fromstring would fail. However, because 220 # in _export_onnx, the tensors should be saved separately if the proto 221 # size > 2GB, and if it for some reason did not, the model would fail on 222 # serialization anyway in terms of the protobuf limitation. So we don't 223 # need to worry about > 2GB model getting here. 224 model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] 225 226 # Iterate graph nodes to insert only the included custom 227 # function_proto into model_proto 228 onnx_function_list = [] # type: ignore[var-annotated] 229 included_node_func: set[str] = set() 230 # onnx_function_list and included_node_func are expanded in-place 231 _find_onnxscript_op( 232 model_proto.graph, included_node_func, custom_opsets, onnx_function_list 233 ) 234 235 if onnx_function_list: 236 model_proto.functions.extend(onnx_function_list) 237 model_bytes = model_proto.SerializeToString() 238 return model_bytes 239 240 241def _find_onnxscript_op( 242 graph_proto, 243 included_node_func: set[str], 244 custom_opsets: Mapping[str, int], 245 onnx_function_list: list, 246): 247 """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" 248 for node in graph_proto.node: 249 node_kind = node.domain + "::" + node.op_type 250 # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto 251 for attr in node.attribute: 252 if attr.g is not None: 253 _find_onnxscript_op( 254 attr.g, included_node_func, custom_opsets, onnx_function_list 255 ) 256 # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry 257 onnx_function_group = registration.registry.get_function_group(node_kind) 258 # Ruled out corner cases: onnx/prim in registry 259 if ( 260 node.domain 261 and not jit_utils.is_aten(node.domain) 262 and not jit_utils.is_prim(node.domain) 263 and not jit_utils.is_onnx(node.domain) 264 and onnx_function_group is not None 265 and node_kind not in included_node_func 266 ): 267 specified_version = custom_opsets.get(node.domain, 1) 268 onnx_fn = onnx_function_group.get(specified_version) 269 if onnx_fn is not None: 270 if hasattr(onnx_fn, "to_function_proto"): 271 onnx_function_proto = onnx_fn.to_function_proto() # type: ignore[attr-defined] 272 onnx_function_list.append(onnx_function_proto) 273 included_node_func.add(node_kind) 274 continue 275 276 raise errors.UnsupportedOperatorError( 277 node_kind, 278 specified_version, 279 onnx_function_group.get_min_supported() 280 if onnx_function_group 281 else None, 282 ) 283 return onnx_function_list, included_node_func 284