1# 2# Copyright (c) 2024 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import logging 7from typing import cast 8 9import torch 10from executorch.backends.apple.mps.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 15 MPSDataType, 16 MPSDequantizePerChannelGroup, 17 MPSGraph, 18 MPSNode, 19) 20from executorch.backends.apple.mps.utils.mps_utils import get_input_node 21 22FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 23logging.basicConfig(level=logging.DEBUG, format=FORMAT) 24 25 26@register_node_visitor 27class OpDequantizePerChannelGroupDefault(NodeVisitor): 28 target = "quantized_decomposed.dequantize_per_channel_group.default" 29 30 def __init__(self, *args) -> None: 31 super().__init__(*args) 32 33 def define_node( 34 self, 35 node: torch.fx.Node, 36 mps_graph: MPSGraph, 37 ) -> None: 38 # Weights placeholders shouldn't have been defined until this point 39 if get_input_node(node, 0) in self.tensor_to_id: 40 raise RuntimeError( 41 f"Placeholder for {node.target.__name__} already visited" 42 ) 43 output_id = self.define_tensor(node, mps_graph) 44 input_id = self.define_tensor( 45 get_input_node(node, 0), mps_graph, MPSDataType.mps_data_type_int4 46 ) 47 scales_id = self.define_tensor(get_input_node(node, 1), mps_graph) 48 49 # there are no zero points in this quantization method (node.args[2] is all zeros) 50 zero_points_id = -1 51 quant_min = cast(int, node.args[3]) 52 quant_max = cast(int, node.args[4]) 53 dtype = self.torch_dtype_to_mps_dtype(node.args[5]) 54 group_size = cast(int, node.args[6]) 55 output_dtype = self.torch_dtype_to_mps_dtype(node.args[7]) 56 57 dequant_node = MPSNode( 58 mpsnode_union=MPSDequantizePerChannelGroup( 59 input1_id=input_id, 60 output_id=output_id, 61 scales_id=scales_id, 62 zero_points_id=zero_points_id, 63 quant_min=quant_min, 64 quant_max=quant_max, 65 dtype=dtype, 66 group_size=group_size, 67 output_dtype=output_dtype, 68 ) 69 ) 70 mps_graph.mps_nodes.append(dequant_node) 71 72 73@register_node_visitor 74class OpQuantizePerToken(NodeVisitor): 75 """ 76 Dynamic Quantize Per Token Node visitor 77 """ 78 79 target = "quantized_decomposed.quantize_per_token.default" 80 81 def __init__(self, *args) -> None: 82 super().__init__(*args) 83 84 def define_node( 85 self, 86 node: torch.fx.Node, 87 mps_graph: MPSGraph, 88 ) -> None: 89 """ 90 Skip activation dynamic quantization for now. 91 Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8]. 92 Issue: #133407308 93 """ 94 dq_input = self.define_tensor(get_input_node(node, 0), mps_graph) 95 self.tensor_to_id[node] = dq_input 96 97 98@register_node_visitor 99class OpDequantizePerToken(NodeVisitor): 100 """ 101 Dequantize Per Token Node visitor 102 """ 103 104 target = "quantized_decomposed.dequantize_per_token.default" 105 106 def __init__(self, *args) -> None: 107 super().__init__(*args) 108 109 def define_node( 110 self, 111 node: torch.fx.Node, 112 mps_graph: MPSGraph, 113 ) -> None: 114 """ 115 Skip activation dynamic quantization for now. 116 Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8]. 117 Issue: #133407308 118 """ 119 dq_input = self.define_tensor(get_input_node(node, 0), mps_graph) 120 self.tensor_to_id[node] = dq_input 121 122 123@register_node_visitor 124class OpChooseQparamsToken(NodeVisitor): 125 """ 126 do nothing if node is choose_qparams_per_token_asymmetric.tensor 127 """ 128 129 target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default" 130 131 def define_node( 132 self, 133 node: torch.fx.Node, 134 mps_graph: MPSGraph, 135 ) -> None: 136 """ 137 Skip activation dynamic quantization for now. 138 Currently all matmuls are going through [FP16/BF16] @ [QInt4/QInt8]. 139 Issue: #133407308 140 """ 141 input_id = self.define_tensor(get_input_node(node, 0), mps_graph) 142 self.tensor_to_id[node] = [input_id, input_id] 143