1# Owner(s): ["oncall: mobile"] 2 3import unittest 4import torch 5from torch.nn import functional as F 6 7from torch.testing._internal.common_utils import TestCase, run_tests 8from torch.testing import FileCheck 9import io 10 11@unittest.skipUnless(torch.is_vulkan_available(), 12 "Vulkan backend must be available for these tests.") 13class TestVulkanRewritePass(TestCase): 14 @staticmethod 15 def validate_transformed_module( 16 # To please flake 17 self, 18 pattern_count_map, 19 data_shape, 20 prepack_removal=False, 21 fuse_clamping_ops=False): 22 module_instance = self 23 scripted_model = torch.jit.script(module_instance) 24 scripted_model.eval() 25 input_data = torch.normal(1, 20, size=data_shape) 26 ref_result = scripted_model(input_data) 27 torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c) 28 if fuse_clamping_ops or prepack_removal: 29 scripted_model._c = torch._C._freeze_module(scripted_model._c) 30 if fuse_clamping_ops: 31 torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c) 32 if prepack_removal: 33 torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c) 34 35 buffer = io.BytesIO() 36 torch.jit.save(scripted_model, buffer) 37 buffer.seek(0) 38 deserialized_scripted_model = torch.jit.load(buffer) 39 for pattern, v in pattern_count_map.items(): 40 if (v == 0): 41 FileCheck().check(pattern).run(deserialized_scripted_model.graph) 42 elif (v == -1): 43 FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) 44 else: 45 FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) 46 47 def test_conv(self): 48 # Conv params 49 batch_size = 2 50 input_channels_per_group = 6 51 height = 16 52 width = 16 53 output_channels_per_group = 6 54 groups = 4 55 kernel_h = kernel_w = 3 56 stride_h = stride_w = 1 57 pad_h = pad_w = 1 58 dilation = 1 59 input_channels = input_channels_per_group * groups 60 output_channels = output_channels_per_group * groups 61 kernels = (kernel_h, kernel_w) 62 strides = (stride_h, stride_w) 63 paddings = (pad_h, pad_w) 64 dilations = (dilation, dilation) 65 conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) 66 conv_bias_shape = (output_channels) 67 68 class Conv2D(torch.nn.Module): 69 def __init__(self) -> None: 70 super().__init__() 71 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 72 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 73 self.strides = strides 74 self.paddings = paddings 75 self.dilations = dilations 76 self.groups = groups 77 78 def forward(self, x): 79 return F.conv2d(x, self.weight, self.bias, 80 self.strides, self.paddings, self.dilations, self.groups) 81 82 data_shape = (batch_size, input_channels, height, width) 83 pattern_count_map = {"Tensor = aten::conv2d": -1, 84 "vulkan_prepack::conv2d_clamp_prepack": 1, 85 "vulkan_prepack::conv2d_clamp_run": 1} 86 TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) 87 88 class Conv2DRelu(torch.nn.Module): 89 def __init__(self) -> None: 90 super().__init__() 91 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 92 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 93 self.strides = strides 94 self.paddings = paddings 95 self.dilations = dilations 96 self.groups = groups 97 98 def forward(self, x): 99 o = F.conv2d(x, self.weight, self.bias, 100 self.strides, self.paddings, self.dilations, self.groups) 101 o = F.relu(o) 102 return o 103 104 data_shape = (batch_size, input_channels, height, width) 105 pattern_count_map = {"Tensor = aten::conv2d": -1, 106 "vulkan_prepack::conv2d_clamp_prepack": 1, 107 "vulkan_prepack::conv2d_clamp_run": 1} 108 TestVulkanRewritePass.validate_transformed_module( 109 Conv2DRelu(), pattern_count_map, data_shape) 110 111 pattern_count_map["aten::relu"] = 1 112 pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 113 TestVulkanRewritePass.validate_transformed_module( 114 Conv2DRelu(), 115 pattern_count_map, 116 data_shape, 117 prepack_removal=True) 118 pattern_count_map["aten::relu"] = -1 119 TestVulkanRewritePass.validate_transformed_module( 120 Conv2DRelu(), 121 pattern_count_map, 122 data_shape, 123 prepack_removal=True, 124 fuse_clamping_ops=True) 125 126 127 class Conv2DHardtanh(torch.nn.Module): 128 def __init__(self) -> None: 129 super().__init__() 130 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 131 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 132 self.strides = strides 133 self.paddings = paddings 134 self.dilations = dilations 135 self.groups = groups 136 137 def forward(self, x): 138 o = F.conv2d(x, self.weight, self.bias, 139 self.strides, self.paddings, self.dilations, self.groups) 140 o = F.hardtanh(o) 141 return o 142 143 data_shape = (batch_size, input_channels, height, width) 144 pattern_count_map = {"Tensor = aten::conv2d": -1, 145 "vulkan_prepack::conv2d_clamp_prepack": 1, 146 "vulkan_prepack::conv2d_clamp_run": 1} 147 TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) 148 pattern_count_map["aten::hardtanh"] = 1 149 pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 150 TestVulkanRewritePass.validate_transformed_module( 151 Conv2DHardtanh(), 152 pattern_count_map, 153 data_shape, 154 prepack_removal=True) 155 pattern_count_map["aten::hardtanh"] = -1 156 TestVulkanRewritePass.validate_transformed_module( 157 Conv2DRelu(), 158 pattern_count_map, 159 data_shape, 160 prepack_removal=True, 161 fuse_clamping_ops=True) 162 163if __name__ == "__main__": 164 run_tests() 165