1# Owner(s): ["module: inductor"] 2 3import unittest 4from functools import partial 5 6import torch 7from torch._inductor.ir import Pointwise 8from torch._inductor.lowering import make_pointwise, register_lowering 9from torch._inductor.test_case import TestCase as InductorTestCase 10from torch._inductor.virtualized import ops 11from torch.testing._internal.common_utils import skipIfRocm 12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 13 14 15# These tests check issues for lowerings that aren't in the main pytorch repo 16class TestCustomLowering(InductorTestCase): 17 @classmethod 18 def setUpClass(cls): 19 super().setUpClass() 20 cls.test_inductor_ops = torch.library.Library( # noqa: TOR901 21 "test_inductor_ops", "DEF" 22 ) 23 cls.impl_cuda = torch.library.Library( # noqa: TOR901 24 "test_inductor_ops", "IMPL", "CUDA" 25 ) 26 cls.impl_meta = torch.library.Library( # noqa: TOR901 27 "test_inductor_ops", "IMPL", "Meta" 28 ) 29 cls._register_jagged_to_padded_dense() 30 cls._register_asm_op() 31 32 @classmethod 33 def tearDown(cls): 34 super().tearDownClass() 35 36 @classmethod 37 def _register_jagged_to_padded_dense(cls): 38 # Approximation of fbgemm.jagged_to_padded_dense_forward 39 cls.test_inductor_ops.define( 40 "jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor" 41 ) 42 43 def j2pd_meta(inp, offsets, max_seq_len, pad_value): 44 return torch.empty( 45 (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), 46 device=inp.device, 47 dtype=inp.dtype, 48 ) 49 50 def j2pd_cuda(inp, offsets, max_seq_len, pad_value): 51 res = torch.full( 52 (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), 53 pad_value, 54 device=inp.device, 55 dtype=inp.dtype, 56 ) 57 for b in range(offsets.shape[0] - 1): 58 for r in range(offsets[b + 1] - offsets[b]): 59 res[b][r] = inp[offsets[b] + r] 60 return res 61 62 def j2pd_lowering(inp, offsets, max_seq_len, pad_value): 63 offsets_loader = offsets.make_loader() 64 inp_loader = inp.make_loader() 65 jagged_len = inp.get_size()[0] 66 offsets_dtype = offsets.get_dtype() 67 68 def inner_fn(index): 69 batch_idx, seq_idx, emb_idx = index 70 71 begin_idx = ops.indirect_indexing( 72 offsets_loader([batch_idx]), 73 jagged_len + 1, 74 ) 75 end_idx = offsets_loader([batch_idx + 1]) 76 jagged_idx = begin_idx + seq_idx 77 78 return ops.masked( 79 ops.lt( 80 ops.index_expr(jagged_idx, offsets_dtype), 81 end_idx, 82 ), 83 lambda: inp_loader([jagged_idx, emb_idx]), 84 pad_value, 85 ) 86 87 return Pointwise.create( 88 device=inp.get_device(), 89 dtype=inp.get_dtype(), 90 inner_fn=inner_fn, 91 ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]], 92 ) 93 94 register_lowering( 95 torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None 96 )(j2pd_lowering) 97 98 cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta) 99 cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda) 100 101 @classmethod 102 def _register_asm_op(cls): 103 # Approximation of fbgemm.jagged_to_padded_dense_forward 104 cls.test_inductor_ops.define("tanh_approx(Tensor input) -> Tensor") 105 106 def tanh_approx_meta(inp): 107 return torch.tanh(inp) 108 109 cls.impl_meta.impl("tanh_approx", tanh_approx_meta) 110 111 def tanh_approx_lowering(inp): 112 fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;") 113 return make_pointwise(fn)(inp) 114 115 register_lowering( 116 torch.ops.test_inductor_ops.tanh_approx, type_promotion_kind=None 117 )(tanh_approx_lowering) 118 119 cls.test_inductor_ops.define("add_custom(Tensor a, Tensor b) -> Tensor") 120 121 def add_custom(a, b): 122 return a + b 123 124 cls.impl_meta.impl("add_custom", add_custom) 125 126 def add_custom_lowering(a, b): 127 fn = partial(ops.inline_asm_elementwise, asm="add.f32 $0, $1, $2;") 128 return make_pointwise(fn)(a, b) 129 130 register_lowering( 131 torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None 132 )(add_custom_lowering) 133 134 @unittest.skipIf(not HAS_CUDA, "CUDA needed") 135 def test_jagged_to_padded_dense_sanity_cuda(self): 136 def fn(inp, offsets, max_seq_len): 137 return torch.ops.test_inductor_ops.jagged_to_padded_dense( 138 inp, offsets, max_seq_len, 60.0 139 ) 140 141 inp = torch.rand((9, 96), device="cuda") 142 offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda") 143 max_seq_len = 4 144 145 res = fn(inp, offsets, max_seq_len) 146 self.assertEqual(inp[0], res[0][0]) 147 self.assertEqual(inp[1], res[0][1]) 148 self.assertEqual(inp[2], res[1][0]) 149 self.assertEqual(inp[3], res[1][1]) 150 self.assertEqual(inp[5], res[2][0]) 151 self.assertEqual(inp[8], res[2][3]) 152 153 fn_opt = torch.compile(fn) 154 155 self.assertEqual( 156 fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) 157 ) 158 159 @unittest.skipIf(not HAS_CUDA, "CUDA needed") 160 def test_jagged_to_padded_dense_zero_size(self): 161 # Previously, the masking was being completely stripped for the 162 # masked load of the input value. That would lead to an IMA 163 # because cuda was trying to read index 0 of a zero-size tensor. 164 def fn(inp, offsets, max_seq_len): 165 inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1)) 166 return torch.ops.test_inductor_ops.jagged_to_padded_dense( 167 inp, offsets, max_seq_len, 60.0 168 ) 169 170 inp = torch.rand((1, 0, 96), device="cuda") 171 offsets = torch.zeros(1025, device="cuda", dtype=torch.int32) 172 max_seq_len = 20 173 174 fn_opt = torch.compile(fn) 175 176 self.assertEqual( 177 fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) 178 ) 179 180 @unittest.skipIf(not HAS_CUDA, "CUDA needed") 181 @skipIfRocm 182 def test_tanh_approx(self): 183 def fn(inp): 184 return torch.ops.test_inductor_ops.tanh_approx(inp) 185 186 inp = torch.randn(32, device="cuda") 187 fn_opt = torch.compile(fn) 188 189 a = torch.tanh(inp) 190 b = fn_opt(inp) 191 self.assertEqual(a, b) 192 193 @unittest.skipIf(not HAS_CUDA, "CUDA needed") 194 @skipIfRocm 195 def test_multi_inp_asm(self): 196 def fn(a, b): 197 return torch.ops.test_inductor_ops.add_custom(a, b) 198 199 a = torch.randn(32, device="cuda") 200 b = torch.randn(32, device="cuda") 201 fn_opt = torch.compile(fn) 202 203 out1 = a + b 204 out2 = fn_opt(a, b) 205 self.assertEqual(out1, out2) 206 207 208if __name__ == "__main__": 209 from torch._inductor.test_case import run_tests 210 211 if HAS_CPU or HAS_CUDA: 212 run_tests(needs="filelock") 213