• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
15
16
17class OpSkipOps(NodeVisitor):
18    """
19    Parent Class for handling Skip Ops
20    """
21
22    def __init__(self, *args) -> None:
23        super().__init__(*args)
24
25    def define_node(
26        self,
27        node: torch.fx.Node,
28        xnn_graph: XNNGraph,
29        vals_to_ids: Dict[torch.fx.Node, int],
30        debug_handle: int,
31    ) -> None:
32        return
33
34
35@register_node_visitor
36class OpChooseQparamsTensor(OpSkipOps):
37    """
38    do nothing if node is choose_qparams.tensor
39    """
40
41    target = "quantized_decomposed.choose_qparams.tensor"
42
43
44@register_node_visitor
45class OpDequantizePerChannelDefault(OpSkipOps):
46    """
47    do nothing if node is dequantize_per_channel.default
48    """
49
50    target = "quantized_decomposed.dequantize_per_channel.default"
51
52
53@register_node_visitor
54class OpGetItem(OpSkipOps):
55    """
56    do nothing if node is getitem
57    """
58
59    target = "getitem"
60
61
62@register_node_visitor
63class OpQuantizePerChannelDefault(OpSkipOps):
64    """
65    do nothing if node is quantize_per_channel.default
66    """
67
68    target = "quantized_decomposed.quantize_per_channel.default"
69
70
71@register_node_visitor
72class OpTCopyDefault(OpSkipOps):
73    """
74    do nothing if node is t_copy.default
75    """
76
77    target = "aten.t_copy.default"
78
79
80@register_node_visitor
81class OpViewCopyDefault(OpSkipOps):
82    """
83    currently, do nothing if node is view_copy.default
84    need to handle this later on, currently view it as one of skip ops
85    """
86
87    target = "aten.view_copy.default"
88
89
90@register_node_visitor
91class OpSymSizeInt(OpSkipOps):
92    """
93    currently, do nothing if node is sym_size.int
94    need to handle this later on, currently view it as one of skip ops
95    """
96
97    target = "sym_size.int"
98
99
100@register_node_visitor
101class OpChooseQparamsAffine(OpSkipOps):
102    """
103    do nothing if node is choose_qparams_affine.default
104    """
105
106    target = "quant.choose_qparams_affine.default"
107
108
109@register_node_visitor
110class OpChooseQparamsToken(OpSkipOps):
111    """
112    do nothing if node is choose_qparams_per_token_asymmetric.tensor
113    """
114
115    target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default"
116
117
118@register_node_visitor
119class OpQuantizePerChannelGroupDefault(OpSkipOps):
120    """
121    do nothing if node is quantize_per_channel_group.default
122    """
123
124    target = "quantized_decomposed.quantize_per_channel_group.default"
125
126
127@register_node_visitor
128class OpDequantizePerChannelGroupDefault(OpSkipOps):
129    """
130    do nothing if node is dequantize_per_channel_group.default
131    """
132
133    target = "quantized_decomposed.dequantize_per_channel_group.default"
134