# Owner(s): ["oncall: quantization"] import copy import itertools import sys from enum import Enum import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ObserverBase from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, skipIfNoInductorSupport, skipIfNoX86, ) from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfTorchDynamo if IS_WINDOWS and IS_CI: sys.stderr.write("Windows CI still has some issue to be fixed.\n") sys.exit(0) class NodePosType(Enum): left = 1 right = 2 both = 3 class TestHelperModules: class SingleConv2dModule(torch.nn.Module): def __init__(self, with_bn=False) -> None: super().__init__() self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1)) self.bn = torch.nn.BatchNorm2d(6) self.with_bn = with_bn def forward(self, x): x = self.conv(x) if self.with_bn: x = self.bn(x) return x class Conv2dUnaryModule(torch.nn.Module): def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None: super().__init__() self.conv = nn.Conv2d( 3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias ) self.post_op = post_op self.bn = torch.nn.BatchNorm2d(6) self.with_bn = with_bn self.maxpool = torch.nn.MaxPool2d((3, 3)) def forward(self, x): x = self.conv(x) if self.with_bn: x = self.bn(x) x = self.post_op(x) x = self.maxpool(x) return x class Conv2dAddModule(torch.nn.Module): def __init__( self, inplace_add: bool = False, conv2d_type: NodePosType = NodePosType.left, use_bias: bool = False, with_bn: bool = False, ) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias, ) self.conv2 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias, ) self.relu = nn.ReLU() self.inplace_add = inplace_add self.conv2d_type = conv2d_type self.bn = torch.nn.BatchNorm2d(3) self.with_bn = with_bn def forward(self, x): if self.conv2d_type == NodePosType.left: if self.inplace_add: tmp = self.conv(x) if self.with_bn: tmp = self.bn(tmp) tmp += self.relu(x) return tmp else: tmp = self.conv(x) if self.with_bn: tmp = self.bn(tmp) return tmp + self.relu(x) elif self.conv2d_type == NodePosType.right: if self.inplace_add: tmp = self.relu(x) tmp += self.conv(x) return tmp else: return self.relu(x) + self.conv(x) elif self.conv2d_type == NodePosType.both: if self.inplace_add: tmp = self.conv(x) tmp += self.conv2(x) return tmp else: return self.conv(x) + self.conv2(x) class Conv2dAddReLUModule(torch.nn.Module): def __init__( self, inplace_add: bool = False, conv2d_type: NodePosType = NodePosType.left, inplace_relu: bool = False, use_bias: bool = False, with_bn: bool = False, ) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias, ) self.conv2 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=use_bias, ) self.relu = nn.ReLU() self.inplace_add = inplace_add self.conv2d_type = conv2d_type self.relu2 = nn.ReLU(inplace=inplace_relu) self.bn = torch.nn.BatchNorm2d(3) self.with_bn = with_bn def forward(self, x): if self.conv2d_type == NodePosType.left: if self.inplace_add: tmp = self.conv(x) if self.with_bn: tmp = self.bn(tmp) tmp += self.relu(x) return self.relu2(tmp) else: tmp = self.conv(x) if self.with_bn: tmp = self.bn(tmp) return self.relu2(tmp + self.relu(x)) elif self.conv2d_type == NodePosType.right: if self.inplace_add: tmp = self.relu(x) tmp += self.conv(x) return self.relu2(tmp) else: return self.relu2(self.relu(x) + self.conv(x)) elif self.conv2d_type == NodePosType.both: if self.inplace_add: tmp = self.conv(x) tmp += self.conv2(x) return self.relu2(tmp) else: return self.relu2(self.conv(x) + self.conv2(x)) class Conv2dSingleOpPowModule(nn.Module): def __init__(self, single_op): super().__init__() self.conv = nn.Conv2d(2, 2, 1) self.single_op = single_op def forward(self, x): x = self.conv(x) x = self.single_op(x) return torch.pow(x, 2) class SerialsConv2dAddReLUModule(torch.nn.Module): """Serials of 2 Conv2d -> Add -> ReLU Pattern.""" def __init__( self, ) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, ) self.conv2 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, ) self.conv3 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, ) self.conv4 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, ) self.relu = nn.ReLU() self.relu2 = nn.ReLU() def forward(self, x): x1 = self.conv(x) res1 = self.relu(self.conv2(x1) + self.conv3(x1)) res2 = self.relu2(self.conv4(res1) + res1) return res2 class Conv2dCatMaxpool2d(torch.nn.Module): def __init__( self, ): super().__init__() self.conv = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.conv2 = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.relu = torch.nn.ReLU() self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1) self.conv3 = torch.nn.Conv2d( 32, 32, 7, bias=True, stride=2, padding=3, dilation=1 ) def forward(self, x): temp1 = self.relu(self.conv(x)) temp2 = self.conv2(x + 1) temp3 = torch.cat((temp1, temp2), 1) temp4 = self.maxpool(temp3) temp5 = self.conv3(temp4) return temp5 class Conv2dAvgPool2d(torch.nn.Module): def __init__( self, ): super().__init__() self.conv = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1) def forward(self, x): temp1 = self.avgpool(self.conv(x)) return temp1 class Conv2dCatSameInputs(torch.nn.Module): def __init__( self, ): super().__init__() self.conv = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.relu = torch.nn.ReLU() def forward(self, x): temp1 = self.relu(self.conv(x)) temp3 = torch.cat((temp1, temp1), 1) return temp3 class Conv2dCatSingleInput(torch.nn.Module): def __init__( self, ): super().__init__() self.conv = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.relu = torch.nn.ReLU() def forward(self, x): temp1 = self.relu(self.conv(x)) temp3 = torch.cat((temp1,), 1) return temp3 class SingleLinearModule(torch.nn.Module): def __init__(self, use_bias) -> None: super().__init__() self.linear = nn.Linear(4, 4, bias=use_bias) def forward(self, x): return self.linear(x) class LinearUnaryModule(torch.nn.Module): def __init__( self, use_bias, postop, inplace_postop=False, post_op_algo="none" ) -> None: super().__init__() self.linear = nn.Linear(4, 4, bias=use_bias) if postop == nn.GELU: self.postop = postop(approximate=post_op_algo) else: self.postop = postop(inplace=inplace_postop) def forward(self, x): return self.postop(self.linear(x)) class LinearAddModule(torch.nn.Module): def __init__( self, inplace_add: bool = False, linear_pos: NodePosType = NodePosType.left, use_bias: bool = False, ) -> None: super().__init__() self.linear = torch.nn.Linear( in_features=16, out_features=16, bias=use_bias ) self.linear2 = torch.nn.Linear( in_features=16, out_features=16, bias=use_bias ) self.relu = nn.ReLU() self.inplace_add = inplace_add self.linear_pos = linear_pos def forward(self, x): if self.linear_pos == NodePosType.left: if self.inplace_add: tmp = self.linear(x) tmp += self.relu(x) return tmp else: tmp = self.linear(x) return tmp + self.relu(x) elif self.linear_pos == NodePosType.right: if self.inplace_add: tmp = self.relu(x) tmp += self.linear(x) return tmp else: return self.relu(x) + self.linear(x) elif self.linear_pos == NodePosType.both: if self.inplace_add: tmp = self.linear(x) tmp += self.linear2(x) return tmp else: return self.linear(x) + self.linear2(x) class LinearAddReLUModule(torch.nn.Module): def __init__( self, inplace_add: bool = False, linear_pos: NodePosType = NodePosType.left, inplace_relu: bool = False, use_bias: bool = False, ) -> None: super().__init__() self.linear = torch.nn.Linear( in_features=16, out_features=16, bias=use_bias ) self.linear2 = torch.nn.Linear( in_features=16, out_features=16, bias=use_bias ) self.relu = nn.ReLU() self.inplace_add = inplace_add self.linear_pos = linear_pos self.relu2 = nn.ReLU(inplace=inplace_relu) def forward(self, x): if self.linear_pos == NodePosType.left: if self.inplace_add: tmp = self.linear(x) tmp += self.relu(x) return self.relu2(tmp) else: tmp = self.linear(x) return self.relu2(tmp + self.relu(x)) elif self.linear_pos == NodePosType.right: if self.inplace_add: tmp = self.relu(x) tmp += self.linear(x) return self.relu2(tmp) else: return self.relu2(self.relu(x) + self.linear(x)) elif self.linear_pos == NodePosType.both: if self.inplace_add: tmp = self.linear(x) tmp += self.linear2(x) return self.relu2(tmp) else: return self.relu2(self.linear(x) + self.linear2(x)) class SerialsLinearAddReLUModule(torch.nn.Module): """Serials of 2 Linear -> Add -> ReLU Pattern.""" def __init__( self, ) -> None: super().__init__() self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.relu = nn.ReLU() self.relu2 = nn.ReLU() def forward(self, x): x1 = self.linear(x) res1 = self.relu(self.linear2(x1) + self.linear3(x1)) res2 = self.relu2(self.linear4(res1) + res1) return res2 class LinearAddModule2(torch.nn.Module): def __init__( self, inplace_add: bool = False, ) -> None: super().__init__() self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) self.inplace_add = inplace_add def forward(self, x): if self.inplace_add: tmp = self.linear(x) tmp += self.linear2(tmp) return tmp else: tmp = self.linear(x) return tmp + self.linear2(tmp) class Conv2dAddModule2(torch.nn.Module): def __init__( self, inplace_add: bool = False, ) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 ) self.conv2 = torch.nn.Conv2d( in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 ) self.inplace_add = inplace_add self.bn = torch.nn.BatchNorm2d(3) self.bn2 = torch.nn.BatchNorm2d(3) def forward(self, x): if self.inplace_add: tmp = self.bn(self.conv(x)) tmp += self.bn2(self.conv2(tmp)) return tmp else: tmp = self.bn(self.conv(x)) return tmp + self.bn2(self.conv2(tmp)) class SelfAttnLikeModule(torch.nn.Module): def __init__( self, input_dim, transpose_for_score=False, num_attention_heads=None, attention_head_size=None, ) -> None: super().__init__() self.input_dim = input_dim self.q_proj = nn.Linear(input_dim, input_dim, bias=False) self.k_proj = nn.Linear(input_dim, input_dim, bias=False) self.v_proj = nn.Linear(input_dim, input_dim, bias=False) self.softmax = nn.Softmax(dim=-1) self.transpose_for_score = transpose_for_score if self.transpose_for_score: assert num_attention_heads is not None assert attention_head_size is not None self.num_attention_heads = num_attention_heads self.attention_head_size = attention_head_size def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, x): q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) if self.transpose_for_score: q = self.transpose_for_scores(q) k = self.transpose_for_scores(k) v = self.transpose_for_scores(v) scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) attention = self.softmax(scores) weighted = torch.matmul(attention, v) return weighted class X86InductorQuantTestCase(QuantizationTestCase): def _test_quantizer( self, model, example_inputs, quantizer, expected_node_occurrence, expected_node_list=None, is_qat=False, debug=False, ): m_eager = model.train() if is_qat else model.eval() # program capture m = copy.deepcopy(m_eager) m = capture_pre_autograd_graph( m, example_inputs, ) # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) prepare_model = copy.deepcopy(m) m = convert_pt2e(m) convert_model = copy.deepcopy(m) if debug: convert_model.print_readable(True) pt2_quant_output = m(*example_inputs) node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() } if expected_node_list is None: expected_node_list = [] node_list = [ns.call_function(n) for n in expected_node_list] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) return export_model, prepare_model, convert_model @skipIfNoInductorSupport class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): @skipIfNoX86 def test_conv2d(self): """ Test pattern of single conv2d with X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SingleConv2dModule().eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { # one for input and weight of the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_conv2d_unary(self): """ Test pattern of conv2d with unary post ops (such as relu, hardtanh, hardswish, relu6) with X86InductorQuantizer. """ unary_map = { "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], "hardtanh": [ torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default, ], "hardtanh_inplace": [ torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default, ], "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], "relu6_inplace": [ torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default, ], "hardswish": [ torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default, ], "hardswish_inplace": [ torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default, ], "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], "swish_inplace": [ torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default, ], } use_bias_list = [True, False] with override_quantized_engine("x86"), torch.no_grad(): for unary_op, use_bias in itertools.product( unary_map.keys(), use_bias_list ): m = TestHelperModules.Conv2dUnaryModule( unary_map[unary_op][0], use_bias=use_bias ).eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { # one for input and weight of the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, unary_map[unary_op][1], ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_conv2d_binary(self): """ Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. """ conv2d_type_list = [NodePosType.left, NodePosType.both] example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) with override_quantized_engine("x86"), torch.no_grad(): for conv2d_type in conv2d_type_list: m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval() if conv2d_type != NodePosType.both: node_occurrence = { # one for input and weight of the conv # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: node_occurrence = { # one for input of the conv # one for input of another conv # 2 conv will share same input quant/dequant # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.aten.add.Tensor, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_conv2d_binary2(self): """ Test Pattern: tmp = conv2d_1(x) tmp2 = conv2d_2(tmp) return tmp + tmp2 Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 """ example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) inplace_add_list = [True, False] with override_quantized_engine("x86"), torch.no_grad(): for inplace_add in inplace_add_list: m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval() node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_conv2d_binary_unary(self): """ Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. """ conv2d_type_list = [NodePosType.left, NodePosType.both] example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) with override_quantized_engine("x86"), torch.no_grad(): for conv2d_type in conv2d_type_list: m = TestHelperModules.Conv2dAddReLUModule( conv2d_type=conv2d_type, ).eval() if conv2d_type != NodePosType.both: node_occurrence = { # one for input for conv # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: node_occurrence = { # one for input of the conv # one for input of another conv # 2 conv will share same input quant/dequant # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.aten.add.Tensor, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_conv2d_serials_binary_unary(self): """ Test pattern of 2 following up conv2d add relu with X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SerialsConv2dAddReLUModule().eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.default, torch.ops.aten.add.Tensor, torch.ops.aten.relu.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) def _single_op_share_observer_recipe_test_helper(self, m, x, single_op): quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (x,) node_occurrence = { # one for input and weight of the conv, two for input/output for the maxpool2d torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, single_op, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] _, prepare_model, _ = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) # Check Maxpool2d has share observer at input and output for node in prepare_model.graph.nodes: if node.op == "call_function" and node.target is single_op: single_op_node = node input_obs_of_single_op = getattr( prepare_model, single_op_node.args[0].target ) output_obs_of_single_op = getattr( prepare_model, next(iter(single_op_node.users)).target ) elif ( node.op == "call_function" and node.target is torch.ops.aten.conv2d.default ): conv_node = node input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) @skipIfNoX86 def test_maxpool2d_recipe(self): r""" Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. """ self._single_op_share_observer_recipe_test_helper( TestHelperModules.Conv2dSingleOpPowModule(nn.MaxPool2d(1, 1)).eval(), torch.rand(1, 2, 14, 14), torch.ops.aten.max_pool2d.default, ) @skipIfNoX86 def test_adaptive_avg_pool2d_recipe(self): r""" Test pattern: int8_in_int8_out_ops(adaptive_avg_pool2d) - non_quantizable op(pow) Since adaptive_avg_pool2d is a int8_in_int8_out_op, there is obs between adaptive_avg_pool2d and pow. """ self._single_op_share_observer_recipe_test_helper( TestHelperModules.Conv2dSingleOpPowModule( nn.AdaptiveAvgPool2d((1, 1)) ).eval(), torch.rand(1, 2, 14, 14), torch.ops.aten.adaptive_avg_pool2d.default, ) @skipIfNoX86 def test_flatten_recipe(self): r""" Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow) Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow. """ self._single_op_share_observer_recipe_test_helper( TestHelperModules.Conv2dSingleOpPowModule( lambda x: torch.flatten(x, 1) ).eval(), torch.rand(1, 2, 14, 14), torch.ops.aten.flatten.using_ints, ) @skipIfNoX86 def test_cat_recipe(self): r""" Test pattern: conv -> cat -> maxpool2d Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer. """ m = TestHelperModules.Conv2dCatMaxpool2d().eval() x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (x,) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 6, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.cat.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.max_pool2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] _, prepare_model, _ = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) # Check Cat/Maxpool2d has share observer at input and output for node in prepare_model.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.cat.default: cat_act_obs0 = getattr(prepare_model, node.all_input_nodes[0].target) cat_act_obs1 = getattr(prepare_model, node.all_input_nodes[1].target) cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) elif ( node.op == "call_function" and node.target is torch.ops.aten.max_pool2d.default ): maxpool_node = node input_obs_of_maxpool = getattr( prepare_model, maxpool_node.args[0].target ) output_obs_of_maxpool = getattr( prepare_model, next(iter(maxpool_node.users)).target ) self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) self.assertTrue(isinstance(cat_out_obs, ObserverBase)) self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase)) self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase)) self.assertTrue(cat_act_obs0 is cat_act_obs1) self.assertTrue(cat_act_obs0 is cat_out_obs) self.assertTrue(cat_out_obs is input_obs_of_maxpool) self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool) @skipIfNoX86 def test_cat_recipe_same_inputs(self): r""" Test pattern: conv -> cat([input0, input0]) Since cat has 2 input node of same tensor, they should also be with same observer. """ m = TestHelperModules.Conv2dCatSameInputs().eval() x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (x,) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.cat.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] _, prepare_model, _ = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) # Check Cat has share observer at input and output for node in prepare_model.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.cat.default: cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) cat_act_obs1 = getattr(prepare_model, node.args[0][1].target) cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) self.assertTrue(isinstance(cat_out_obs, ObserverBase)) self.assertTrue(cat_act_obs0 is cat_act_obs1) self.assertTrue(cat_act_obs0 is cat_out_obs) @skipIfNoX86 def test_cat_recipe_single_input(self): r""" Test pattern: conv -> cat([input0,]) Since cat has 1 input node, they should also be with same observer. """ m = TestHelperModules.Conv2dCatSingleInput().eval() x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (x,) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.cat.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] _, prepare_model, _ = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) # Check Cat has share observer at input and output for node in prepare_model.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.cat.default: cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) self.assertTrue(isinstance(cat_out_obs, ObserverBase)) self.assertTrue(cat_act_obs0 is cat_out_obs) @skipIfNoX86 def test_avg_pool2d_recipe(self): r""" Test pattern: conv -> AvgPool2d Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer. """ m = TestHelperModules.Conv2dAvgPool2d().eval() x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (x,) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.avg_pool2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] _, prepare_model, _ = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) for node in prepare_model.graph.nodes: if ( node.op == "call_function" and node.target is torch.ops.aten.avg_pool2d.default ): avgpool_node = node input_obs_of_avgpool = getattr( prepare_model, avgpool_node.args[0].target ) output_obs_of_avgpool = getattr( prepare_model, next(iter(avgpool_node.users)).target ) elif ( node.op == "call_function" and node.target is torch.ops.aten.conv2d.default ): conv_node = node output_obs_of_conv = getattr( prepare_model, next(iter(conv_node.users)).target ) self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase)) self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase)) self.assertTrue(isinstance(output_obs_of_conv, ObserverBase)) self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool) self.assertTrue(input_obs_of_avgpool is output_obs_of_conv) @skipIfNoX86 def test_linear(self): """ Test pattern of single linear with X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): for use_bias in [True, False]: m = TestHelperModules.SingleLinearModule(use_bias).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { # one for input and weight, one for output torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) def _test_linear_unary_helper( self, post_op_module, post_op_aten, post_op_aten_inplace, post_op_algo_list=None, is_qat=False, is_dynamic=False, ): """ Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) with override_quantized_engine("x86"), torch.no_grad(): for use_bias, inplace, post_op_algo in cases: if inplace and post_op_aten_inplace is None: continue m = TestHelperModules.LinearUnaryModule( use_bias=use_bias, postop=post_op_module, inplace_postop=inplace, post_op_algo=post_op_algo, ).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic, ) ) quantize_per_tensor_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequantize_per_tensor_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { # one for input of the linear quantize_per_tensor_op: 1, dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ quantize_per_tensor_op, dequantize_per_tensor_op, torch.ops.aten.linear.default, post_op_aten_inplace if inplace else post_op_aten, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=is_qat, ) @skipIfNoX86 def test_linear_unary(self): aten = torch.ops.aten self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) self._test_linear_unary_helper( nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default ) self._test_linear_unary_helper( nn.GELU, aten.gelu.default, None, ["none", "tanh"] ) @skipIfNoX86 def test_linear_unary_qat(self): aten = torch.ops.aten self._test_linear_unary_helper( nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True ) self._test_linear_unary_helper( nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True ) self._test_linear_unary_helper( nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True ) @skipIfNoX86 def test_linear_unary_dynamic(self): aten = torch.ops.aten self._test_linear_unary_helper( nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True ) self._test_linear_unary_helper( nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_dynamic=True, ) self._test_linear_unary_helper( nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True ) @skipIfNoX86 def test_linear_unary_dynamic_qat(self): aten = torch.ops.aten self._test_linear_unary_helper( nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True ) self._test_linear_unary_helper( nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True, is_dynamic=True, ) self._test_linear_unary_helper( nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True, is_dynamic=True, ) def _check_annotation_stat(self, gm, expected_stat_dict): # Check expected annotation statistics to ensure the annotation is correct def _check_annotation(node): annot = node.meta.get(QUANT_ANNOTATION_KEY, None) if annot is None: return False, False return annot._annotated, annot._is_output_of_quantized_pattern for node in gm.graph.nodes: if node.target in expected_stat_dict.keys(): annotated, is_quant_out = _check_annotation(node) expected_stat_dict[node.target]["annotated"] -= annotated expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out for op_stat in expected_stat_dict.values(): assert all(v == 0 for v in op_stat.values()) def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic, ) ) quantize_per_tensor_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequantize_per_tensor_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list) with override_quantized_engine("x86"), torch.no_grad(): for linear_pos, inplace_add in cases: m = TestHelperModules.LinearAddModule( inplace_add=inplace_add, linear_pos=linear_pos ).eval() if linear_pos != NodePosType.both: node_occurrence = { # Only one 1 q-dq for input of the linear # No q-dq for extra input node of add quantize_per_tensor_op: 1, dequantize_per_tensor_op: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: # convert_pt2e disables duplicate dequant for dynamic quant num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add quantize_per_tensor_op: 1, dequantize_per_tensor_op: num_dequant, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ quantize_per_tensor_op, dequantize_per_tensor_op, torch.ops.aten.linear.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), ] fq_m = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=is_qat, )[-1] # One linear and add are fused. The other linear is quantized alone if present aten = torch.ops.aten add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor expected_annotation_stat = { aten.linear.default: { "annotated": 2 if linear_pos == NodePosType.both else 1, "is_quant_out": 1 if linear_pos == NodePosType.both else 0, }, add_op: {"annotated": 1, "is_quant_out": 1}, } self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 def test_linear_binary(self): self._test_linear_binary_helper() @skipIfNoX86 def test_linear_binary_qat(self): self._test_linear_binary_helper(is_qat=True) @skipIfNoX86 def test_linear_binary_dynamic(self): self._test_linear_binary_helper(is_dynamic=True) @skipIfNoX86 def test_linear_binary_dynamic_qat(self): self._test_linear_binary_helper(is_qat=True, is_dynamic=True) @skipIfNoX86 def test_linear_binary2(self): """ Test Pattern: tmp = linear_1(x) tmp2 = linear_2(tmp) return tmp + tmp2 Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] is_qat_list = [False, True] is_dynamic_list = [False, True] cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): for inplace_add, is_qat, is_dynamic in cases: quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic ) ) m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() quantize_per_tensor_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequantize_per_tensor_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) # Two q-dq nodes for inputs of linear nodes # No q-dq for extra input node of add node_occurrence = { quantize_per_tensor_op: 2, dequantize_per_tensor_op: 2, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ torch.ops.quantized_decomposed.dequantize_per_channel.default, quantize_per_tensor_op, dequantize_per_tensor_op, torch.ops.aten.linear.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), ] fq_m = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, )[-1] # One linear and add are fused. The other linear is quantized alone if present aten = torch.ops.aten add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor expected_annotation_stat = { aten.linear.default: { "annotated": 2, "is_quant_out": 1, }, add_op: {"annotated": 1, "is_quant_out": 1}, } self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): """ Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] # TODO test for inplace add after refactoring of capture_pre_autograd_graph inplace_add_list = [False] # TODO test for inplace relu after refactoring of capture_pre_autograd_graph inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic, ) ) quantize_per_tensor_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequantize_per_tensor_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) with override_quantized_engine("x86"), torch.no_grad(): for linear_pos, inplace_add, inplace_relu in cases: m = TestHelperModules.LinearAddReLUModule( inplace_add=inplace_add, linear_pos=linear_pos, inplace_relu=inplace_relu, ).eval() if linear_pos != NodePosType.both: node_occurrence = { # Only one q-dq node for input of the linear # No q-dq node for extra input node of add quantize_per_tensor_op: 1, dequantize_per_tensor_op: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: # convert_pt2e disables duplicate dequant for dynamic quant num_dequant = 1 if is_dynamic else 2 node_occurrence = { # One quantize_per_tensor for both linear nodes (shared) # Two dequantize_per_tensor for two linear nodes # No q-dq for extra input node of add quantize_per_tensor_op: 1, dequantize_per_tensor_op: num_dequant, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ quantize_per_tensor_op, dequantize_per_tensor_op, torch.ops.aten.linear.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), ] fq_m = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, )[-1] # linear, add, relu are fused # The other linear is quantized alone if present aten = torch.ops.aten add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor relu_op = aten.relu_.default if inplace_relu else aten.relu.default expected_annotation_stat = { aten.linear.default: { "annotated": 2 if linear_pos == NodePosType.both else 1, "is_quant_out": 1 if linear_pos == NodePosType.both else 0, }, add_op: {"annotated": 1, "is_quant_out": 0}, relu_op: {"annotated": 1, "is_quant_out": 1}, } self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfNoX86 def test_linear_binary_unary(self): self._test_linear_binary_unary_helper() @skipIfNoX86 def test_linear_binary_unary_qat(self): self._test_linear_binary_unary_helper(is_qat=True) @skipIfNoX86 def test_linear_binary_unary_dynamic(self): self._test_linear_binary_unary_helper(is_dynamic=True) @skipIfNoX86 def test_linear_binary_unary_dynamic_qat(self): self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) @skipIfNoX86 def test_linear_binary_unary_serials(self): """ Test pattern of 2 following up linear add relu with X86InductorQuantizer. """ is_qat_list = [False, True] is_dynamic_list = [False, True] cases = itertools.product(is_qat_list, is_dynamic_list) with override_quantized_engine("x86"), torch.no_grad(): for is_qat, is_dynamic in cases: m = TestHelperModules.SerialsLinearAddReLUModule().eval() example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic, ) ) quantize_per_tensor_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequantize_per_tensor_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) # convert_pt2e disables duplicate dequant for dynamic quant num_dequant = 3 if is_dynamic else 4 node_occurrence = { # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 # dequantize_per_tensor: 1 for each linear # No q-dq for extra input node of add quantize_per_tensor_op: 3, dequantize_per_tensor_op: num_dequant, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, } node_list = [ torch.ops.quantized_decomposed.dequantize_per_channel.default, quantize_per_tensor_op, dequantize_per_tensor_op, torch.ops.aten.linear.default, torch.ops.aten.linear.default, torch.ops.aten.linear.default, torch.ops.aten.add.Tensor, torch.ops.aten.relu.default, ] fq_m = self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, )[-1] # Two linear nodes are quantized alone # The other two are fused with add and relu aten = torch.ops.aten expected_annotation_stat = { aten.linear.default: { "annotated": 4, "is_quant_out": 2, }, aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, aten.relu.default: {"annotated": 2, "is_quant_out": 2}, } self._check_annotation_stat(fq_m, expected_annotation_stat) @skipIfTorchDynamo("very slow") @skipIfNoX86 def test_qat_conv2d(self): """ Test QAT pattern of conv2d_bn with X86InductorQuantizer. """ with override_quantized_engine("x86"): m = TestHelperModules.SingleConv2dModule(with_bn=True) example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) ) node_occurrence = { # one for input and weight of the conv, one for output for the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfTorchDynamo("very slow") @skipIfNoX86 def test_qat_conv2d_unary(self): """ Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. Currently, only relu as unary post op is supported. """ unary_map = { "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], "hardtanh": [ torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default, ], "hardtanh_inplace": [ torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default, ], "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], "relu6_inplace": [ torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default, ], "hardswish": [ torch.nn.Hardswish(inplace=False), torch.ops.aten.hardswish.default, ], "hardswish_inplace": [ torch.nn.Hardswish(inplace=True), torch.ops.aten.hardswish_.default, ], "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], "swish_inplace": [ torch.nn.SiLU(inplace=True), torch.ops.aten.silu_.default, ], } with override_quantized_engine("x86"): for unary_op in unary_map.keys(): m = TestHelperModules.Conv2dUnaryModule( unary_map[unary_op][0], with_bn=True ) example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) ) node_occurrence = { # one for input and weight of the conv, one for output for the relu torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, unary_map[unary_op][1], torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfTorchDynamo("very slow") @skipIfNoX86 def test_qat_conv2d_binary(self): """ Test qat pattern of conv2d_bn with binary post ops (such as add) with X86InductorQuantizer. Currently, only add as binary post op is supported. """ example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) ) with override_quantized_engine("x86"): for inplace_add in [True, False]: m = TestHelperModules.Conv2dAddModule( inplace_add=inplace_add, with_bn=True ) node_occurrence = { # one for input and weight of the conv # one for output for the add # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfTorchDynamo("very slow") @skipIfNoX86 def test_qat_conv2d_binary2(self): """ Test qat Pattern: tmp = bn1(conv2d_1(x)) tmp2 = bn2(conv2d_2(tmp)) return tmp + tmp2 Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 """ example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) ) inplace_add_list = [True, False] with override_quantized_engine("x86"), torch.no_grad(): for inplace_add in inplace_add_list: m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ( torch.ops.aten.add_.Tensor if inplace_add else torch.ops.aten.add.Tensor ), ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfTorchDynamo("very slow") @skipIfNoX86 def test_qat_conv2d_binary_unary(self): """ Test QAT pattern of conv2d_bn with binary + unary post ops (such as add + relu) with X86InductorQuantizer. Currently, only add as binary post op and relu as unary post op are supported. """ example_inputs = (torch.randn(2, 3, 6, 6),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) ) with override_quantized_engine("x86"): m = TestHelperModules.Conv2dAddReLUModule(with_bn=True) node_occurrence = { # one for input for conv # one for output for the relu # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.aten.add.Tensor, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfNoX86 def test_dynamic_quant_linear(self): """ Test pattern of dynamic quantization of linear with X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_dynamic=True) ) node_occurrence = { torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } node_list = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_qat_dynamic_quant_linear(self): """ Test pattern of qat dynamic quantization of linear with X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=True, is_dynamic=True ) ) node_occurrence = { torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } node_list = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=True, ) @skipIfNoX86 def test_set_module_name_qconfig(self): """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. Expect that all linear layers within the submodule `sub` are quantized. """ class Sub(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.relu1 = torch.nn.ReLU(inplace=False) self.linear2 = torch.nn.Linear(10, 5) def forward(self, x): x = self.linear1(x) x = self.relu1(x) x = self.linear2(x) return x class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) self.sub = Sub() def forward(self, x): x = self.linear(x) x = self.sub(x) return x m = M().eval() example_inputs = (torch.randn(3, 5),) # Set global to `None` and then default config for a specific submodule. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { torch.ops.aten.linear.default: 3, # quantize and dequantize the input of two linear layers from `sub` torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # dequantize the weight of two linear layers from `sub` torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ # first linear is not quantized torch.ops.aten.linear.default, # two Q/DQ pairs for two linear layers from `sub` torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_set_module_name_qconfig_with_underscores(self) -> None: """Test that if a module name has an underscore, we can still quantize it.""" class M(torch.nn.Module): def __init__(self) -> None: super().__init__() # This module name has underscores, which can be part of a mangled name. self.foo_bar = torch.nn.Linear(2, 2) self.baz = torch.nn.Linear(2, 2) def forward(self, x): return self.baz(self.foo_bar(x)) # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "foo_bar", xiq.get_default_x86_inductor_quantization_config() ) example_inputs = (torch.randn(2, 2),) m = M().eval() m = capture_pre_autograd_graph(m, example_inputs) m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. count = 0 for n in m.graph.nodes: if n.op == "call_function" and n.target == torch.ops.aten.linear.default: # Get the weight observer to see the per-channel vs per-tensor. weight_observer_node = n.args[1] if count == 0: # for foo_bar. self.assertEqual( weight_observer_node.op, "call_module", f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", ) observer_instance = getattr(m, weight_observer_node.target) self.assertEqual( observer_instance.qscheme, torch.per_channel_symmetric ) else: # For baz it should have no observer at all. self.assertNotEqual( weight_observer_node.op, "call_module", f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", ) count += 1 @skipIfNoX86 def test_set_module_name_and_module_type_case1(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. Expect that all linear layers are not quantized except the last one. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.linear2 = torch.nn.Linear(10, 5) self.sub = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = self.sub(x) return x m = M().eval() example_inputs = (torch.randn(3, 5),) # Set `sub` with default config and then `None` for all `Linear`. # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config() ).set_module_type_qconfig(torch.nn.Linear, None) node_occurrence = { torch.ops.aten.linear.default: 3, # quantize and dequantize the input of the last linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # dequantize the weight of the last linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ # first and second linear are not quantized torch.ops.aten.linear.default, torch.ops.aten.linear.default, # last linear is quantized torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_set_module_name_and_module_type_case2(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. Expect that all linear layers are quantized except the last one. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.linear2 = torch.nn.Linear(10, 5) self.sub = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = self.sub(x) return x m = M().eval() example_inputs = (torch.randn(3, 5),) # Set `sub` with None and then default config for a all `Linear`. quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() ) node_occurrence = { torch.ops.aten.linear.default: 3, # quantize and dequantize the input and output of the first and second linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, # dequantize the weight of the first and second linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ # Q/DQ for first lienar torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, # Q/DQ for second lienar torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, # last linear is not quantized torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_set_module_name_qconfig_for_dynamic_quant(self): """Test that quantize a specific submodule for dynamic quantization.""" with override_quantized_engine("x86"), torch.no_grad(): for is_qat in [False, True]: m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) # only quantize `q_proj` `v_proj` dynamic_config = xiq.get_default_x86_inductor_quantization_config( is_dynamic=True, is_qat=is_qat ) quantizer = ( X86InductorQuantizer() .set_module_name_qconfig("q_proj", dynamic_config) .set_module_name_qconfig("v_proj", dynamic_config) ) node_occurrence = { # quantize and dequantize the input torch.ops.quantized_decomposed.choose_qparams.tensor: 1, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, # dequantize the weight of q_proj and v_proj torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ # quantize and dequantize the input torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, # q_proj torch.ops.aten.linear.default, # k_proj torch.ops.aten.linear.default, # v_proj torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=is_qat, ) @skipIfNoX86 def test_set_module_name_with_mixed_configs(self): """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. The config for 'v_proj' will always be ignored and raise a warning. """ with override_quantized_engine("x86"), torch.no_grad(): with self.assertWarns(UserWarning) as context: for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( [False, True], repeat=4 ): if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: continue m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() example_inputs = (torch.randn(1, 4, 64),) quantizer = ( X86InductorQuantizer() .set_module_name_qconfig( "q_proj", xiq.get_default_x86_inductor_quantization_config( is_qat=q_is_qat, is_dynamic=q_is_dynamic ), ) .set_module_name_qconfig( "v_proj", xiq.get_default_x86_inductor_quantization_config( is_qat=v_is_qat, is_dynamic=v_is_dynamic ), ) ) quant_op = ( torch.ops.quantized_decomposed.quantize_per_tensor.tensor if q_is_dynamic else torch.ops.quantized_decomposed.quantize_per_tensor.default ) dequant_op = ( torch.ops.quantized_decomposed.dequantize_per_tensor.tensor if q_is_dynamic else torch.ops.quantized_decomposed.dequantize_per_tensor.default ) node_occurrence = { # quantize and dequantize the input quant_op: 1, dequant_op: 1, # only `q_proj` was quantized, dequantize its weight torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ # quantize and dequantize the input quant_op, dequant_op, # q_proj torch.ops.aten.linear.default, # k_proj/v_proj torch.ops.aten.linear.default, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, is_qat=q_is_qat, ) warning_msg = ( "Mixed QAT and Non-QAT" if q_is_qat != v_is_qat else "Mixed dynamic and static" ) self.assertTrue( any( warning_msg in msg for msg in [str(w.message) for w in context.warnings] ) ) @skipIfNoX86 def test_set_module_name_and_module_type_with_mixed_configs(self): """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. Expect that only the last linear(`sub`) is quantized using static quantization. """ class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(5, 10) self.linear2 = torch.nn.Linear(10, 5) self.sub = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear1(x) x = self.linear2(x) x = self.sub(x) return x m = M().eval() example_inputs = (torch.randn(3, 5),) # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). quantizer = X86InductorQuantizer() quantizer.set_module_name_qconfig( "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) ).set_module_type_qconfig( torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), ) node_occurrence = { torch.ops.aten.linear.default: 3, # quantize and dequantize the input of the last linear torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # dequantize the weight of the last linear torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ # first and second linear are not quantized torch.ops.aten.linear.default, torch.ops.aten.linear.default, # Q/DQ pairs for the last linear torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_filter_conv2d_recipe(self): """ Test removing conv2d from default recipe of X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) quantizer.set_module_type_qconfig(torch.nn.Conv2d, None) node_occurrence = { # one for input and weight of the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, } node_list = [ torch.ops.aten.conv2d.default, torch.ops.aten.relu.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_filter_linear_recipe(self): """ Test removing linear from default recipe of X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.LinearUnaryModule( use_bias=True, postop=nn.ReLU, ).eval() example_inputs = (torch.randn(2, 4),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) quantizer.set_function_type_qconfig(torch.nn.functional.linear, None) node_occurrence = { # one for input and weight of the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, } node_list = [ torch.ops.aten.linear.default, torch.ops.aten.relu.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_filter_maxpool2d_recipe(self): """ Test removing maxpool2d from default recipe of X86InductorQuantizer. """ with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) quantizer.set_function_type_qconfig(torch.nn.functional.max_pool2d, None) node_occurrence = { # one for input and weight of the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, # note: quantize op for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, torch.ops.aten.relu.default, torch.ops.aten.max_pool2d.default, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, ) @skipIfNoX86 def test_attention_block(self): """ Test pattern of Attention like Block with X86InductorQuantizer. """ for annotate_matmul in [False, True]: with override_quantized_engine("x86"), torch.no_grad(): m = TestHelperModules.SelfAttnLikeModule( input_dim=64 * 16, transpose_for_score=True, num_attention_heads=16, attention_head_size=64, ).eval() example_inputs = (torch.randn(2, 384, 1024),) m(*example_inputs) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() ) if annotate_matmul: quantizer.set_function_type_qconfig( torch.matmul, quantizer.get_global_quantization_config() ) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: ( 5 if annotate_matmul else 1 ), torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( 7 if annotate_matmul else 3 ), # quantize_per_channel for weights are const propagated torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } if annotate_matmul: node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.aten.view.default, torch.ops.aten.permute.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.matmul.default, torch.ops.aten.div.Tensor, torch.ops.aten.softmax.int, ] else: node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.linear.default, torch.ops.aten.view.default, torch.ops.aten.permute.default, torch.ops.aten.matmul.default, torch.ops.aten.div.Tensor, torch.ops.aten.softmax.int, ] self._test_quantizer( m, example_inputs, quantizer, node_occurrence, node_list, )