• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: mobile"]
2
3import unittest
4import torch
5from torch.nn import functional as F
6
7from torch.testing._internal.common_utils import TestCase, run_tests
8from torch.testing import FileCheck
9import io
10
11@unittest.skipUnless(torch.is_vulkan_available(),
12                     "Vulkan backend must be available for these tests.")
13class TestVulkanRewritePass(TestCase):
14    @staticmethod
15    def validate_transformed_module(
16            # To please flake
17            self,
18            pattern_count_map,
19            data_shape,
20            prepack_removal=False,
21            fuse_clamping_ops=False):
22        module_instance = self
23        scripted_model = torch.jit.script(module_instance)
24        scripted_model.eval()
25        input_data = torch.normal(1, 20, size=data_shape)
26        ref_result = scripted_model(input_data)
27        torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
28        if fuse_clamping_ops or prepack_removal:
29            scripted_model._c = torch._C._freeze_module(scripted_model._c)
30        if fuse_clamping_ops:
31            torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
32        if prepack_removal:
33            torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
34
35        buffer = io.BytesIO()
36        torch.jit.save(scripted_model, buffer)
37        buffer.seek(0)
38        deserialized_scripted_model = torch.jit.load(buffer)
39        for pattern, v in pattern_count_map.items():
40            if (v == 0):
41                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
42            elif (v == -1):
43                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
44            else:
45                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
46
47    def test_conv(self):
48        # Conv params
49        batch_size = 2
50        input_channels_per_group = 6
51        height = 16
52        width = 16
53        output_channels_per_group = 6
54        groups = 4
55        kernel_h = kernel_w = 3
56        stride_h = stride_w = 1
57        pad_h = pad_w = 1
58        dilation = 1
59        input_channels = input_channels_per_group * groups
60        output_channels = output_channels_per_group * groups
61        kernels = (kernel_h, kernel_w)
62        strides = (stride_h, stride_w)
63        paddings = (pad_h, pad_w)
64        dilations = (dilation, dilation)
65        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
66        conv_bias_shape = (output_channels)
67
68        class Conv2D(torch.nn.Module):
69            def __init__(self) -> None:
70                super().__init__()
71                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
72                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
73                self.strides = strides
74                self.paddings = paddings
75                self.dilations = dilations
76                self.groups = groups
77
78            def forward(self, x):
79                return F.conv2d(x, self.weight, self.bias,
80                                self.strides, self.paddings, self.dilations, self.groups)
81
82        data_shape = (batch_size, input_channels, height, width)
83        pattern_count_map = {"Tensor = aten::conv2d": -1,
84                             "vulkan_prepack::conv2d_clamp_prepack": 1,
85                             "vulkan_prepack::conv2d_clamp_run": 1}
86        TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
87
88        class Conv2DRelu(torch.nn.Module):
89            def __init__(self) -> None:
90                super().__init__()
91                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
92                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
93                self.strides = strides
94                self.paddings = paddings
95                self.dilations = dilations
96                self.groups = groups
97
98            def forward(self, x):
99                o = F.conv2d(x, self.weight, self.bias,
100                             self.strides, self.paddings, self.dilations, self.groups)
101                o = F.relu(o)
102                return o
103
104        data_shape = (batch_size, input_channels, height, width)
105        pattern_count_map = {"Tensor = aten::conv2d": -1,
106                             "vulkan_prepack::conv2d_clamp_prepack": 1,
107                             "vulkan_prepack::conv2d_clamp_run": 1}
108        TestVulkanRewritePass.validate_transformed_module(
109            Conv2DRelu(), pattern_count_map, data_shape)
110
111        pattern_count_map["aten::relu"] = 1
112        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
113        TestVulkanRewritePass.validate_transformed_module(
114            Conv2DRelu(),
115            pattern_count_map,
116            data_shape,
117            prepack_removal=True)
118        pattern_count_map["aten::relu"] = -1
119        TestVulkanRewritePass.validate_transformed_module(
120            Conv2DRelu(),
121            pattern_count_map,
122            data_shape,
123            prepack_removal=True,
124            fuse_clamping_ops=True)
125
126
127        class Conv2DHardtanh(torch.nn.Module):
128            def __init__(self) -> None:
129                super().__init__()
130                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
131                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
132                self.strides = strides
133                self.paddings = paddings
134                self.dilations = dilations
135                self.groups = groups
136
137            def forward(self, x):
138                o = F.conv2d(x, self.weight, self.bias,
139                             self.strides, self.paddings, self.dilations, self.groups)
140                o = F.hardtanh(o)
141                return o
142
143        data_shape = (batch_size, input_channels, height, width)
144        pattern_count_map = {"Tensor = aten::conv2d": -1,
145                             "vulkan_prepack::conv2d_clamp_prepack": 1,
146                             "vulkan_prepack::conv2d_clamp_run": 1}
147        TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
148        pattern_count_map["aten::hardtanh"] = 1
149        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
150        TestVulkanRewritePass.validate_transformed_module(
151            Conv2DHardtanh(),
152            pattern_count_map,
153            data_shape,
154            prepack_removal=True)
155        pattern_count_map["aten::hardtanh"] = -1
156        TestVulkanRewritePass.validate_transformed_module(
157            Conv2DRelu(),
158            pattern_count_map,
159            data_shape,
160            prepack_removal=True,
161            fuse_clamping_ops=True)
162
163if __name__ == "__main__":
164    run_tests()
165