# Owner(s): ["module: onnx"] """Test the support on onnxscript in PyTorch-ONNX converter.""" import io from typing import List import onnx import onnxscript from onnxscript.onnx_types import FLOAT import torch from torch.onnx._internal import jit_utils from torch.testing._internal import common_utils class TestONNXScriptExport(common_utils.TestCase): # opset version is # 1. local function is supported after opset 15 # 2. onnx-script requires users to determine opset in local function opset_version = 15 def test_onnxscript_registration_with_multiple_models(self): from onnxscript.onnx_opset import opset15 as op # 1. Register Selu onnxscript function as custom Op custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) @onnxscript.script(custom_opset) def Selu(X): # default value is not supported by onnxscript alpha = 1.67326 # auto wrapped as Constants gamma = 1.0507 alphaX = op.CastLike(alpha, X) gammaX = op.CastLike(gamma, X) neg = gammaX * (alphaX * op.Exp(X) - alphaX) pos = gammaX * X zero = op.CastLike(0, X) return op.Where(X <= zero, neg, pos) def custom_selu(g: jit_utils.GraphContext, X): return g.onnxscript_op(Selu, X).setType(X.type()) torch.onnx.register_custom_op_symbolic( symbolic_name="aten::selu", symbolic_fn=custom_selu, opset_version=self.opset_version, ) # 2. Register layer_norm onnxscript function as custom Op @onnxscript.script(custom_opset) def layer_norm( X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float ): mean = op.ReduceMean(X, axes=axes) D = X - mean # op.Sub(X, mean) DD = D * D # op.Mul(D, D) var = op.ReduceMean(DD, axes=axes) vareps = var + eps # op.Add(var, eps) stddev = op.Sqrt(vareps) invstddev = op.Reciprocal(stddev) normalized = D * invstddev # op.Mul(D, invstddev) normalizedw = op.CastLike( normalized, weight ) # Type issue if missing this Op normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) return normalizedscaled + bias @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") def custom_layer_norm( g, input, normalized_shape, weight, bias, eps, cudnn_enable ): # comprehension is not supported by onnxscript axes = [-i for i in range(len(normalized_shape), 0, -1)] return g.onnxscript_op( layer_norm, input, weight, bias, axes_i=axes, eps_f=eps ).setType(input.type()) torch.onnx.register_custom_op_symbolic( symbolic_name="aten::layer_norm", symbolic_fn=custom_layer_norm, opset_version=self.opset_version, ) # 3. export two models x = torch.randn(1, 2, 3, 4, requires_grad=True) model_selu = torch.nn.SELU() selu_onnx = io.BytesIO() torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version) N, C = 3, 4 y = torch.randn(N, C) model_layer_norm = torch.nn.LayerNorm(C) layer_norm_onnx = io.BytesIO() torch.onnx.export( model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version ) # 4. test on models selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue())) layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue())) self.assertEqual(len(selu_proto.functions), 1) self.assertEqual(len(layer_norm_proto.functions), 1) self.assertEqual(selu_proto.functions[0].name, "Selu") self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm") def test_loop_registration(self): # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py, # which has recursive logic to go through every nodes with subgraph in model proto class NestedLoopsModel(torch.jit.ScriptModule): def __init__(self) -> None: super().__init__() self.selu = torch.nn.SELU() @torch.jit.script_method def forward(self, x): y = x for i in range(x.size(3)): if i == 0: y = self.selu(x) else: y += i return y model = NestedLoopsModel() inputs = torch.zeros(1, 2, 3, 4) from onnxscript.onnx_opset import opset15 as op custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2) @onnxscript.script(custom_opset) def Selu(X): alpha = 1.6732632423543772848170429916717 gamma = 1.0507009873554804934193349852946 alphaX = op.CastLike(alpha, X) gammaX = op.CastLike(gamma, X) neg = gammaX * (alphaX * op.Exp(X) - alphaX) pos = gammaX * X zero = op.CastLike(0, X) return op.Where(X <= zero, neg, pos) def custom_selu(g, X): # domain of the Op should be aligned with onnx-script # setType API is required for custom Op to support # torchscript shape type inference print("custom_selu is used!") return g.onnxscript_op(Selu, X).setType(X.type()) torch.onnx.register_custom_op_symbolic( symbolic_name="aten::selu", symbolic_fn=custom_selu, opset_version=15, ) saved_model = io.BytesIO() torch.onnx.export( torch.jit.script(model), inputs, f=saved_model, opset_version=15 ) loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) self.assertEqual(len(loop_selu_proto.functions), 1)