1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3"""This file exports ONNX ops for opset 14. 4 5Note [ONNX operators that are added/updated in opset 14] 6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 7New operators: 8 HardSwish, Trilu 9 10Updated operators: 11 Reshape 12 Add, Sub, Mul, Div 13 GRU, LSTM, RNN 14 BatchNorm, Cumsum, Relu 15""" 16 17# EDITING THIS FILE? READ THIS FIRST! 18# see Note [Edit Symbolic Files] in README.md 19from __future__ import annotations 20 21import functools 22 23import torch 24from torch.onnx import _constants, _type_utils, symbolic_helper 25from torch.onnx._globals import GLOBALS 26from torch.onnx._internal import jit_utils, registration 27 28 29__all__ = [ 30 "hardswish", 31 "tril", 32 "triu", 33 "reshape", 34 "batch_norm", 35 "quantized_hardswish", 36 "scaled_dot_product_attention", 37] 38 39_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) 40 41 42@_onnx_symbolic("aten::hardswish") 43@symbolic_helper.parse_args("v") 44def hardswish(g: jit_utils.GraphContext, self): 45 return g.op("HardSwish", self) 46 47 48@_onnx_symbolic("aten::tril") 49def tril(g: jit_utils.GraphContext, self, diagonal, out=None): 50 return g.op("Trilu", self, diagonal, upper_i=0) 51 52 53@_onnx_symbolic("aten::triu") 54def triu(g: jit_utils.GraphContext, self, diagonal, out=None): 55 return g.op("Trilu", self, diagonal, upper_i=1) 56 57 58@_onnx_symbolic("aten::reshape") 59@symbolic_helper.quantized_args(True) 60@symbolic_helper.parse_args("v", "v") 61def reshape(g: jit_utils.GraphContext, self, shape): 62 # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 63 # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. 64 return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) 65 66 67@_onnx_symbolic("aten::batch_norm") 68@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") 69def batch_norm( 70 g: jit_utils.GraphContext, 71 input, 72 weight, 73 bias, 74 running_mean, 75 running_var, 76 training, 77 momentum, 78 eps, 79 cudnn_enabled, 80): 81 if ( 82 torch.is_autocast_enabled() 83 and not symbolic_helper.args_have_same_dtype( 84 [input, weight, bias, running_mean, running_var] 85 ) 86 and GLOBALS.export_onnx_opset_version < 15 87 ): 88 return symbolic_helper._onnx_opset_unsupported_detailed( 89 "BatchNormalization", 90 14, 91 15, 92 "All input tensors must have the same `dtype`." 93 " Turn off Autocast or export using opset version 15.", 94 input, 95 ) 96 97 symbolic_helper.check_training_mode(training, "batch_norm") 98 weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( 99 g, input, weight, bias, running_mean, running_var 100 ) 101 out = g.op( 102 "BatchNormalization", 103 input, 104 weight, 105 bias, 106 running_mean, 107 running_var, 108 epsilon_f=eps, 109 momentum_f=1 - momentum, 110 training_mode_i=0 if not training else 1, 111 outputs=1 if not training else 3, 112 ) 113 if not training: 114 return out 115 else: 116 res, new_running_mean, new_running_var = out 117 new_running_mean.setType(running_mean.type()) 118 new_running_var.setType(running_var.type()) 119 return res 120 121 122@_onnx_symbolic("quantized::hardswish") 123def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): 124 x, _, _, _ = symbolic_helper.dequantize_helper(g, x) 125 126 output = hardswish(g, x) 127 128 return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) 129 130 131# Ported from 132# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 133# aten_scaled_dot_product_attention 134# NOTE: Need op.Trilu 135@_onnx_symbolic("aten::scaled_dot_product_attention") 136@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") 137def scaled_dot_product_attention( 138 g: jit_utils.GraphContext, 139 query: torch._C.Value, 140 key: torch._C.Value, 141 value: torch._C.Value, 142 attn_mask: torch._C.Value | None = None, 143 dropout_p: float = 0.0, 144 is_causal: bool = False, 145 scale: torch._C.Value | None = None, 146 enable_gqa: bool = False, 147): 148 assert (not is_causal) or ( 149 is_causal and symbolic_helper._is_none(attn_mask) 150 ), "is_causal and attn_mask cannot be set at the same time" 151 assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" 152 153 if symbolic_helper._is_none(scale): 154 scale = _attention_scale(g, query) 155 156 if is_causal: 157 attn_mask = _causal_attention_mask(g, query, key) 158 159 # Swap the last two axes of key 160 # NOTE: onnx-script has different logic here, because the attribute perms in 161 # transpose needs list of ints 162 key_shape_builtin = symbolic_helper._get_tensor_rank(key) 163 key_transposed_axes = list(range(key_shape_builtin)) 164 key_transposed_axes[-1], key_transposed_axes[-2] = ( 165 key_transposed_axes[-2], 166 key_transposed_axes[-1], 167 ) 168 key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) 169 170 # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 171 # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math 172 query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) 173 key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) 174 mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) 175 176 if symbolic_helper._is_none(attn_mask): 177 mul_qk_add = mul_qk 178 elif ( 179 _type_utils.JitScalarType.from_value(attn_mask) 180 == _type_utils.JitScalarType.BOOL 181 ): 182 # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) 183 const_zero = g.op("Constant", value_t=torch.tensor([0.0])) 184 const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) 185 attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) 186 mul_qk_add = g.op("Add", mul_qk, attn_mask) 187 elif _type_utils.JitScalarType.from_value(attn_mask) in ( 188 _type_utils.JitScalarType.FLOAT, 189 _type_utils.JitScalarType.HALF, 190 _type_utils.JitScalarType.BFLOAT16, 191 ): 192 mul_qk_add = g.op("Add", mul_qk, attn_mask) 193 else: 194 raise ValueError( 195 f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" 196 ) 197 198 attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) 199 200 if dropout_p != 0: 201 attn_weight = g.op( 202 "Dropout", 203 attn_weight, 204 g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), 205 ) 206 207 return g.op("MatMul", attn_weight, value) 208 209 210def _attention_scale( 211 g: jit_utils.GraphContext, query: torch._C.Value 212) -> torch._C.Value: 213 """Calculate the scale factor for the attention result. 214 215 Args: 216 query: Tensor of shape [..., L, E] 217 218 Returns: 219 Scalar scale factor := 1 / math.sqrt(query.size(-1)) 220 """ 221 query_shape = g.op("Shape", query) 222 query_shape_last = g.op( 223 "Slice", 224 query_shape, 225 g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), 226 g.op( 227 "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) 228 ), 229 ) 230 embedding_size = g.op( 231 "Cast", 232 query_shape_last, 233 to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), 234 ) 235 const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) 236 scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) 237 # Add a Cast to convert the scale back to original type 238 scale = g.op( 239 "Cast", 240 scale, 241 to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), 242 ) 243 return scale 244 245 246def _causal_attention_mask( 247 g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value 248) -> torch._C.Value: 249 """Create a causal mask for the given query and key tensors. 250 251 Equivalent to:: 252 mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) 253 attn_mask = torch.zeros(L, S, dtype=torch.float) 254 attn_mask = attn_mask.masked_fill(not mask, -float("inf")) 255 256 Args: 257 query: Tensor of shape [..., L, E] 258 key: Tensor of shape [..., S, E] 259 260 Returns: 261 Tensor of shape [L, S] 262 """ 263 264 query_shape = g.op("Shape", query) 265 key_shape = g.op("Shape", key) 266 267 last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) 268 second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) 269 target_length = g.op("Slice", query_shape, second_last_idx, last_idx) 270 source_length = g.op("Slice", key_shape, second_last_idx, last_idx) 271 # attn_mask = torch.ones(L, S) := { 272 size = g.op("Concat", target_length, source_length, axis_i=0) 273 const_one = g.op("Constant", value_t=torch.tensor([1.0])) 274 attn_mask = g.op("Expand", const_one, size) 275 # } 276 attn_mask = g.op("Trilu", attn_mask, upper_i=0) 277 # The causal mask has 0s in the lower triangle and -inf in the upper triangle. 278 const_zero = g.op("Constant", value_t=torch.tensor([0.0])) 279 const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) 280 attn_mask = g.op( 281 "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero 282 ) 283 return attn_mask 284