• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3from __future__ import annotations
4
5import functools
6import sys
7import warnings
8from typing import Sequence
9
10import torch
11import torch._C._onnx as _C_onnx
12import torch.onnx
13from torch import _C
14
15# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
16from torch.onnx import (
17    _constants,
18    _type_utils,
19    errors,
20    symbolic_helper,
21    symbolic_opset9 as opset9,
22)
23from torch.onnx._globals import GLOBALS
24from torch.onnx._internal import jit_utils, registration
25
26
27# EDITING THIS FILE? READ THIS FIRST!
28# see Note [Edit Symbolic Files] in README.md
29
30# This file exports ONNX ops for opset 10
31# Opset 10 is supported by ONNX release 1.5.0
32# release on 04/24/19
33
34
35__all__ = [
36    "dequantize",
37    "div",
38    "embedding_bag",
39    "fake_quantize_per_tensor_affine",
40    "flip",
41    "fmod",
42    "isfinite",
43    "isinf",
44    "nan_to_num",
45    "quantize_per_tensor",
46    "quantized_add_relu",
47    "quantized_add",
48    "quantized_cat",
49    "quantized_conv1d_relu",
50    "quantized_conv2d_relu",
51    "quantized_conv3d_relu",
52    "quantized_conv1d",
53    "quantized_conv2d",
54    "quantized_conv3d",
55    "quantized_conv_transpose1d",
56    "quantized_conv_transpose2d",
57    "quantized_conv_transpose3d",
58    "quantized_group_norm",
59    "quantized_hardswish",
60    "quantized_instance_norm",
61    "quantized_layer_norm",
62    "quantized_leaky_relu",
63    "quantized_linear",
64    "quantized_linear_relu",
65    "quantized_mul",
66    "quantized_sigmoid",
67    "slice",
68    "sort",
69    "topk",
70]
71
72
73_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
74
75
76@_onnx_symbolic("aten::div")
77def div(g: jit_utils.GraphContext, self, other, *args):
78    if len(args) == 0:
79        return opset9.true_divide(g, self, other)
80    else:
81        return _div_rounding_mode(g, self, other, *args)
82
83
84@symbolic_helper.parse_args("v", "v", "s")
85def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
86    if rounding_mode == "floor":
87        return _floor_divide(g, self, other)
88    else:
89        return opset9._div_rounding_mode(g, self, other, rounding_mode)
90
91
92@_onnx_symbolic("aten::_floor_divide")
93def _floor_divide(g: jit_utils.GraphContext, self, other):
94    if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
95        out = opset9.true_divide(g, self, other)
96        return g.op("Floor", out)
97    else:
98        # Integer division does trunction rounding
99        div = g.op("Div", self, other)
100        # Division is negative if: self < 0 != other < 0
101        zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
102        negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
103
104        # For negative numbers with self % other != 0, subtract 1 to round down instead of up
105        mod = g.op("Mod", self, other, fmod_i=0)
106        fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
107
108        one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
109        fixup = g.op("Sub", div, one)
110        return g.op("Where", fixup_mask, fixup, div)
111
112
113@_onnx_symbolic("aten::sort")
114@symbolic_helper.parse_args("v", "i", "i", "none")
115def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
116    return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
117
118
119@_onnx_symbolic("aten::topk")
120@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
121def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
122    return symbolic_helper._topk_helper(
123        g, self, k, dim, largest=largest, sorted=sorted, out=out
124    )
125
126
127def _aten_max_pool_onnx(
128    g: jit_utils.GraphContext,
129    self: _C.Value,
130    kernel_shape: Sequence[int],
131    strides: Sequence[int],
132    pads: Sequence[int],
133    dilations: Sequence[int],
134    ceil_mode: bool,
135    unbatched_rank: int,
136) -> _C.Value:
137    self_rank = g.op("Size", g.op("Shape", self))
138    if self_rank == unbatched_rank:  # C,H,W -> N,C,H,W and N=1
139        self = g.op(
140            "Unsqueeze",
141            self,
142            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
143        )
144
145    pool_result, _ = g.op(
146        "MaxPool",
147        self,
148        outputs=2,
149        ceil_mode_i=ceil_mode,
150        dilations_i=dilations,
151        kernel_shape_i=kernel_shape,
152        pads_i=pads,
153        strides_i=strides,
154    )
155
156    if self_rank == unbatched_rank:
157        pool_result = g.op(
158            "Squeeze",
159            pool_result,
160            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
161        )
162
163    return pool_result
164
165
166# For MaxPool
167def _adjust_attributes_of_max_pool(
168    expand_size: int,
169    kernel_size: Sequence[int] | int,
170    stride: Sequence[int] | int,
171    padding: Sequence[int] | int,
172    dilation: Sequence[int] | int,
173) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
174    """Adjust attributes of avg_pool to match ONNX specification."""
175
176    if isinstance(dilation, int):
177        dilation = [dilation] * expand_size
178
179    if isinstance(kernel_size, int):
180        kernel_shape = [kernel_size] * expand_size
181    else:
182        kernel_shape = kernel_size  # type: ignore[assignment]
183
184    if isinstance(padding, int):
185        pads = [padding] * expand_size * 2  # type: ignore[operator, assignment]
186    elif len(padding) == 1:
187        pads = padding * expand_size * 2  # type: ignore[operator, assignment]
188    elif len(padding) == 2:
189        # 2D padding
190        pads = padding * 2  # type: ignore[operator, assignment]
191    elif len(padding) == 3:
192        # 3D padding
193        pads = padding * 2  # type: ignore[operator, assignment]
194    else:
195        # When padding is already done for all dimensions,
196        # we don't need to double it
197        # eg: (1, 1, 1, 1, 1, 1)
198        pads = padding  # type: ignore[assignment]
199
200    if isinstance(stride, int):
201        strides = [stride] * expand_size
202    elif not stride:
203        strides = kernel_shape
204    else:
205        strides = stride  # type: ignore[assignment]
206
207    return (kernel_shape, strides, pads, dilation)
208
209
210def _aten_max_pool_with_indices_onnx(
211    g: jit_utils.GraphContext,
212    self: _C.Value,
213    kernel_shape: Sequence[int],
214    strides: Sequence[int],
215    pads: Sequence[int],
216    dilations: Sequence[int],
217    ceil_mode: bool,
218    unbatched_rank: int,
219    n_dims_one: Sequence[int],
220    n_dims_zero: Sequence[int],
221    n_dims_axes: Sequence[int],
222) -> tuple[_C.Value, Sequence[int]]:
223    self_rank = g.op("Size", g.op("Shape", self))
224    if self_rank == unbatched_rank:  # C,H,W -> N,C,H,W and N=1
225        self = g.op(
226            "Unsqueeze",
227            self,
228            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
229        )
230
231    pool_result, indices = g.op(
232        "MaxPool",
233        self,
234        outputs=2,
235        ceil_mode_i=ceil_mode,
236        dilations_i=dilations,
237        kernel_shape_i=kernel_shape,
238        pads_i=pads,
239        strides_i=strides,
240    )
241    _, flatten_indices = g.op(
242        "MaxPool",
243        self,
244        outputs=2,
245        dilations_i=dilations,
246        kernel_shape_i=n_dims_one,
247        strides_i=n_dims_one,
248    )
249
250    ends = g.op("Constant", value_t=torch.tensor(n_dims_one))
251    starts = g.op("Constant", value_t=torch.tensor(n_dims_zero))
252    axes = g.op("Constant", value_t=torch.tensor(n_dims_axes))
253
254    delta = g.op("Slice", flatten_indices, starts, ends, axes)
255    indices = g.op("Sub", indices, delta)
256
257    if self_rank == unbatched_rank:
258        pool_result = g.op(
259            "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64)
260        )
261        indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64))
262
263    return (pool_result, indices)
264
265
266@_onnx_symbolic(
267    "aten::max_pool1d",
268    decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)],
269)
270@_onnx_symbolic(
271    "aten::max_pool2d",
272    decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)],
273)
274@_onnx_symbolic(
275    "aten::max_pool3d",
276    decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)],
277)
278@_onnx_symbolic(
279    "aten::max_pool1d_with_indices",
280    decorate=[
281        symbolic_helper._apply_params(
282            "max_pool1d_with_indices",
283            1,
284            return_indices=True,
285        )
286    ],
287)
288@_onnx_symbolic(
289    "aten::max_pool2d_with_indices",
290    decorate=[
291        symbolic_helper._apply_params(
292            "max_pool2d_with_indices",
293            2,
294            return_indices=True,
295        )
296    ],
297)
298@_onnx_symbolic(
299    "aten::max_pool3d_with_indices",
300    decorate=[
301        symbolic_helper._apply_params(
302            "max_pool3d_with_indices",
303            3,
304            return_indices=True,
305        )
306    ],
307)
308def _max_pool(name: str, expand_size: int, return_indices: bool):
309    @symbolic_helper.quantized_args(True, False, False, False, False, False)
310    @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
311    def symbolic_fn(
312        g: jit_utils.GraphContext,
313        input: _C.Value,
314        kernel_size: Sequence[int],
315        stride: Sequence[int],
316        padding: int | Sequence[int],
317        dilation: Sequence[int],
318        ceil_mode: bool,
319    ):
320        kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool(
321            expand_size, kernel_size, stride, padding, dilation
322        )
323
324        if return_indices:
325            return _aten_max_pool_with_indices_onnx(
326                g,
327                input,
328                kernel_shape,
329                strides,
330                pads,
331                dilations,
332                ceil_mode,
333                expand_size + 1,
334                ([1] * expand_size),
335                ([0] * expand_size),
336                ([2 + i for i in range(expand_size)]),
337            )
338        else:
339            return _aten_max_pool_onnx(
340                g,
341                input,
342                kernel_shape,
343                strides,
344                pads,
345                dilations,
346                ceil_mode,
347                expand_size + 1,
348            )
349
350    return symbolic_fn
351
352
353# For AvgPool
354def _adjust_attributes_of_avg_pool(
355    expand_size: int,
356    kernel_size: Sequence[int] | int,
357    stride: Sequence[int] | int,
358    padding: Sequence[int] | int,
359) -> tuple[Sequence[int], Sequence[int], Sequence[int]]:
360    """Adjust attributes of avg_pool to match ONNX specification."""
361
362    if isinstance(kernel_size, int):
363        kernel_shape = [kernel_size] * expand_size
364    else:
365        kernel_shape = kernel_size  # type: ignore[assignment]
366
367    if isinstance(padding, int):
368        pads = [padding] * expand_size * 2
369    elif len(padding) == 1:
370        pads = padding * expand_size * 2  # type: ignore[operator, assignment]
371    elif len(padding) == 2:
372        pads = padding * expand_size  # type: ignore[operator, assignment]
373    else:
374        pads = padding * 2  # type: ignore[operator, assignment]
375
376    if isinstance(stride, int):
377        strides = [stride] * expand_size
378    elif not stride:
379        strides = kernel_shape
380    else:
381        strides = stride  # type: ignore[assignment]
382
383    return (kernel_shape, strides, pads)
384
385
386@_onnx_symbolic(
387    "aten::avg_pool1d",
388    decorate=[symbolic_helper._apply_params("avg_pool1d", 1)],
389)
390@_onnx_symbolic(
391    "aten::avg_pool2d",
392    decorate=[symbolic_helper._apply_params("avg_pool2d", 2)],
393)
394@_onnx_symbolic(
395    "aten::avg_pool3d",
396    decorate=[symbolic_helper._apply_params("avg_pool3d", 3)],
397)
398def _avg_pool(name, expand_size):
399    @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
400    @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
401    def symbolic_fn(
402        g,
403        input: _C.Value,
404        kernel_size: Sequence[int],
405        stride: Sequence[int],
406        padding: int | Sequence[int],
407        ceil_mode: int,
408        count_include_pad: int,
409        divisor_override=None,
410    ):
411        kernel_shape, strides, pads = _adjust_attributes_of_avg_pool(
412            expand_size, kernel_size, stride, padding
413        )
414
415        result = g.op(
416            "AveragePool",
417            input,
418            ceil_mode_i=ceil_mode,
419            count_include_pad_i=count_include_pad,
420            kernel_shape_i=kernel_shape,
421            pads_i=pads,
422            strides_i=strides,
423        )
424
425        return result
426
427    return symbolic_fn
428
429
430@_onnx_symbolic(
431    "aten::upsample_nearest1d",
432    decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")],
433)
434@_onnx_symbolic(
435    "aten::upsample_nearest2d",
436    decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")],
437)
438@_onnx_symbolic(
439    "aten::upsample_nearest3d",
440    decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")],
441)
442@_onnx_symbolic(
443    "aten::upsample_linear1d",
444    decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")],
445)
446@_onnx_symbolic(
447    "aten::upsample_bilinear2d",
448    decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")],
449)
450@_onnx_symbolic(
451    "aten::upsample_trilinear3d",
452    decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")],
453)
454def _interpolate(name, dim, interpolate_mode):
455    @symbolic_helper.quantized_args(True, False, False)
456    def symbolic_fn(g, input, output_size, *args):
457        scales, align_corners = symbolic_helper._get_interpolate_attributes(
458            g, interpolate_mode, args
459        )
460        symbolic_helper._interpolate_warning(interpolate_mode)
461        align_corners = symbolic_helper._maybe_get_scalar(align_corners)
462        if align_corners:
463            return symbolic_helper._unimplemented(name, "align_corners == True", input)
464        if scales is None:
465            scales = symbolic_helper._interpolate_size_to_scales(
466                g, input, output_size, dim
467            )
468        return g.op("Resize", input, scales, mode_s=interpolate_mode)
469
470    return symbolic_fn
471
472
473@_onnx_symbolic("aten::__interpolate")
474def __interpolate(
475    g: jit_utils.GraphContext,
476    input,
477    size,
478    scale_factor,
479    mode,
480    align_corners,
481    recompute_scale_factor,
482    antialias,
483):
484    scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
485        g, input, size, scale_factor, mode, align_corners
486    )
487    return g.op("Resize", input, scales, mode_s=mode)
488
489
490def _slice(
491    g: jit_utils.GraphContext,
492    input: torch._C.Value,
493    axes: list | torch.Tensor | torch._C.Value,
494    starts: list | torch.Tensor | torch._C.Value,
495    ends: list | torch.Tensor | torch._C.Value,
496    steps: list | torch.Tensor | torch._C.Value | None = None,
497):
498    def is_none_value(value):
499        if value is None:
500            return True
501        return (
502            isinstance(value, torch._C.Value)
503            and value.node().kind() == "prim::Constant"
504            and isinstance(value.type(), _C.NoneType)
505        )
506
507    def to_slice_input(list_or_value, default_value=None):
508        # Convert input param into a 1D torch.Value.
509        if is_none_value(list_or_value) and default_value is not None:
510            list_or_value = [default_value]
511
512        if isinstance(list_or_value, (list, torch.Tensor)):
513            return g.op("Constant", value_t=torch.tensor(list_or_value))
514
515        rank = symbolic_helper._get_tensor_rank(list_or_value)
516        if rank == 0:
517            return symbolic_helper._unsqueeze_helper(g, list_or_value, [0])
518        if rank == 1:
519            return list_or_value
520        raise errors.SymbolicValueError(
521            f"Rank must be 0 or 1, not {rank}", list_or_value
522        )
523
524    def get_const_value(list_or_value):
525        if isinstance(list_or_value, (list, torch.Tensor)):
526            if len(list_or_value) == 1:
527                return list_or_value[0]
528            return None
529        return symbolic_helper._maybe_get_const(list_or_value, "i")
530
531    # Check if slice is a no-op
532    if (
533        get_const_value(starts) == 0
534        and get_const_value(ends) == _constants.INT64_MAX
535        and (steps is None or get_const_value(steps) == 1)
536    ):
537        return input
538
539    axes = to_slice_input(axes)
540    starts = to_slice_input(starts, default_value=0)
541    ends = to_slice_input(ends, default_value=_constants.INT64_MAX)
542    if steps is None:
543        return g.op("Slice", input, starts, ends, axes)
544    steps = to_slice_input(steps, default_value=1)
545    return g.op("Slice", input, starts, ends, axes, steps)
546
547
548@_onnx_symbolic("aten::slice")
549def slice(g: jit_utils.GraphContext, self, *args):
550    if len(args) == 4:
551        # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
552        dims, start, end, step = args
553    elif len(args) == 3:
554        # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
555        start, end, step = args
556        dims = [0]
557    else:
558        raise errors.SymbolicValueError("Unknown aten::slice signature", self)
559
560    return symbolic_helper._slice_helper(
561        g,
562        self,
563        axes=dims,
564        starts=start,
565        ends=end,
566        steps=step,
567    )
568
569
570@_onnx_symbolic("aten::flip")
571@symbolic_helper.parse_args("v", "is")
572def flip(g: jit_utils.GraphContext, input, dims):
573    return symbolic_helper._slice_helper(
574        g,
575        input,
576        axes=dims,
577        starts=[-1] * len(dims),
578        ends=[-_constants.INT64_MAX] * len(dims),
579        steps=[-1] * len(dims),
580    )
581
582
583@_onnx_symbolic("aten::fmod")
584def fmod(g: jit_utils.GraphContext, input, other):
585    return g.op("Mod", input, other, fmod_i=1)
586
587
588@_onnx_symbolic("aten::embedding_bag")
589@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
590def embedding_bag(
591    g: jit_utils.GraphContext,
592    embedding_matrix,
593    indices,
594    offsets,
595    scale_grad_by_freq,
596    mode,
597    sparse,
598    per_sample_weights,
599    include_last_offset,
600    padding_idx,
601):
602    if scale_grad_by_freq and GLOBALS.export_training:
603        return symbolic_helper._onnx_unsupported(
604            "embedding_bag with scale_grad_by_freq for training mode"
605        )
606    if padding_idx is not None and padding_idx >= 0:
607        raise RuntimeError("embedding_bag with padding_idx")
608
609    warnings.warn(
610        "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
611        "Please use opset 11 or higher to export model for dynamic input shape.'"
612    )
613    offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
614    if offsets_dim_0 is not None:
615        if include_last_offset:
616            offset_len = offsets_dim_0 - 1
617            offsets_extended = offsets
618        else:
619            offset_len = offsets_dim_0
620            offsets_extended = [
621                offsets,
622                g.op("Constant", value_t=torch.tensor([sys.maxsize])),
623            ]
624            offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
625        list_ = []
626        for i in range(offset_len):
627            start_ = symbolic_helper._unsqueeze_helper(
628                g,
629                opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
630                [0],
631            )
632            end_ = symbolic_helper._unsqueeze_helper(
633                g,
634                opset9.select(
635                    g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
636                ),
637                [0],
638            )
639            axes_ = g.op("Constant", value_t=torch.tensor([0]))
640            indices_row = g.op("Slice", indices, start_, end_, axes_)
641
642            embeddings = g.op("Gather", embedding_matrix, indices_row)
643            if not symbolic_helper._is_none(per_sample_weights):
644                per_sample_weights_row = g.op(
645                    "Slice", per_sample_weights, start_, end_, axes_
646                )
647                per_sample_weights_row = symbolic_helper._unsqueeze_helper(
648                    g, per_sample_weights_row, [1]
649                )
650                embeddings = g.op("Mul", embeddings, per_sample_weights_row)
651            if mode == 0:
652                embeddings = symbolic_helper._reducesum_helper(
653                    g, embeddings, axes_i=[0], keepdims_i=0
654                )
655            elif mode == 1:
656                embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
657            else:
658                embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
659
660            embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
661            list_.append(embeddings)
662
663        output = g.op("Concat", *list_, axis_i=0)
664        # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
665        # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
666        return output, None, None, None
667    else:
668        return symbolic_helper._onnx_unsupported(
669            "embedding_bag with unknown shape of offsets for opset 10 is not supported. "
670            "please use opset 11 or higher."
671        )
672
673
674@_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
675@symbolic_helper.parse_args("v", "v", "v", "i", "i")
676def fake_quantize_per_tensor_affine(
677    g: jit_utils.GraphContext,
678    inputs,
679    scale,
680    zero_point,
681    quant_min=-128,
682    quant_max=127,
683):
684    # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
685    #   https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
686    if (quant_min, quant_max) == (0, 127):
687        symbolic_helper._onnx_opset_unsupported_detailed(
688            "fake_quantize_per_tensor_affine",
689            10,
690            13,
691            "Quantize range (0, 127) not supported, requires opset 13 Clip",
692            inputs,
693        )
694    if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
695        raise errors.SymbolicValueError(
696            f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
697            f"Got ({quant_min}, {quant_max})",
698            inputs,
699        )
700    scale = symbolic_helper._maybe_get_scalar(scale)
701    if scale is None:
702        symbolic_helper._onnx_opset_unsupported_detailed(
703            "fake_quantize_per_tensor_affine",
704            10,
705            13,
706            "Non-constant scale not supported",
707            inputs,
708        )
709    scale = scale.float().data  # Avoid exporter generating double type
710    if quant_min == 0:
711        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
712    else:
713        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
714    return g.op(
715        "DequantizeLinear",
716        g.op("QuantizeLinear", inputs, scale, zero_point),
717        scale,
718        zero_point,
719    )
720
721
722@_onnx_symbolic("aten::isinf")
723def isinf(g: jit_utils.GraphContext, input):
724    return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
725
726
727@_onnx_symbolic("aten::isfinite")
728def isfinite(g: jit_utils.GraphContext, input):
729    inf_node = isinf(g, input)
730    nan_node = opset9.isnan(g, input)
731    return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
732
733
734@_onnx_symbolic("aten::quantize_per_tensor")
735def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
736    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
737    # TODO(justinchuby): Extract all the cast ops into a helper function.
738    zero_point = g.op(
739        "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
740    )
741    scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
742    return symbolic_helper.quantize_helper(g, input, scale, zero_point)
743
744
745@_onnx_symbolic("aten::dequantize")
746def dequantize(g: jit_utils.GraphContext, input):
747    return symbolic_helper.dequantize_helper(g, input)[0]
748
749
750@_onnx_symbolic("aten::nan_to_num")
751@symbolic_helper.parse_args("v", "f", "f", "f")
752def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
753    # Cannot create a int type tensor with inf/nan values, so we simply
754    # return the original tensor
755    if not symbolic_helper._is_fp(input):
756        return input
757    input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
758    if nan is None:
759        nan = 0.0
760    nan_cond = opset9.isnan(g, input)
761    nan_result = g.op(
762        "Where",
763        nan_cond,
764        g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
765        input,
766    )
767
768    # For None values of posinf, neginf we use the greatest/lowest finite
769    # value representable by input's dtype.
770    finfo = torch.finfo(input_dtype)
771    if posinf is None:
772        posinf = finfo.max
773    posinf_cond = opset9.logical_and(
774        g,
775        isinf(g, nan_result),
776        opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
777    )
778    nan_posinf_result = g.op(
779        "Where",
780        posinf_cond,
781        g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
782        nan_result,
783    )
784
785    if neginf is None:
786        neginf = finfo.min
787    neginf_cond = opset9.logical_and(
788        g,
789        isinf(g, nan_posinf_result),
790        opset9.lt(
791            g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
792        ),
793    )
794    return g.op(
795        "Where",
796        neginf_cond,
797        g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
798        nan_posinf_result,
799    )
800
801
802# Quantized symbolics ---------------------------------------------------------
803# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
804# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
805# introduced in opset version 10.
806@_onnx_symbolic("quantized::linear")
807def quantized_linear(
808    g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
809):
810    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
811    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
812    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
813    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
814
815    output = opset9.linear(g, input, weight, bias)
816
817    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
818
819
820@_onnx_symbolic("quantized::linear_relu")
821def quantized_linear_relu(
822    g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
823):
824    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
825    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
826    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
827    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
828
829    output = opset9.linear(g, input, weight, bias)
830    output = opset9.relu(g, output)
831
832    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
833
834
835@_onnx_symbolic("quantized::add")
836def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
837    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
838    y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
839
840    output = opset9.add(g, x, y)
841
842    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
843
844
845@_onnx_symbolic("quantized::add_relu")
846def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
847    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
848    y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
849
850    output = opset9.add(g, x, y)
851    output = opset9.relu(g, output)
852
853    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
854
855
856@_onnx_symbolic("quantized::mul")
857def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
858    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
859    y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
860
861    output = opset9.mul(g, x, y)
862
863    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
864
865
866@_onnx_symbolic("quantized::hardswish")
867def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
868    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
869
870    output = opset9.hardswish(g, x)
871
872    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
873
874
875@_onnx_symbolic("quantized::sigmoid")
876def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
877    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
878
879    output = opset9.sigmoid(g, x)
880
881    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
882
883
884@_onnx_symbolic("quantized::leaky_relu")
885def quantized_leaky_relu(
886    g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
887):
888    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
889
890    output = opset9.leaky_relu(g, x, negative_slope, inplace)
891
892    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
893
894
895@_onnx_symbolic("quantized::layer_norm")
896def quantized_layer_norm(
897    g: jit_utils.GraphContext,
898    x,
899    normalized_shape,
900    weight,
901    bias,
902    eps,
903    op_scale,
904    op_zero_point,
905):
906    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
907
908    output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
909
910    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
911
912
913@_onnx_symbolic("quantized::group_norm")
914def quantized_group_norm(
915    g: jit_utils.GraphContext,
916    x,
917    num_groups,
918    weight,
919    bias,
920    eps,
921    op_scale,
922    op_zero_point,
923):
924    x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
925
926    output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
927
928    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
929
930
931@_onnx_symbolic("quantized::instance_norm")
932@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
933def quantized_instance_norm(
934    g: jit_utils.GraphContext,
935    q_input,
936    weight,
937    bias,
938    eps,
939    op_scale,
940    op_zero_point,
941):
942    input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
943
944    output = opset9.instance_norm(
945        g, input, weight, bias, None, None, False, 0.0, eps, False
946    )
947
948    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
949
950
951@_onnx_symbolic("quantized::conv1d_relu")
952def quantized_conv1d_relu(
953    g: jit_utils.GraphContext,
954    q_input,
955    q_weight,
956    bias,
957    stride,
958    padding,
959    dilation,
960    groups,
961    op_scale,
962    op_zero_point,
963):
964    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
965    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
966    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
967    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
968
969    output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
970    output = opset9.relu(g, output)
971
972    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
973
974
975@_onnx_symbolic("quantized::conv2d_relu")
976def quantized_conv2d_relu(
977    g: jit_utils.GraphContext,
978    q_input,
979    q_weight,
980    bias,
981    stride,
982    padding,
983    dilation,
984    groups,
985    op_scale,
986    op_zero_point,
987):
988    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
989    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
990    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
991    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
992
993    output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
994    output = opset9.relu(g, output)
995
996    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
997
998
999@_onnx_symbolic("quantized::conv3d_relu")
1000def quantized_conv3d_relu(
1001    g: jit_utils.GraphContext,
1002    q_input,
1003    q_weight,
1004    bias,
1005    stride,
1006    padding,
1007    dilation,
1008    groups,
1009    op_scale,
1010    op_zero_point,
1011):
1012    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1013    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1014    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1015    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1016
1017    output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
1018    output = opset9.relu(g, output)
1019
1020    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1021
1022
1023@_onnx_symbolic("quantized::conv1d")
1024def quantized_conv1d(
1025    g: jit_utils.GraphContext,
1026    q_input,
1027    q_weight,
1028    bias,
1029    stride,
1030    padding,
1031    dilation,
1032    groups,
1033    op_scale,
1034    op_zero_point,
1035):
1036    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1037    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1038    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1039    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1040
1041    output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
1042
1043    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1044
1045
1046@_onnx_symbolic("quantized::conv2d")
1047def quantized_conv2d(
1048    g: jit_utils.GraphContext,
1049    q_input,
1050    q_weight,
1051    bias,
1052    stride,
1053    padding,
1054    dilation,
1055    groups,
1056    op_scale,
1057    op_zero_point,
1058):
1059    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1060    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1061    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1062    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1063
1064    output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
1065
1066    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1067
1068
1069@_onnx_symbolic("quantized::conv3d")
1070def quantized_conv3d(
1071    g: jit_utils.GraphContext,
1072    q_input,
1073    q_weight,
1074    bias,
1075    stride,
1076    padding,
1077    dilation,
1078    groups,
1079    op_scale,
1080    op_zero_point,
1081):
1082    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1083    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1084    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1085    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1086
1087    output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups)
1088
1089    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1090
1091
1092@_onnx_symbolic("quantized::conv_transpose1d")
1093def quantized_conv_transpose1d(
1094    g: jit_utils.GraphContext,
1095    q_input,
1096    q_weight,
1097    bias,
1098    stride,
1099    padding,
1100    output_padding,
1101    dilation,
1102    groups,
1103    op_scale,
1104    op_zero_point,
1105):
1106    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1107    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1108    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1109    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1110
1111    output = opset9.conv_transpose2d(
1112        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1113    )
1114
1115    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1116
1117
1118@_onnx_symbolic("quantized::conv_transpose2d")
1119def quantized_conv_transpose2d(
1120    g: jit_utils.GraphContext,
1121    q_input,
1122    q_weight,
1123    bias,
1124    stride,
1125    padding,
1126    output_padding,
1127    dilation,
1128    groups,
1129    op_scale,
1130    op_zero_point,
1131):
1132    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1133    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1134    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1135    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1136
1137    output = opset9.conv_transpose2d(
1138        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1139    )
1140
1141    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1142
1143
1144@_onnx_symbolic("quantized::conv_transpose3d")
1145def quantized_conv_transpose3d(
1146    g: jit_utils.GraphContext,
1147    q_input,
1148    q_weight,
1149    bias,
1150    stride,
1151    padding,
1152    output_padding,
1153    dilation,
1154    groups,
1155    op_scale,
1156    op_zero_point,
1157):
1158    input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
1159    weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
1160    q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
1161    bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
1162
1163    output = opset9.conv_transpose3d(
1164        g, input, weight, bias, stride, padding, output_padding, groups, dilation
1165    )
1166
1167    return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
1168
1169
1170@_onnx_symbolic("quantized::cat")
1171@symbolic_helper.parse_args("v", "i", "v", "v")
1172def quantized_cat(
1173    g: jit_utils.GraphContext,
1174    q_inputs: _C.Value,
1175    dim: int,
1176    op_scale: _C.Value,
1177    op_zero_point: _C.Value,
1178) -> _C.Value:
1179    unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
1180    dequantized = [
1181        symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
1182    ]
1183    concatenated = g.op("Concat", *dequantized, axis_i=dim)
1184    return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
1185