• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import unittest
4from functools import partial
5
6import torch
7from torch._inductor.ir import Pointwise
8from torch._inductor.lowering import make_pointwise, register_lowering
9from torch._inductor.test_case import TestCase as InductorTestCase
10from torch._inductor.virtualized import ops
11from torch.testing._internal.common_utils import skipIfRocm
12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
13
14
15# These tests check issues for lowerings that aren't in the main pytorch repo
16class TestCustomLowering(InductorTestCase):
17    @classmethod
18    def setUpClass(cls):
19        super().setUpClass()
20        cls.test_inductor_ops = torch.library.Library(  # noqa: TOR901
21            "test_inductor_ops", "DEF"
22        )
23        cls.impl_cuda = torch.library.Library(  # noqa: TOR901
24            "test_inductor_ops", "IMPL", "CUDA"
25        )
26        cls.impl_meta = torch.library.Library(  # noqa: TOR901
27            "test_inductor_ops", "IMPL", "Meta"
28        )
29        cls._register_jagged_to_padded_dense()
30        cls._register_asm_op()
31
32    @classmethod
33    def tearDown(cls):
34        super().tearDownClass()
35
36    @classmethod
37    def _register_jagged_to_padded_dense(cls):
38        # Approximation of fbgemm.jagged_to_padded_dense_forward
39        cls.test_inductor_ops.define(
40            "jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor"
41        )
42
43        def j2pd_meta(inp, offsets, max_seq_len, pad_value):
44            return torch.empty(
45                (offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
46                device=inp.device,
47                dtype=inp.dtype,
48            )
49
50        def j2pd_cuda(inp, offsets, max_seq_len, pad_value):
51            res = torch.full(
52                (offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
53                pad_value,
54                device=inp.device,
55                dtype=inp.dtype,
56            )
57            for b in range(offsets.shape[0] - 1):
58                for r in range(offsets[b + 1] - offsets[b]):
59                    res[b][r] = inp[offsets[b] + r]
60            return res
61
62        def j2pd_lowering(inp, offsets, max_seq_len, pad_value):
63            offsets_loader = offsets.make_loader()
64            inp_loader = inp.make_loader()
65            jagged_len = inp.get_size()[0]
66            offsets_dtype = offsets.get_dtype()
67
68            def inner_fn(index):
69                batch_idx, seq_idx, emb_idx = index
70
71                begin_idx = ops.indirect_indexing(
72                    offsets_loader([batch_idx]),
73                    jagged_len + 1,
74                )
75                end_idx = offsets_loader([batch_idx + 1])
76                jagged_idx = begin_idx + seq_idx
77
78                return ops.masked(
79                    ops.lt(
80                        ops.index_expr(jagged_idx, offsets_dtype),
81                        end_idx,
82                    ),
83                    lambda: inp_loader([jagged_idx, emb_idx]),
84                    pad_value,
85                )
86
87            return Pointwise.create(
88                device=inp.get_device(),
89                dtype=inp.get_dtype(),
90                inner_fn=inner_fn,
91                ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]],
92            )
93
94        register_lowering(
95            torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None
96        )(j2pd_lowering)
97
98        cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta)
99        cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda)
100
101    @classmethod
102    def _register_asm_op(cls):
103        # Approximation of fbgemm.jagged_to_padded_dense_forward
104        cls.test_inductor_ops.define("tanh_approx(Tensor input) -> Tensor")
105
106        def tanh_approx_meta(inp):
107            return torch.tanh(inp)
108
109        cls.impl_meta.impl("tanh_approx", tanh_approx_meta)
110
111        def tanh_approx_lowering(inp):
112            fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
113            return make_pointwise(fn)(inp)
114
115        register_lowering(
116            torch.ops.test_inductor_ops.tanh_approx, type_promotion_kind=None
117        )(tanh_approx_lowering)
118
119        cls.test_inductor_ops.define("add_custom(Tensor a, Tensor b) -> Tensor")
120
121        def add_custom(a, b):
122            return a + b
123
124        cls.impl_meta.impl("add_custom", add_custom)
125
126        def add_custom_lowering(a, b):
127            fn = partial(ops.inline_asm_elementwise, asm="add.f32 $0, $1, $2;")
128            return make_pointwise(fn)(a, b)
129
130        register_lowering(
131            torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None
132        )(add_custom_lowering)
133
134    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
135    def test_jagged_to_padded_dense_sanity_cuda(self):
136        def fn(inp, offsets, max_seq_len):
137            return torch.ops.test_inductor_ops.jagged_to_padded_dense(
138                inp, offsets, max_seq_len, 60.0
139            )
140
141        inp = torch.rand((9, 96), device="cuda")
142        offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda")
143        max_seq_len = 4
144
145        res = fn(inp, offsets, max_seq_len)
146        self.assertEqual(inp[0], res[0][0])
147        self.assertEqual(inp[1], res[0][1])
148        self.assertEqual(inp[2], res[1][0])
149        self.assertEqual(inp[3], res[1][1])
150        self.assertEqual(inp[5], res[2][0])
151        self.assertEqual(inp[8], res[2][3])
152
153        fn_opt = torch.compile(fn)
154
155        self.assertEqual(
156            fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
157        )
158
159    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
160    def test_jagged_to_padded_dense_zero_size(self):
161        # Previously, the masking was being completely stripped for the
162        # masked load of the input value. That would lead to an IMA
163        # because cuda was trying to read index 0 of a zero-size tensor.
164        def fn(inp, offsets, max_seq_len):
165            inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1))
166            return torch.ops.test_inductor_ops.jagged_to_padded_dense(
167                inp, offsets, max_seq_len, 60.0
168            )
169
170        inp = torch.rand((1, 0, 96), device="cuda")
171        offsets = torch.zeros(1025, device="cuda", dtype=torch.int32)
172        max_seq_len = 20
173
174        fn_opt = torch.compile(fn)
175
176        self.assertEqual(
177            fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
178        )
179
180    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
181    @skipIfRocm
182    def test_tanh_approx(self):
183        def fn(inp):
184            return torch.ops.test_inductor_ops.tanh_approx(inp)
185
186        inp = torch.randn(32, device="cuda")
187        fn_opt = torch.compile(fn)
188
189        a = torch.tanh(inp)
190        b = fn_opt(inp)
191        self.assertEqual(a, b)
192
193    @unittest.skipIf(not HAS_CUDA, "CUDA needed")
194    @skipIfRocm
195    def test_multi_inp_asm(self):
196        def fn(a, b):
197            return torch.ops.test_inductor_ops.add_custom(a, b)
198
199        a = torch.randn(32, device="cuda")
200        b = torch.randn(32, device="cuda")
201        fn_opt = torch.compile(fn)
202
203        out1 = a + b
204        out2 = fn_opt(a, b)
205        self.assertEqual(out1, out2)
206
207
208if __name__ == "__main__":
209    from torch._inductor.test_case import run_tests
210
211    if HAS_CPU or HAS_CUDA:
212        run_tests(needs="filelock")
213