• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2024 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import logging
7from typing import cast
8
9import torch
10from executorch.backends.apple.mps.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.apple.mps.serialization.mps_graph_schema import (
15    MPSDataType,
16    MPSDequantizePerChannelGroup,
17    MPSGraph,
18    MPSNode,
19)
20from executorch.backends.apple.mps.utils.mps_utils import get_input_node
21
22FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
23logging.basicConfig(level=logging.DEBUG, format=FORMAT)
24
25
26@register_node_visitor
27class OpDequantizePerChannelGroupDefault(NodeVisitor):
28    target = "quantized_decomposed.dequantize_per_channel_group.default"
29
30    def __init__(self, *args) -> None:
31        super().__init__(*args)
32
33    def define_node(
34        self,
35        node: torch.fx.Node,
36        mps_graph: MPSGraph,
37    ) -> None:
38        # Weights placeholders shouldn't have been defined until this point
39        if get_input_node(node, 0) in self.tensor_to_id:
40            raise RuntimeError(
41                f"Placeholder for {node.target.__name__} already visited"
42            )
43        output_id = self.define_tensor(node, mps_graph)
44        input_id = self.define_tensor(
45            get_input_node(node, 0), mps_graph, MPSDataType.mps_data_type_int4
46        )
47        scales_id = self.define_tensor(get_input_node(node, 1), mps_graph)
48
49        # there are no zero points in this quantization method (node.args[2] is all zeros)
50        zero_points_id = -1
51        quant_min = cast(int, node.args[3])
52        quant_max = cast(int, node.args[4])
53        dtype = self.torch_dtype_to_mps_dtype(node.args[5])
54        group_size = cast(int, node.args[6])
55        output_dtype = self.torch_dtype_to_mps_dtype(node.args[7])
56
57        dequant_node = MPSNode(
58            mpsnode_union=MPSDequantizePerChannelGroup(
59                input1_id=input_id,
60                output_id=output_id,
61                scales_id=scales_id,
62                zero_points_id=zero_points_id,
63                quant_min=quant_min,
64                quant_max=quant_max,
65                dtype=dtype,
66                group_size=group_size,
67                output_dtype=output_dtype,
68            )
69        )
70        mps_graph.mps_nodes.append(dequant_node)
71
72
73@register_node_visitor
74class OpQuantizePerToken(NodeVisitor):
75    """
76    Dynamic Quantize Per Token Node visitor
77    """
78
79    target = "quantized_decomposed.quantize_per_token.default"
80
81    def __init__(self, *args) -> None:
82        super().__init__(*args)
83
84    def define_node(
85        self,
86        node: torch.fx.Node,
87        mps_graph: MPSGraph,
88    ) -> None:
89        """
90        Skip activation dynamic quantization for now.
91        Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
92        Issue: #133407308
93        """
94        dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
95        self.tensor_to_id[node] = dq_input
96
97
98@register_node_visitor
99class OpDequantizePerToken(NodeVisitor):
100    """
101    Dequantize Per Token Node visitor
102    """
103
104    target = "quantized_decomposed.dequantize_per_token.default"
105
106    def __init__(self, *args) -> None:
107        super().__init__(*args)
108
109    def define_node(
110        self,
111        node: torch.fx.Node,
112        mps_graph: MPSGraph,
113    ) -> None:
114        """
115        Skip activation dynamic quantization for now.
116        Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
117        Issue: #133407308
118        """
119        dq_input = self.define_tensor(get_input_node(node, 0), mps_graph)
120        self.tensor_to_id[node] = dq_input
121
122
123@register_node_visitor
124class OpChooseQparamsToken(NodeVisitor):
125    """
126    do nothing if node is choose_qparams_per_token_asymmetric.tensor
127    """
128
129    target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default"
130
131    def define_node(
132        self,
133        node: torch.fx.Node,
134        mps_graph: MPSGraph,
135    ) -> None:
136        """
137        Skip activation dynamic quantization for now.
138        Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8].
139        Issue: #133407308
140        """
141        input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
142        self.tensor_to_id[node] = [input_id, input_id]
143