• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: sparse"]
2#
3# Test to ensure sparsity information propagates properly into traced graph.
4#
5
6import sys
7import unittest
8
9import torch
10from torch._dynamo.config import is_fbcode
11from torch._subclasses.fake_tensor import FakeTensor
12from torch.testing._internal.common_utils import (
13    instantiate_parametrized_tests,
14    parametrize,
15    run_tests,
16    subtest,
17    TestCase,
18)
19
20
21# Various data types (preserved over operations).
22DTYPES = [
23    torch.int64,
24    torch.float16,
25    torch.bfloat16,
26    torch.float32,
27    torch.float64,
28]
29
30# Various index types.
31ITYPES = [torch.int32, torch.int64]
32
33
34# Constructs a subtest for every sparse layout currently supported in torch.sparse.
35def all_sparse_layouts(test_name="layout"):
36    return parametrize(
37        test_name,
38        [
39            subtest(torch.sparse_coo, name="SparseCOO"),
40            subtest(torch.sparse_csr, name="SparseCSR"),
41            subtest(torch.sparse_csc, name="SparseCSC"),
42            subtest(torch.sparse_bsr, name="SparseBSR"),
43            subtest(torch.sparse_bsc, name="SparseBSC"),
44        ],
45    )
46
47
48#
49# Various network examples.
50#
51
52
53class IdNet(torch.nn.Module):
54    def forward(self, x):
55        return x
56
57
58class SumNet(torch.nn.Module):
59    def forward(self, x):
60        return x.sum()
61
62
63class EltwiseNet(torch.nn.Module):
64    def forward(self, x):
65        return torch.nn.functional.relu(2 * torch.abs(-x))
66
67
68class ToDenseNet(torch.nn.Module):
69    def forward(self, x):
70        return x.to_dense()
71
72
73class AddNet(torch.nn.Module):
74    def forward(self, x, y):
75        return torch.add(x, y)
76
77
78class SparseActivationCOO(torch.nn.Module):
79    def forward(self, x):
80        return [xi.to_sparse() for xi in x]
81
82
83class SparseActivationCSR(torch.nn.Module):
84    def forward(self, x):
85        return [xi.to_sparse_csr() for xi in x]
86
87
88#
89# The test driver.
90#
91
92
93@unittest.skipIf(is_fbcode(), "See torch._dynamo.config")
94@unittest.skipIf(
95    sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
96)
97class TestSparseProp(TestCase):
98    def setUp(self):
99        TestCase.setUp(self)
100
101    def assertEqualMeta(self, x, y):
102        self.assertIsInstance(x, FakeTensor)
103        self.assertIsInstance(y, torch.Tensor)
104
105        # Convert expected value to meta for comparison.
106        y = y.to("meta")
107        self.assertEqual(x, y, exact_layout=True, exact_is_coalesced=True)
108
109        # When x or y is a meta tensor (say, `x.device == "meta"`), then
110        # assertEqual(x, y) compares only x and y attributes but skips
111        # comparing their values. In the case of sparse tensors, this means
112        # that comparing indices and values attributes are skipped as well,
113        # which is why we are doing that explicitly below.
114        if x.layout is torch.strided:
115            pass
116        elif x.layout is torch.sparse_coo:
117            self.assertEqual(x._indices(), y._indices(), exact_layout=True)
118            self.assertEqual(x._values(), y._values(), exact_layout=True)
119        else:
120            if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
121                x_meta1, y_meta1 = (x.crow_indices(), y.crow_indices())
122                x_meta2, y_meta2 = (x.col_indices(), y.col_indices())
123            elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
124                x_meta1, y_meta1 = (x.ccol_indices(), y.ccol_indices())
125                x_meta2, y_meta2 = (x.row_indices(), y.row_indices())
126            else:
127                assert 0  # unreachable
128            self.assertEqual(x_meta1, y_meta1, exact_layout=True)
129            self.assertEqual(x_meta2, y_meta2, exact_layout=True)
130            self.assertEqual(x.values(), y.values(), exact_layout=True)
131
132    @parametrize("dtype", DTYPES)
133    @parametrize("itype", ITYPES)
134    @all_sparse_layouts("layout")
135    def test_idnet(self, dtype, itype, layout):
136        net = IdNet()
137        for sparse_input in self.generate_simple_inputs(
138            layout,
139            device="cpu",
140            dtype=dtype,
141            index_dtype=itype,
142        ):
143            # Build the traced graph.
144            prog = torch.export.export(net, (sparse_input,))
145            # Test arg/output.
146            for i, node in enumerate(prog.graph.nodes):
147                meta = node.meta.get("val", None)
148                if i == 0:
149                    self.assertEqualMeta(meta, sparse_input)
150                else:
151                    self.assertEqual(meta, None)
152
153    @parametrize("dtype", DTYPES)
154    @parametrize("itype", ITYPES)
155    @all_sparse_layouts("layout")
156    def test_sumnet(self, dtype, itype, layout):
157        net = SumNet()
158        for sparse_input in self.generate_simple_inputs(
159            layout,
160            device="cpu",
161            dtype=dtype,
162            index_dtype=itype,
163        ):
164            result = net(sparse_input)
165            # Build the traced graph.
166            prog = torch.export.export(net, (sparse_input,))
167            # Test arg/sum/output.
168            for i, node in enumerate(prog.graph.nodes):
169                meta = node.meta.get("val", None)
170                if i == 0:
171                    self.assertEqualMeta(meta, sparse_input)
172                elif i == 1:
173                    self.assertEqualMeta(meta, result)
174                else:
175                    self.assertEqual(meta, None)
176
177    @parametrize("dtype", DTYPES)
178    @parametrize("itype", ITYPES)
179    @all_sparse_layouts("layout")
180    def test_eltwisenet(self, dtype, itype, layout):
181        net = EltwiseNet()
182        for sparse_input in self.generate_simple_inputs(
183            layout,
184            device="cpu",
185            dtype=dtype,
186            index_dtype=itype,
187        ):
188            result = net(sparse_input)
189            # Build the traced graph.
190            prog = torch.export.export(net, (sparse_input,))
191            # Test arg/neg/abs/mul/relu/output.
192            for i, node in enumerate(prog.graph.nodes):
193                meta = node.meta.get("val", None)
194                if i <= 4:
195                    self.assertEqualMeta(meta, result)
196                else:
197                    self.assertEqual(meta, None)
198
199    @parametrize("dtype", DTYPES)
200    @parametrize("itype", ITYPES)
201    @all_sparse_layouts("layout")
202    def test_todensenet(self, dtype, itype, layout):
203        net = ToDenseNet()
204        for sparse_input in self.generate_simple_inputs(
205            layout,
206            device="cpu",
207            dtype=dtype,
208            index_dtype=itype,
209        ):
210            result = net(sparse_input)
211            # Build the traced graph.
212            prog = torch.export.export(net, (sparse_input,))
213            # Test arg/todense/output.
214            for i, node in enumerate(prog.graph.nodes):
215                meta = node.meta.get("val", None)
216                if i == 0:
217                    self.assertEqualMeta(meta, sparse_input)
218                elif i == 1:
219                    self.assertEqualMeta(meta, result)
220                else:
221                    self.assertEqual(meta, None)
222
223    def test_add(self):
224        net = AddNet()
225        Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
226        A = torch.tensor(
227            [
228                [0.0, 1.0, 0.0, 0.0],
229                [0.0, 0.0, 0.0, 2.0],
230                [0.0, 0.0, 1.0, 1.0],
231                [3.0, 0.0, 3.0, 0.0],
232            ],
233            dtype=torch.float32,
234        )
235        S = A.to_sparse_csr()
236        result = net(S, Y)
237        # Build the traced graph.
238        prog = torch.export.export(net, (S, Y))
239        # Test args/add/output.
240        for i, node in enumerate(prog.graph.nodes):
241            meta = node.meta.get("val", None)
242            if i == 0:
243                self.assertEqualMeta(meta, S)
244            elif i == 1:
245                self.assertEqualMeta(meta, Y)
246            elif i == 2:
247                self.assertEqualMeta(meta, result)
248            else:
249                self.assertEqual(meta, None)
250
251    def test_activation_coo(self):
252        net = SparseActivationCOO()
253        x = [torch.randn(3, 3) for _ in range(3)]
254        result = net(x)
255        # Build the traced graph.
256        prog = torch.export.export(net, args=(x,))
257        # Test args/to_sparse/output.
258        for i, node in enumerate(prog.graph.nodes):
259            meta = node.meta.get("val", None)
260            if i <= 2:
261                self.assertEqualMeta(meta, x[i])
262            elif i <= 5:
263                self.assertEqualMeta(meta, result[i - 3])
264            else:
265                self.assertEqual(meta, None)
266
267    def test_activation_csr(self):
268        net = SparseActivationCSR()
269        x = [torch.randn(3, 3) for _ in range(3)]
270        result = net(x)
271        # Build the traced graph.
272        prog = torch.export.export(net, args=(x,))
273        # Test args/to_sparse/output.
274        for i, node in enumerate(prog.graph.nodes):
275            meta = node.meta.get("val", None)
276            if i <= 2:
277                self.assertEqualMeta(meta, x[i])
278            elif i <= 5:
279                self.assertEqualMeta(meta, result[i - 3])
280            else:
281                self.assertEqual(meta, None)
282
283
284instantiate_parametrized_tests(TestSparseProp)
285
286if __name__ == "__main__":
287    run_tests()
288