# Owner(s): ["module: onnx"] import io import numpy as np import onnx import pytorch_test_common from pytorch_test_common import skipIfUnsupportedMinOpsetVersion import torch from torch.onnx import _constants, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import jit_utils from torch.testing._internal import common_utils def expect_tensor(scalar_type, shape=None): def verify(actual_type): np.testing.assert_equal(actual_type.scalarType(), scalar_type) # if shape is not None: # np.testing.assert_equal(actual_type.sizes(), shape) if shape is not None: np.testing.assert_equal(actual_type.varyingSizes(), shape) return verify def as_graphcontext(graph: torch.Graph) -> jit_utils.GraphContext: return jit_utils.GraphContext( graph=graph, block=graph.block(), opset=_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET, original_node=None, # type: ignore[arg-type] params_dict={}, env={}, values_in_env=set(), ) def g_op(graph: torch.Graph, op_name: str, *args, **kwargs): return as_graphcontext(graph).op(op_name, *args, **kwargs) class TestONNXShapeInference(pytorch_test_common.ExportTestCase): def setUp(self): self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET GLOBALS.export_onnx_opset_version = self.opset_version def run_test(self, g, n, type_assertion_funcs): if not isinstance(type_assertion_funcs, list): type_assertion_funcs = [type_assertion_funcs] torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version) for out, type_assertion_func in zip(n.outputs(), type_assertion_funcs): type_assertion_func(out.type()) def create_empty_graph(self): g = torch._C.Graph() # kick off initialization for ConstantMap. torch._C._jit_pass_onnx_graph_shape_type_inference(g, {}, self.opset_version) return g def insert_tensor_constant(self, g, tensor): return g_op(g, "Constant", value_t=tensor) def test_cast(self): # Test cast with input of unknown scalar type. g = self.create_empty_graph() input = g.addInput() cast_out = g_op(g, "Cast", input, to_i=1) self.run_test(g, cast_out.node(), expect_tensor("Float")) def test_constant_of_shape(self): # Test ConstantOfShape with input of onnx::Shape node. g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(1, 2, 3, 4)) shape = g_op(g, "Shape", constant) constant_of_shape = g_op( g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) ) self.run_test( g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4)) ) def test_constant_of_shape_static(self): # Test ConstantOfShape with input of prim::ListConstruct of static tensor rank = 4 g = self.create_empty_graph() constants = [ self.insert_tensor_constant(g, torch.tensor(i + 1)) for i in range(rank) ] shape = g_op(g, "prim::ListConstruct", *constants) shape.setType(torch._C.ListType.ofInts()) constant_of_shape = g_op( g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) ) self.run_test( g, constant_of_shape.node(), expect_tensor("Float", shape=(1, 2, 3, 4)) ) def test_constant_of_shape_dynamic(self): # Test ConstantOfShape with input of prim::ListConstruct of dynamic tensor rank = 4 g = self.create_empty_graph() inputs = [g.addInput() for i in range(rank)] shape = g_op(g, "prim::ListConstruct", *inputs) shape.setType(torch._C.ListType.ofInts()) constant_of_shape = g_op( g, "ConstantOfShape", shape, value_t=torch.tensor([2.0]) ) self.run_test( g, constant_of_shape.node(), expect_tensor("Float", shape=(None, None, None, None)), ) def test_gather_dynamic_index(self): g = self.create_empty_graph() input = g.addInput() input.setType( input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]) ) indices = g.addInput() indices.setType(indices.type().with_dtype(torch.int64).with_sizes([None])) output = g_op(g, "Gather", input, indices, axis_i=1) self.run_test( g, output.node(), expect_tensor("Float", shape=([None, None, 16, 16])) ) def test_gather_scalar_index(self): g = self.create_empty_graph() input = g.addInput() input.setType( input.type().with_dtype(torch.float).with_sizes([None, 3, 16, 16]) ) indices = self.insert_tensor_constant(g, torch.tensor(1)) output = g_op(g, "Gather", input, indices, axis_i=1) self.run_test(g, output.node(), expect_tensor("Float", shape=([None, 16, 16]))) def test_reshape(self): g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 5)) constant_2 = self.insert_tensor_constant(g, torch.tensor([2, 0, -1])) shape = g_op(g, "Reshape", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(2, 16, 25))) g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4)) constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 4])) shape = g_op(g, "Reshape", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(10, 16, 4))) g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(2, 16, 5, 4)) constant_2 = self.insert_tensor_constant(g, torch.tensor([-1, 0, 0])) shape = g_op(g, "Reshape", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(8, 16, 5))) def test_reshape_symbolic(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_sizes([None, None, 2, 8])) constant = self.insert_tensor_constant(g, torch.tensor([0, 0, -1])) output = g_op(g, "Reshape", input, constant) self.run_test(g, output.node(), expect_tensor(None, shape=(None, None, 16))) @skipIfUnsupportedMinOpsetVersion(14) def test_reshape_allowzero(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_sizes([3, 4, 0])) constant = self.insert_tensor_constant(g, torch.tensor([0, 4, 3])) output = g_op(g, "Reshape", input, constant, allowzero_i=1) self.run_test(g, output.node(), expect_tensor(None, shape=(0, 4, 3))) def test_slice(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_sizes([None, None])) start_input = g.addInput() start_input.setType(start_input.type().with_sizes([None])) end = self.insert_tensor_constant(g, torch.tensor([3])) axis = self.insert_tensor_constant(g, torch.tensor([0])) step = self.insert_tensor_constant(g, torch.tensor([1])) slice = g_op(g, "Slice", input, start_input, end, axis, step) self.run_test(g, slice.node(), expect_tensor(None, shape=(None, None))) def test_slice_with_dynamic_start_index(self): g = self.create_empty_graph() input = self.insert_tensor_constant(g, torch.ones(2, 3, 4, 5)) start_input = g.addInput() start_input.setType(start_input.type().with_sizes([2])) end = self.insert_tensor_constant(g, torch.tensor([3, 4])) axis = self.insert_tensor_constant(g, torch.tensor([1, -1])) slice = g_op(g, "Slice", input, start_input, end, axis) self.run_test(g, slice.node(), expect_tensor(None, shape=(2, None, 4, None))) def test_broadcast_matmul(self): g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2)) constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1)) shape = g_op(g, "MatMul", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 5, 1, 1))) # test when first input is of rank 1 g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(2)) constant_2 = self.insert_tensor_constant(g, torch.ones(3, 1, 2, 1)) shape = g_op(g, "MatMul", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(3, 1, 1))) # test when second input is of rank 1 g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(5, 1, 2)) constant_2 = self.insert_tensor_constant(g, torch.ones(2)) shape = g_op(g, "MatMul", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=(5, 1))) # test when both inputs are of rank 1 g = self.create_empty_graph() constant = self.insert_tensor_constant(g, torch.ones(2)) constant_2 = self.insert_tensor_constant(g, torch.ones(2)) shape = g_op(g, "MatMul", constant, constant_2) self.run_test(g, shape.node(), expect_tensor("Float", shape=())) def test_expand(self): g = self.create_empty_graph() input = g.addInput() constant = self.insert_tensor_constant(g, torch.ones(2, 4)) input.setType(constant.type().with_sizes([None, None])) shape = g_op(g, "Shape", input) expand = g_op(g, "Expand", constant, shape) self.run_test(g, expand.node(), expect_tensor("Float", shape=(None, None))) def test_pad(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100])) constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long)) none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) pad = g_op(g, "Pad", input, constant, none, mode_s="constant") self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, 322, 102))) def test_pad_with_dynamic_input_shape(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([3, None, None])) constant = self.insert_tensor_constant(g, torch.ones(6, dtype=torch.long)) none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) pad = g_op(g, "Pad", input, constant, none, mode_s="constant") self.run_test(g, pad.node(), expect_tensor("Float", shape=(5, None, None))) def test_pad_with_dynamic_pad_size(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([3, 320, 100])) pad_size = g.addInput() pad_size.setType(pad_size.type().with_dtype(torch.long).with_sizes([6])) none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) pad = g_op(g, "Pad", input, pad_size, none, mode_s="constant") self.run_test(g, pad.node(), expect_tensor("Float", shape=(None, None, None))) def test_resize(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64])) none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) scales = self.insert_tensor_constant( g, torch.tensor([1, 1, 2, 2], dtype=torch.float) ) resize = g_op( g, "Resize", input, none, scales, coordinate_transformation_mode_s="align_corners", cubic_coeff_a_f=-0.75, mode_s="linear", nearest_mode_s="floor", ) self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128))) def test_resize_after_concat(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([4, 32, 64, 64])) none = g_op(g, "prim::Constant").setType(torch.NoneType.get()) scale_1 = self.insert_tensor_constant( g, torch.tensor([1, 1], dtype=torch.float) ) scale_2 = self.insert_tensor_constant( g, torch.tensor([2, 2], dtype=torch.float) ) # `scales` values should be statically known due to constant folding in shape inference. scales = g_op(g, "Concat", scale_1, scale_2, axis_i=0) resize = g_op( g, "Resize", input, none, scales, coordinate_transformation_mode_s="align_corners", cubic_coeff_a_f=-0.75, mode_s="linear", nearest_mode_s="floor", ) self.run_test(g, resize.node(), expect_tensor("Float", shape=(4, 32, 128, 128))) def test_reduce_prod_with_axes(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.long).with_sizes([2])) reduce_prod = g_op(g, "ReduceProd", input, axes_i=[0]) self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) def test_reduce_prod_without_axes(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.long).with_sizes([2])) reduce_prod = g_op(g, "ReduceProd", input) self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) def test_proceeding_nodes_use_prim_pack_padded_output_dtype_correctly(self): g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([4, 16])) length = g.addInput() length.setType(length.type().with_dtype(torch.long).with_sizes([4])) padded, batch_size = g_op(g, "prim::PackPadded", input, length, outputs=2) # `prim::PackPadded` only occurs in tracing mode. Hence its outputs inherits # shape and data type from traced graph. padded.setType(padded.type().with_dtype(torch.float).with_sizes([None, None])) batch_size.setType(batch_size.type().with_dtype(torch.long).with_sizes([None])) # `Gather` should use the data type of `batch_size` as the data type of its output. gather_idx = self.insert_tensor_constant(g, torch.tensor([0], dtype=torch.long)) gather = g_op(g, "Gather", batch_size, gather_idx, axis_i=0) self.run_test(g, gather.node(), expect_tensor("Long", shape=(None,))) def test_squeeze_after_dynamic_if(self): from torch.onnx.symbolic_opset11 import squeeze as squeeze11 g = self.create_empty_graph() input = g.addInput() input.setType(input.type().with_dtype(torch.float).with_sizes([1, None, 5])) # Type is intentionally not bool to test that # the added "Cast" node doesn't stop shape inference. cond = g.addInput() cond.setType(input.type().with_dtype(torch.int32).with_sizes([1])) if_op, (if_context, else_context), new_node = jit_utils.add_op_with_blocks( as_graphcontext(g), "If", cond, n_blocks=2 ) block1_output = if_context.op("Add", input, input) block2_output = else_context.op("Identity", input) utils._add_output_to_block(if_context.block, block1_output) utils._add_output_to_block(else_context.block, block2_output) if_output = torch._C._jit_pass_fixup_onnx_controlflow_node( new_node, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET )[0] torch._C._jit_pass_onnx_node_shape_type_inference( new_node, {}, _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET ) # Exporter will add "If" instead of raw "Squeeze" if it does not know # that if the dimension it is squeezing has size 1. squeezed = squeeze11(as_graphcontext(g), if_output, dim=0) assert squeezed.node().kind() == "onnx::Squeeze" self.run_test(g, squeezed.node(), expect_tensor("Float", shape=(None, 5))) class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase): def setUp(self): super().setUp() self.opset_version = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET def test_setType_maintains_output_shape_for_single_custom_op(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) class CustomInverse(torch.nn.Module): def forward(self, x): return torch.inverse(x) + x def linalg_inv_settype(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type()) torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) model = CustomInverse() x = torch.randn(2, 3, 3) f = io.BytesIO() torch.onnx.export( model, (x,), f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, ) model_proto = onnx.load(io.BytesIO(f.getvalue())) model_value_info = model_proto.graph.value_info self.assertIsNotNone(model_value_info) assert model_value_info dims = model_value_info[0].type.tensor_type.shape.dim for i in range(len(dims)): # If node output has shape info, it should have dim_value # Otherwise, it has dim_params with dynamic shape self.assertTrue(dims[i].HasField("dim_value")) for dim, rank in zip(dims, x.size()): self.assertEqual(dim.dim_value, rank) def test_no_setType_for_single_custom_op(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) class CustomInverse(torch.nn.Module): def forward(self, x): return torch.inverse(x) + x def linalg_inv_no_settype(g, self): return g.op("com.microsoft::Inverse", self) torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9) model = CustomInverse() x = torch.randn(2, 3, 3) f = io.BytesIO() torch.onnx.export( model, (x,), f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, ) model_proto = onnx.load(io.BytesIO(f.getvalue())) model_value_info = model_proto.graph.value_info self.assertIsNotNone(model_value_info) assert model_value_info dims = model_value_info[0].type.tensor_type.shape.dim for i in range(len(dims)): # If node output has shape info, it should have dim_value # Otherwise, it has dim_params with dynamic shape self.assertTrue(dims[i].HasField("dim_param")) def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes( self, ): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) class CustomInverse(torch.nn.Module): def forward(self, x): return torch.inverse(x) + x def linalg_inv_settype(g, self): return g.op("com.microsoft::Inverse", self).setType( self.type().with_dtype(torch.float).with_sizes([None, 3, 3]) ) torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) model = CustomInverse() x = torch.randn(2, 3, 3) f = io.BytesIO() torch.onnx.export( model, (x,), f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, input_names=["x"], dynamic_axes={"x": {0: "batch"}}, ) model_proto = onnx.load(io.BytesIO(f.getvalue())) model_value_info = model_proto.graph.value_info self.assertIsNotNone(model_value_info) assert model_value_info dims = model_value_info[0].type.tensor_type.shape.dim # The first axe should be dynamic as we defined when exporting self.assertTrue(dims[0].HasField("dim_param")) for i in range(1, len(dims)): # If node output has shape info, it should have dim_value # Otherwise, it has dim_params with dynamic shape self.assertTrue(dims[i].HasField("dim_value")) self.assertEqual(dims[i].dim_value, x.size()[i]) def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self): self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) class CustomInverse(torch.nn.Module): def forward(self, x, y, z): x = torch.inverse(x) return x + y + z def linalg_inv_settype(g, self): return g.op("com.microsoft::Inverse", self).setType( self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10]) ) torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) model = CustomInverse() x = torch.randn(2, 3, 10, 10) y = torch.randn(2, 3, 10, 10) z = torch.randn(2, 3, 10, 10) f = io.BytesIO() torch.onnx.export( model, (x, y, z), f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, ) model_proto = onnx.load(io.BytesIO(f.getvalue())) # To validate the shape of inverse Op, we need to find inverse output name, # and then use it to identify its value_info for the shape. output_name = "" for node in model_proto.graph.node: if node.op_type == "Inverse": output_name = node.output[0] break assert output_name model_value_info = model_proto.graph.value_info self.assertIsNotNone(model_value_info) assert model_value_info for value_info in model_value_info: assert value_info.name if value_info.name == output_name: dims = value_info.type.tensor_type.shape.dim for i in range(len(dims)): # If node output has shape info, it should have dim_value # Otherwise, it has dim_params with dynamic shape self.assertTrue(dims[i].HasField("dim_value")) for dim, rank in zip(dims, x.size()): self.assertEqual(dim.dim_value, rank) if __name__ == "__main__": common_utils.run_tests()