• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3"""This file exports ONNX ops for opset 14.
4
5Note [ONNX operators that are added/updated in opset 14]
6~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7New operators:
8    HardSwish, Trilu
9
10Updated operators:
11    Reshape
12    Add, Sub, Mul, Div
13    GRU, LSTM, RNN
14    BatchNorm, Cumsum, Relu
15"""
16
17# EDITING THIS FILE? READ THIS FIRST!
18# see Note [Edit Symbolic Files] in README.md
19from __future__ import annotations
20
21import functools
22
23import torch
24from torch.onnx import _constants, _type_utils, symbolic_helper
25from torch.onnx._globals import GLOBALS
26from torch.onnx._internal import jit_utils, registration
27
28
29__all__ = [
30    "hardswish",
31    "tril",
32    "triu",
33    "reshape",
34    "batch_norm",
35    "quantized_hardswish",
36    "scaled_dot_product_attention",
37]
38
39_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
40
41
42@_onnx_symbolic("aten::hardswish")
43@symbolic_helper.parse_args("v")
44def hardswish(g: jit_utils.GraphContext, self):
45    return g.op("HardSwish", self)
46
47
48@_onnx_symbolic("aten::tril")
49def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
50    return g.op("Trilu", self, diagonal, upper_i=0)
51
52
53@_onnx_symbolic("aten::triu")
54def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
55    return g.op("Trilu", self, diagonal, upper_i=1)
56
57
58@_onnx_symbolic("aten::reshape")
59@symbolic_helper.quantized_args(True)
60@symbolic_helper.parse_args("v", "v")
61def reshape(g: jit_utils.GraphContext, self, shape):
62    # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
63    #       Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
64    return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
65
66
67@_onnx_symbolic("aten::batch_norm")
68@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
69def batch_norm(
70    g: jit_utils.GraphContext,
71    input,
72    weight,
73    bias,
74    running_mean,
75    running_var,
76    training,
77    momentum,
78    eps,
79    cudnn_enabled,
80):
81    if (
82        torch.is_autocast_enabled()
83        and not symbolic_helper.args_have_same_dtype(
84            [input, weight, bias, running_mean, running_var]
85        )
86        and GLOBALS.export_onnx_opset_version < 15
87    ):
88        return symbolic_helper._onnx_opset_unsupported_detailed(
89            "BatchNormalization",
90            14,
91            15,
92            "All input tensors must have the same `dtype`."
93            " Turn off Autocast or export using opset version 15.",
94            input,
95        )
96
97    symbolic_helper.check_training_mode(training, "batch_norm")
98    weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
99        g, input, weight, bias, running_mean, running_var
100    )
101    out = g.op(
102        "BatchNormalization",
103        input,
104        weight,
105        bias,
106        running_mean,
107        running_var,
108        epsilon_f=eps,
109        momentum_f=1 - momentum,
110        training_mode_i=0 if not training else 1,
111        outputs=1 if not training else 3,
112    )
113    if not training:
114        return out
115    else:
116        res, new_running_mean, new_running_var = out
117        new_running_mean.setType(running_mean.type())
118        new_running_var.setType(running_var.type())
119        return res
120
121
122@_onnx_symbolic("quantized::hardswish")
123def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
124    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
125
126    output = hardswish(g, x)
127
128    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
129
130
131# Ported from
132# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504
133# aten_scaled_dot_product_attention
134# NOTE: Need op.Trilu
135@_onnx_symbolic("aten::scaled_dot_product_attention")
136@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
137def scaled_dot_product_attention(
138    g: jit_utils.GraphContext,
139    query: torch._C.Value,
140    key: torch._C.Value,
141    value: torch._C.Value,
142    attn_mask: torch._C.Value | None = None,
143    dropout_p: float = 0.0,
144    is_causal: bool = False,
145    scale: torch._C.Value | None = None,
146    enable_gqa: bool = False,
147):
148    assert (not is_causal) or (
149        is_causal and symbolic_helper._is_none(attn_mask)
150    ), "is_causal and attn_mask cannot be set at the same time"
151    assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
152
153    if symbolic_helper._is_none(scale):
154        scale = _attention_scale(g, query)
155
156    if is_causal:
157        attn_mask = _causal_attention_mask(g, query, key)
158
159    # Swap the last two axes of key
160    # NOTE: onnx-script has different logic here, because the attribute perms in
161    # transpose needs list of ints
162    key_shape_builtin = symbolic_helper._get_tensor_rank(key)
163    key_transposed_axes = list(range(key_shape_builtin))
164    key_transposed_axes[-1], key_transposed_axes[-2] = (
165        key_transposed_axes[-2],
166        key_transposed_axes[-1],
167    )
168    key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
169
170    # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
171    # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
172    query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
173    key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
174    mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
175
176    if symbolic_helper._is_none(attn_mask):
177        mul_qk_add = mul_qk
178    elif (
179        _type_utils.JitScalarType.from_value(attn_mask)
180        == _type_utils.JitScalarType.BOOL
181    ):
182        # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
183        const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
184        const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
185        attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
186        mul_qk_add = g.op("Add", mul_qk, attn_mask)
187    elif _type_utils.JitScalarType.from_value(attn_mask) in (
188        _type_utils.JitScalarType.FLOAT,
189        _type_utils.JitScalarType.HALF,
190        _type_utils.JitScalarType.BFLOAT16,
191    ):
192        mul_qk_add = g.op("Add", mul_qk, attn_mask)
193    else:
194        raise ValueError(
195            f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
196        )
197
198    attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
199
200    if dropout_p != 0:
201        attn_weight = g.op(
202            "Dropout",
203            attn_weight,
204            g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
205        )
206
207    return g.op("MatMul", attn_weight, value)
208
209
210def _attention_scale(
211    g: jit_utils.GraphContext, query: torch._C.Value
212) -> torch._C.Value:
213    """Calculate the scale factor for the attention result.
214
215    Args:
216        query: Tensor of shape [..., L, E]
217
218    Returns:
219        Scalar scale factor := 1 / math.sqrt(query.size(-1))
220    """
221    query_shape = g.op("Shape", query)
222    query_shape_last = g.op(
223        "Slice",
224        query_shape,
225        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)),
226        g.op(
227            "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
228        ),
229    )
230    embedding_size = g.op(
231        "Cast",
232        query_shape_last,
233        to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
234    )
235    const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float))
236    scale = g.op("Div", const_one, g.op("Sqrt", embedding_size))
237    # Add a Cast to convert the scale back to original type
238    scale = g.op(
239        "Cast",
240        scale,
241        to_i=_type_utils.JitScalarType.from_value(query).onnx_type(),
242    )
243    return scale
244
245
246def _causal_attention_mask(
247    g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value
248) -> torch._C.Value:
249    """Create a causal mask for the given query and key tensors.
250
251    Equivalent to::
252        mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
253        attn_mask = torch.zeros(L, S, dtype=torch.float)
254        attn_mask = attn_mask.masked_fill(not mask, -float("inf"))
255
256    Args:
257        query: Tensor of shape [..., L, E]
258        key: Tensor of shape [..., S, E]
259
260    Returns:
261        Tensor of shape [L, S]
262    """
263
264    query_shape = g.op("Shape", query)
265    key_shape = g.op("Shape", key)
266
267    last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
268    second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64))
269    target_length = g.op("Slice", query_shape, second_last_idx, last_idx)
270    source_length = g.op("Slice", key_shape, second_last_idx, last_idx)
271    # attn_mask = torch.ones(L, S) := {
272    size = g.op("Concat", target_length, source_length, axis_i=0)
273    const_one = g.op("Constant", value_t=torch.tensor([1.0]))
274    attn_mask = g.op("Expand", const_one, size)
275    # }
276    attn_mask = g.op("Trilu", attn_mask, upper_i=0)
277    # The causal mask has 0s in the lower triangle and -inf in the upper triangle.
278    const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
279    const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
280    attn_mask = g.op(
281        "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero
282    )
283    return attn_mask
284