• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestStaticConstantPad(unittest.TestCase):
14    class StaticConstantPadFunctional(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x, y, z):
19            pad_6 = (1, 2, 3, 4, 5, 6)
20            pad_4 = (1, 2, 3, 4)
21            pad_2 = (1, 2)
22            a = torch.nn.functional.pad(
23                input=x,
24                pad=pad_6,
25                mode="constant",
26                value=2.3,
27            )
28            b = torch.nn.functional.pad(
29                input=x,
30                pad=pad_4,
31                mode="constant",
32                value=1.3,
33            )
34            c = torch.nn.functional.pad(
35                input=x,
36                pad=pad_2,
37                mode="constant",
38                value=2.1,
39            )
40            d = torch.nn.functional.pad(
41                input=y,
42                pad=pad_6,
43                mode="constant",
44                value=2.7,
45            )
46            e = torch.nn.functional.pad(
47                input=y,
48                pad=pad_4,
49                mode="constant",
50                value=1.9,
51            )
52            f = torch.nn.functional.pad(
53                input=y,
54                pad=pad_2,
55                mode="constant",
56                value=3.1,
57            )
58            g = torch.nn.functional.pad(
59                input=z,
60                pad=pad_4,
61                mode="constant",
62                value=2.9,
63            )
64            h = torch.nn.functional.pad(
65                input=z,
66                pad=pad_2,
67                mode="constant",
68                value=1.2,
69            )
70
71            # Pad quantizes by propagation
72
73            return (a + a, b + b, c + c, d + d, e + e, f + f, g + g, h + h)
74
75    class StaticConstantPad2d(torch.nn.Module):
76        def __init__(self):
77            super().__init__()
78            self.pad = torch.nn.ConstantPad2d([1, 2, 3, 4], 2.3)
79
80        def forward(self, x):
81            y = self.pad(x)
82            # Pad quantizes by propagation
83            z = y + y
84            return z
85
86    def _test_static_constant_pad_functional(self, inputs):
87        (
88            Tester(self.StaticConstantPadFunctional(), inputs)
89            .export()
90            .check_count({"torch.ops.aten.pad.default": 8})
91            .to_edge_transform_and_lower()
92            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93            .check_not(
94                ["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"]
95            )
96            .to_executorch()
97            .serialize()
98            .run_method_and_compare_outputs()
99        )
100
101    def test_fp16_static_constant_pad_functional(self):
102        inputs = (
103            torch.randn(size=(5, 4, 3, 2)).to(torch.float16),
104            torch.randn(size=(5, 3, 2)).to(torch.float16),
105            torch.randn(size=(4, 3)).to(torch.float16),
106        )
107        self._test_static_constant_pad_functional(inputs)
108
109    def test_fp32_static_constant_pad_functional(self):
110        inputs = (
111            torch.randn(size=(5, 4, 3, 2)),
112            torch.randn(size=(5, 3, 2)),
113            torch.randn(size=(4, 3)),
114        )
115        self._test_static_constant_pad_functional(inputs)
116
117    def test_qs8_static_constant_pad_functional(self):
118        class Pad(torch.nn.Module):
119            def __init__(self):
120                super().__init__()
121
122            def forward(self, x):
123                z = torch.nn.functional.pad(
124                    input=x,
125                    pad=(2, 1),
126                    mode="constant",
127                    value=2.3,
128                )
129                return z + z
130
131        inputs = (torch.randn(size=(1, 2)),)
132        (
133            Tester(Pad(), inputs)
134            .quantize()
135            .export()
136            .check_count({"torch.ops.aten.pad.default": 1})
137            .check(["torch.ops.quantized_decomposed"])
138            .to_edge_transform_and_lower()
139            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
140            .check_not(
141                [
142                    "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"
143                    "torch.ops.quantized_decomposed",
144                ]
145            )
146            .to_executorch()
147            .serialize()
148            .run_method_and_compare_outputs()
149        )
150
151    def test_qs8_static_constant_pad_2d(self):
152        inputs = (torch.randn(size=(5, 4, 3, 2)),)
153        (
154            Tester(self.StaticConstantPad2d(), inputs)
155            .quantize()
156            .export()
157            .check_count({"torch.ops.aten.pad.default": 1})
158            .check(["torch.ops.quantized_decomposed"])
159            .to_edge_transform_and_lower()
160            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
161            .check_not(
162                [
163                    "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
164                    "torch.ops.quantized_decomposed",
165                ]
166            )
167            .to_executorch()
168            .serialize()
169            .run_method_and_compare_outputs()
170        )
171