# Owner(s): ["module: onnx"] from __future__ import annotations import logging import tempfile from typing import Mapping, Tuple, TYPE_CHECKING import onnx import onnx.inliner import pytorch_test_common import transformers # type: ignore[import] import torch from torch import nn from torch._subclasses import fake_tensor from torch.nn import functional as F from torch.onnx import dynamo_export, ExportOptions from torch.onnx._internal.fx import diagnostics, registration from torch.testing._internal import common_utils if TYPE_CHECKING: from torch.onnx._internal.diagnostics import infra def assert_has_diagnostics( diagnostic_context: diagnostics.DiagnosticContext, rule: infra.Rule, level: infra.Level, expected_node: str, ): rule_level_pairs = (rule.id, level.name.lower()) sarif_log = diagnostic_context.sarif_log() actual_results = [] for run in sarif_log.runs: if run.results is None: continue for result in run.results: id_level_pair = (result.rule_id, result.level) actual_results.append(id_level_pair) if ( rule_level_pairs == id_level_pair and result.message.text and result.message.markdown and expected_node in result.message.text ): return raise AssertionError( f"Expected diagnostic results of rule id and level pair {rule_level_pairs} " f"not found with expected error node {expected_node} and " f"Actual diagnostic results: {actual_results}" ) @common_utils.instantiate_parametrized_tests class TestFxToOnnx(pytorch_test_common.ExportTestCase): def setUp(self): super().setUp() self.export_options = ExportOptions() def tearDown(self): super().tearDown() def test_simple_function(self): def func(x): y = x + 1 z = y.relu() return (y, z) _ = dynamo_export( func, torch.randn(1, 1, 2), export_options=self.export_options ) def test_empty(self): # Since `torch.empty` returns tensor with uninitialized data, we cannot # test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison. def func(x): return torch.empty(x.size(), dtype=torch.int64) tensor_x = torch.randn(1, 1, 2) _ = dynamo_export(func, tensor_x, export_options=self.export_options) def test_args_used_for_export_is_not_converted_to_fake_tensors(self): def func(x, y): return x + y tensor_x = torch.randn(1, 1, 2) tensor_y = torch.randn(1, 1, 2) _ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options) self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor) self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor) @common_utils.parametrize( "diagnostic_rule", [ common_utils.subtest( diagnostics.rules.find_opschema_matched_symbolic_function, name="optional_inputs", ), ], ) def test_mnist_exported_with_no_warnings(self, diagnostic_rule): class MNISTModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) self.fc1 = nn.Linear(9216, 128, bias=False) self.fc2 = nn.Linear(128, 10, bias=False) def forward(self, tensor_x: torch.Tensor): tensor_x = self.conv1(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = self.conv2(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = F.max_pool2d(tensor_x, 2) tensor_x = torch.flatten(tensor_x, 1) tensor_x = self.fc1(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = self.fc2(tensor_x) output = F.log_softmax(tensor_x, dim=1) return output tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) onnx_program = dynamo_export(MNISTModel(), tensor_x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostic_rule, diagnostics.levels.NONE, expected_node="aten.convolution.default", ) def test_trace_only_op_with_evaluator(self): model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]]) class ArgminArgmaxModel(torch.nn.Module): def forward(self, input): return ( torch.argmin(input), torch.argmax(input), torch.argmin(input, keepdim=True), torch.argmax(input, keepdim=True), torch.argmin(input, dim=0, keepdim=True), torch.argmax(input, dim=1, keepdim=True), ) _ = dynamo_export( ArgminArgmaxModel(), model_input, export_options=self.export_options ) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): def forward(self, x): values, _ = torch.topk(x, 3) return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) _ = dynamo_export(TopKModel(), x, export_options=self.export_options) def test_unsupported_function_schema_raises_diagnostic_warning_when_found_nearest_match( self, ): class TraceModel(torch.nn.Module): def forward(self, input): return input.new_zeros(()) x = torch.randn((2, 3), dtype=torch.float32) onnx_program = dynamo_export(TraceModel(), x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.WARNING, expected_node="aten.new_zeros.default", ) def test_perfect_match_on_sequence_and_bool_attributes( self, ): class TraceModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv2 = torch.nn.Conv2d( 16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1) ) def forward(self, input): return self.conv2(input) x = torch.randn(20, 16, 50, 50) onnx_program = dynamo_export(TraceModel(), x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.NONE, expected_node="aten.convolution.default", ) def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): class CustomModule(torch.nn.Module): def forward(self, input): return torch.ops.aten.clone(input, memory_format=torch.preserve_format) x = torch.tensor(3) onnx_program = dynamo_export(CustomModule(), x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, diagnostics.levels.NONE, expected_node="aten.clone.default", ) def test_missing_complex_onnx_variant_raises_errors_in_dispatcher(self): registry = torch.onnx.OnnxRegistry() # NOTE: simulate unsupported nodes aten_mul_tensor = registration.OpName.from_name_parts( namespace="aten", op_name="mul", overload="Tensor" ) # Only keep real aten.mul to test missing complex aten.mul registry._registry[aten_mul_tensor] = [ onnx_func for onnx_func in registry._registry[aten_mul_tensor] if not onnx_func.is_complex ] class TraceModel(torch.nn.Module): def forward(self, input): return torch.ops.aten.mul.Tensor(input, input) x = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64) with self.assertRaises(torch.onnx.OnnxExporterError) as e: torch.onnx.dynamo_export( TraceModel(), x, export_options=torch.onnx.ExportOptions(onnx_registry=registry), ) def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( self, ): class SubModule(torch.nn.Module): def forward(self, x, y, bias): output = x @ y return output + bias class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submodule = SubModule() def forward(self, x, y, bias): return self.submodule(x, y, bias) x = torch.randn(2, 3) y = torch.randn(3, 4) bias = torch.randn(4) onnx_program = torch.onnx.dynamo_export( Module(), x, y, bias, export_options=torch.onnx.ExportOptions(dynamic_shapes=True), ) model_proto = onnx_program.model_proto # Assert value_info for values inside local function can be retrieved def _assert_node_outputs_has_value_info( node: onnx.NodeProto, value_infos: Mapping[str, onnx.ValueInfoProto], local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], exclude_names_in_value_info, function_id: str = "", ): for output in node.output: name = f"{function_id}/{output}" if function_id else output if name not in exclude_names_in_value_info: self.assertIn(name, value_infos) if node.domain.startswith("pkg.onnxscript.torch_lib"): # No shape info available for values inside torchlib functions. return if ( function := local_functions.get((node.domain, node.op_type)) ) is not None: for node in function.node: function_id = f"{function.domain}::{function.name}" _assert_node_outputs_has_value_info( node, value_infos, local_functions, exclude_names_in_value_info, function_id, ) type_infos = {vi.name: vi for vi in model_proto.graph.value_info} functions = {(f.domain, f.name): f for f in model_proto.functions} # NOTE: inputs, outputs, and initializers are not included in value_info spec exclude_names_in_value_info = ( [input.name for input in model_proto.graph.input] + [output.name for output in model_proto.graph.output] + [init.name for init in model_proto.graph.initializer] ) for node in model_proto.graph.node: _assert_node_outputs_has_value_info( node, type_infos, functions, exclude_names_in_value_info ) def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) self.fc1 = nn.Linear(9216, 128, bias=False) self.buffer = torch.nn.Buffer(torch.randn(1, 128)) def forward(self, tensor_x: torch.Tensor): tensor_x = self.conv2(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = F.max_pool2d(tensor_x, 2) tensor_x = torch.flatten(tensor_x, 1) tensor_x = self.fc1(tensor_x) tensor_x = tensor_x + self.buffer tensor_x = F.sigmoid(tensor_x) return tensor_x class MNISTModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) self.submodule = SubModule() self.fc2 = nn.Linear(128, 10, bias=False) def forward(self, tensor_x: torch.Tensor): tensor_x = self.conv1(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = self.submodule(tensor_x) tensor_x = self.fc2(tensor_x) output = F.log_softmax(tensor_x, dim=1) return output tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) model = MNISTModel() onnx_program = torch.onnx.dynamo_export(model, tensor_x) model_proto = onnx_program.model_proto # NOTE: initializers could be optimized away by onnx optimizer onnx_initilizers = {init.name for init in model_proto.graph.initializer} torch_weights = {*model.state_dict().keys()} self.assertTrue(onnx_initilizers.issubset(torch_weights)) @common_utils.parametrize( "checkpoint_type", [ common_utils.subtest( "state_dict", name="state_dict", ), common_utils.subtest( "state_dict", name="checkpoint_file", ), ], ) def test_fake_tensor_mode_simple(self, checkpoint_type): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): out = self.linear(x) return out with torch.onnx.enable_fake_mode() as fake_context: x = torch.rand(5, 2, 2) model = Model() export_options = ExportOptions(fake_context=fake_context) onnx_program = torch.onnx.dynamo_export( model, x, export_options=export_options ) assert ( onnx_program is not None ), "ONNXProgram must be created on successful export" assert ( onnx_program.model_proto is not None ), "A model protobuf must be created on a successful export" onnx.checker.check_model(onnx_program.model_proto, full_check=True) assert ( len(onnx_program.model_proto.graph.initializer) == 0 ), "Initializers cannot exist when fake mode is enabled" if checkpoint_type == "state_dict": # Variant 1: Save ONNX proto using Model's state_dict() with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: model_state_dict = ( Model().state_dict() ) # Create a state_dict for testing onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict) assert ( len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 ), "Initializers must be present after loading it from model_state_dict" # Let's make sure consecutive `save` calls don't create dupes onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict) assert ( len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 ), "Initializers must be present after loading it from model_state_dict" elif checkpoint_type == "checkpoint_file": # Variant 2: Save ONNX proto using Model checkpoint file with tempfile.NamedTemporaryFile( suffix=".onnx" ) as tmp_onnx_file, tempfile.NamedTemporaryFile( suffix=".pt" ) as tmp_checkpoint_file: torch.save( Model().state_dict(), tmp_checkpoint_file.name ) # Create checkpoint file for testing onnx_program.save( tmp_onnx_file.name, model_state=tmp_checkpoint_file.name ) assert ( len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 ), "Initializers must be present after loading it from model_state_dict" # Let's make sure consecutive `save` calls don't create dupes onnx_program.save( tmp_onnx_file.name, model_state=tmp_checkpoint_file.name ) assert ( len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2 ), "Initializers must be present after loading it from model_state_dict" def test_fake_tensor_mode_simple_invalid_input(self): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): out = self.linear(x) return out real_model = Model() real_x = torch.rand(5, 2, 2) with torch.onnx.enable_fake_mode() as fake_context: fake_model = Model() fake_x = torch.rand(5, 2, 2) # TODO: Split each scenario on its own test case # Scenario 1: Fake model and fake input WITHOUT ExportOptions(fake_context=...) with self.assertRaises(torch.onnx.OnnxExporterError): export_options = ExportOptions(fake_context=None) _ = torch.onnx.dynamo_export( fake_model, fake_x, export_options=export_options ) # Scenario 2: Fake model and real input WITHOUT fake_context with self.assertRaises(torch.onnx.OnnxExporterError): export_options = ExportOptions(fake_context=None) _ = torch.onnx.dynamo_export( fake_model, real_x, export_options=export_options ) # Scenario 3: Real model and real input WITH fake_context with self.assertRaises(torch.onnx.OnnxExporterError): export_options = ExportOptions(fake_context=fake_context) _ = torch.onnx.dynamo_export( real_model, real_x, export_options=export_options ) # Scenario 4: Fake model and real input WITH fake_context with self.assertRaises(torch.onnx.OnnxExporterError): export_options = ExportOptions(fake_context=fake_context) _ = torch.onnx.dynamo_export( fake_model, real_x, export_options=export_options ) @pytorch_test_common.xfail( error_message="Dynamic control flow is not supported at the moment." ) def test_fake_tensor_mode_huggingface_llama(self): config = transformers.LlamaConfig( vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2 ) batch, seq = 4, 256 with torch.onnx.enable_fake_mode() as fake_context: model = transformers.LlamaModel(config).eval() input_ids = torch.randint(0, config.vocab_size, (batch, seq)) attention_mask = torch.ones(batch, seq, dtype=torch.bool) position_ids = torch.arange(0, seq, dtype=torch.long) position_ids = position_ids.unsqueeze(0).view(-1, seq) export_options = torch.onnx.ExportOptions(fake_context=fake_context) onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, export_options=export_options, ) onnx.checker.check_model(onnx_program.model_proto) onnx.shape_inference.infer_shapes(onnx_program.model_proto) @pytorch_test_common.xfail( error_message="Dynamic control flow is not supported at the moment." ) def test_fake_tensor_mode_huggingface_tiiuae_falcon(self): config = transformers.FalconConfig() batch, seq = 4, 256 with torch.onnx.enable_fake_mode() as fake_context: model = transformers.FalconModel(config).eval() input_ids = torch.randint(0, config.vocab_size, (batch, seq)) attention_mask = torch.ones(batch, seq, dtype=torch.bool) export_options = torch.onnx.ExportOptions(fake_context=fake_context) onnx_program = torch.onnx.dynamo_export( model, input_ids=input_ids, attention_mask=attention_mask, export_options=export_options, ) onnx.checker.check_model(onnx_program.model_proto) onnx.shape_inference.infer_shapes(onnx_program.model_proto) def test_exported_program_torch_distributions_normal_Normal(self): class Model(torch.nn.Module): def __init__(self) -> None: self.normal = torch.distributions.normal.Normal(0, 1) super().__init__() def forward(self, x): return self.normal.sample(x.shape) x = torch.randn(2, 3) with torch.no_grad(): exported_program = torch.export.export(Model(), args=(x,)) _ = torch.onnx.dynamo_export( exported_program, x, ) def test_aten_div_no_opmath_type_promotion(self): class Model(torch.nn.Module): def forward(self, input): return input / 2 model = Model() input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16) model_proto = torch.onnx.dynamo_export(model, input).model_proto model_proto = onnx.inliner.inline_local_functions(model_proto) div_node = next( node for node in model_proto.graph.node if node.op_type == "Div" ) # The input of Div node should be the input of the model, # with no Cast node in between. self.assertEqual(div_node.input[0], model_proto.graph.input[0].name) @common_utils.parametrize( "float8_type", [ common_utils.subtest( torch.float8_e5m2, name="torch_float8_e5m2", ), common_utils.subtest( torch.float8_e5m2fnuz, name="torch_float8_e5m2fnuz", ), common_utils.subtest( torch.float8_e4m3fn, name="torch_float8_e4m3fn", ), common_utils.subtest( torch.float8_e4m3fnuz, name="torch_float8_e4m3fnuz", ), ], ) def test_float8_support(self, float8_type): class Float8Module(torch.nn.Module): def forward(self, input: torch.Tensor): input = input.to(float8_type) return input + torch.tensor(1.0, dtype=float8_type) # NOTE: shape inference error raised in optimizer due to unsupported dtype with self.assertWarnsOnceRegex( UserWarning, "ONNXScript optimizer failed. Skipping optimization." ): _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4)) def test_export_with_logging_logger(self): logger = logging.getLogger(__name__) class LoggingLoggerModule(torch.nn.Module): def forward(self, x): logger.log("abc") return x + 1 input = torch.randn(2, 3) model = LoggingLoggerModule() _ = torch.onnx.dynamo_export(model, input) def test_export_with_hf_logging_logger(self): logger = transformers.utils.logging.get_logger(__name__) class HFLoggingLoggerModule(torch.nn.Module): def forward(self, x): logger.warning_once("abc") return x + 1 input = torch.randn(2, 3) model = HFLoggingLoggerModule() _ = torch.onnx.dynamo_export(model, input) def test_checkpoint_cast(self): model_id = "openai/whisper-large-v3" feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128) batch = 4 with torch.onnx.enable_fake_mode() as ctx: model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained( model_id, low_cpu_mem_usage=False, use_safetensors=False ) input = { "input_features": torch.randn( ( batch, feature_extractor.feature_size, feature_extractor.nb_max_frames, ) ), "decoder_input_ids": torch.tensor([[1, 1]]) * 8001, "return_dict": False, } export_options = torch.onnx.ExportOptions(fake_context=ctx) onnx_program = torch.onnx.dynamo_export( model, **input, export_options=export_options ) with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: onnx_program.save(tmp_onnx_file.name) onnx.checker.check_model(tmp_onnx_file.name, full_check=True) @common_utils.parametrize( "include_initializer", [ common_utils.subtest( True, name="include_initializer", ), common_utils.subtest( False, name="dont_include_initializer", ), ], ) @common_utils.parametrize( "use_fake_mode", [ common_utils.subtest( True, name="use_fake_mode", ), common_utils.subtest( False, name="no_fake_mode", ), ], ) @common_utils.parametrize( "use_exported_program", [ common_utils.subtest( True, name="use_exported_program", ), common_utils.subtest( False, name="no_exported_program", ), ], ) def test_save_with_without_initializer( self, include_initializer, use_fake_mode, use_exported_program ): class MNISTModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False) self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False) self.fc1 = nn.Linear(9216, 128, bias=False) self.fc2 = nn.Linear(128, 10, bias=False) def forward(self, tensor_x: torch.Tensor): tensor_x = self.conv1(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = self.conv2(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = F.max_pool2d(tensor_x, 2) tensor_x = torch.flatten(tensor_x, 1) tensor_x = self.fc1(tensor_x) tensor_x = F.sigmoid(tensor_x) tensor_x = self.fc2(tensor_x) output = F.log_softmax(tensor_x, dim=1) return output state_dict = MNISTModel().state_dict() if use_fake_mode: with torch.onnx.enable_fake_mode() as ctx: model = MNISTModel() tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) if use_exported_program: model = torch.export.export(model, args=(tensor_x,)) export_options = torch.onnx.ExportOptions(fake_context=ctx) else: model = MNISTModel() tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32) if use_exported_program: model = torch.export.export(model, args=(tensor_x,)) export_options = torch.onnx.ExportOptions() onnx_program = torch.onnx.dynamo_export( model, tensor_x, export_options=export_options ) onnx_program.apply_weights(state_dict) with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: onnx_program.save( tmp_onnx_file.name, include_initializers=include_initializer, ) onnx_model = onnx.load(tmp_onnx_file.name) self.assertEqual( (include_initializer and len(onnx_model.graph.initializer) > 0) or (not include_initializer and len(onnx_model.graph.initializer) == 0), True, ) def test_export_with_print(self): class PrintModule(torch.nn.Module): def forward(self, x): print("abc") return x + 1 input = torch.randn(2, 3) model = PrintModule() _ = torch.onnx.dynamo_export(model, input) if __name__ == "__main__": common_utils.run_tests()