1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestStaticConstantPad(unittest.TestCase): 14 class StaticConstantPadFunctional(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x, y, z): 19 pad_6 = (1, 2, 3, 4, 5, 6) 20 pad_4 = (1, 2, 3, 4) 21 pad_2 = (1, 2) 22 a = torch.nn.functional.pad( 23 input=x, 24 pad=pad_6, 25 mode="constant", 26 value=2.3, 27 ) 28 b = torch.nn.functional.pad( 29 input=x, 30 pad=pad_4, 31 mode="constant", 32 value=1.3, 33 ) 34 c = torch.nn.functional.pad( 35 input=x, 36 pad=pad_2, 37 mode="constant", 38 value=2.1, 39 ) 40 d = torch.nn.functional.pad( 41 input=y, 42 pad=pad_6, 43 mode="constant", 44 value=2.7, 45 ) 46 e = torch.nn.functional.pad( 47 input=y, 48 pad=pad_4, 49 mode="constant", 50 value=1.9, 51 ) 52 f = torch.nn.functional.pad( 53 input=y, 54 pad=pad_2, 55 mode="constant", 56 value=3.1, 57 ) 58 g = torch.nn.functional.pad( 59 input=z, 60 pad=pad_4, 61 mode="constant", 62 value=2.9, 63 ) 64 h = torch.nn.functional.pad( 65 input=z, 66 pad=pad_2, 67 mode="constant", 68 value=1.2, 69 ) 70 71 # Pad quantizes by propagation 72 73 return (a + a, b + b, c + c, d + d, e + e, f + f, g + g, h + h) 74 75 class StaticConstantPad2d(torch.nn.Module): 76 def __init__(self): 77 super().__init__() 78 self.pad = torch.nn.ConstantPad2d([1, 2, 3, 4], 2.3) 79 80 def forward(self, x): 81 y = self.pad(x) 82 # Pad quantizes by propagation 83 z = y + y 84 return z 85 86 def _test_static_constant_pad_functional(self, inputs): 87 ( 88 Tester(self.StaticConstantPadFunctional(), inputs) 89 .export() 90 .check_count({"torch.ops.aten.pad.default": 8}) 91 .to_edge_transform_and_lower() 92 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 93 .check_not( 94 ["executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default"] 95 ) 96 .to_executorch() 97 .serialize() 98 .run_method_and_compare_outputs() 99 ) 100 101 def test_fp16_static_constant_pad_functional(self): 102 inputs = ( 103 torch.randn(size=(5, 4, 3, 2)).to(torch.float16), 104 torch.randn(size=(5, 3, 2)).to(torch.float16), 105 torch.randn(size=(4, 3)).to(torch.float16), 106 ) 107 self._test_static_constant_pad_functional(inputs) 108 109 def test_fp32_static_constant_pad_functional(self): 110 inputs = ( 111 torch.randn(size=(5, 4, 3, 2)), 112 torch.randn(size=(5, 3, 2)), 113 torch.randn(size=(4, 3)), 114 ) 115 self._test_static_constant_pad_functional(inputs) 116 117 def test_qs8_static_constant_pad_functional(self): 118 class Pad(torch.nn.Module): 119 def __init__(self): 120 super().__init__() 121 122 def forward(self, x): 123 z = torch.nn.functional.pad( 124 input=x, 125 pad=(2, 1), 126 mode="constant", 127 value=2.3, 128 ) 129 return z + z 130 131 inputs = (torch.randn(size=(1, 2)),) 132 ( 133 Tester(Pad(), inputs) 134 .quantize() 135 .export() 136 .check_count({"torch.ops.aten.pad.default": 1}) 137 .check(["torch.ops.quantized_decomposed"]) 138 .to_edge_transform_and_lower() 139 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 140 .check_not( 141 [ 142 "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default" 143 "torch.ops.quantized_decomposed", 144 ] 145 ) 146 .to_executorch() 147 .serialize() 148 .run_method_and_compare_outputs() 149 ) 150 151 def test_qs8_static_constant_pad_2d(self): 152 inputs = (torch.randn(size=(5, 4, 3, 2)),) 153 ( 154 Tester(self.StaticConstantPad2d(), inputs) 155 .quantize() 156 .export() 157 .check_count({"torch.ops.aten.pad.default": 1}) 158 .check(["torch.ops.quantized_decomposed"]) 159 .to_edge_transform_and_lower() 160 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 161 .check_not( 162 [ 163 "executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default", 164 "torch.ops.quantized_decomposed", 165 ] 166 ) 167 .to_executorch() 168 .serialize() 169 .run_method_and_compare_outputs() 170 ) 171