1# Owner(s): ["module: sparse"] 2# 3# Test to ensure sparsity information propagates properly into traced graph. 4# 5 6import sys 7import unittest 8 9import torch 10from torch._dynamo.config import is_fbcode 11from torch._subclasses.fake_tensor import FakeTensor 12from torch.testing._internal.common_utils import ( 13 instantiate_parametrized_tests, 14 parametrize, 15 run_tests, 16 subtest, 17 TestCase, 18) 19 20 21# Various data types (preserved over operations). 22DTYPES = [ 23 torch.int64, 24 torch.float16, 25 torch.bfloat16, 26 torch.float32, 27 torch.float64, 28] 29 30# Various index types. 31ITYPES = [torch.int32, torch.int64] 32 33 34# Constructs a subtest for every sparse layout currently supported in torch.sparse. 35def all_sparse_layouts(test_name="layout"): 36 return parametrize( 37 test_name, 38 [ 39 subtest(torch.sparse_coo, name="SparseCOO"), 40 subtest(torch.sparse_csr, name="SparseCSR"), 41 subtest(torch.sparse_csc, name="SparseCSC"), 42 subtest(torch.sparse_bsr, name="SparseBSR"), 43 subtest(torch.sparse_bsc, name="SparseBSC"), 44 ], 45 ) 46 47 48# 49# Various network examples. 50# 51 52 53class IdNet(torch.nn.Module): 54 def forward(self, x): 55 return x 56 57 58class SumNet(torch.nn.Module): 59 def forward(self, x): 60 return x.sum() 61 62 63class EltwiseNet(torch.nn.Module): 64 def forward(self, x): 65 return torch.nn.functional.relu(2 * torch.abs(-x)) 66 67 68class ToDenseNet(torch.nn.Module): 69 def forward(self, x): 70 return x.to_dense() 71 72 73class AddNet(torch.nn.Module): 74 def forward(self, x, y): 75 return torch.add(x, y) 76 77 78class SparseActivationCOO(torch.nn.Module): 79 def forward(self, x): 80 return [xi.to_sparse() for xi in x] 81 82 83class SparseActivationCSR(torch.nn.Module): 84 def forward(self, x): 85 return [xi.to_sparse_csr() for xi in x] 86 87 88# 89# The test driver. 90# 91 92 93@unittest.skipIf(is_fbcode(), "See torch._dynamo.config") 94@unittest.skipIf( 95 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 96) 97class TestSparseProp(TestCase): 98 def setUp(self): 99 TestCase.setUp(self) 100 101 def assertEqualMeta(self, x, y): 102 self.assertIsInstance(x, FakeTensor) 103 self.assertIsInstance(y, torch.Tensor) 104 105 # Convert expected value to meta for comparison. 106 y = y.to("meta") 107 self.assertEqual(x, y, exact_layout=True, exact_is_coalesced=True) 108 109 # When x or y is a meta tensor (say, `x.device == "meta"`), then 110 # assertEqual(x, y) compares only x and y attributes but skips 111 # comparing their values. In the case of sparse tensors, this means 112 # that comparing indices and values attributes are skipped as well, 113 # which is why we are doing that explicitly below. 114 if x.layout is torch.strided: 115 pass 116 elif x.layout is torch.sparse_coo: 117 self.assertEqual(x._indices(), y._indices(), exact_layout=True) 118 self.assertEqual(x._values(), y._values(), exact_layout=True) 119 else: 120 if x.layout in {torch.sparse_csr, torch.sparse_bsr}: 121 x_meta1, y_meta1 = (x.crow_indices(), y.crow_indices()) 122 x_meta2, y_meta2 = (x.col_indices(), y.col_indices()) 123 elif x.layout in {torch.sparse_csc, torch.sparse_bsc}: 124 x_meta1, y_meta1 = (x.ccol_indices(), y.ccol_indices()) 125 x_meta2, y_meta2 = (x.row_indices(), y.row_indices()) 126 else: 127 assert 0 # unreachable 128 self.assertEqual(x_meta1, y_meta1, exact_layout=True) 129 self.assertEqual(x_meta2, y_meta2, exact_layout=True) 130 self.assertEqual(x.values(), y.values(), exact_layout=True) 131 132 @parametrize("dtype", DTYPES) 133 @parametrize("itype", ITYPES) 134 @all_sparse_layouts("layout") 135 def test_idnet(self, dtype, itype, layout): 136 net = IdNet() 137 for sparse_input in self.generate_simple_inputs( 138 layout, 139 device="cpu", 140 dtype=dtype, 141 index_dtype=itype, 142 ): 143 # Build the traced graph. 144 prog = torch.export.export(net, (sparse_input,)) 145 # Test arg/output. 146 for i, node in enumerate(prog.graph.nodes): 147 meta = node.meta.get("val", None) 148 if i == 0: 149 self.assertEqualMeta(meta, sparse_input) 150 else: 151 self.assertEqual(meta, None) 152 153 @parametrize("dtype", DTYPES) 154 @parametrize("itype", ITYPES) 155 @all_sparse_layouts("layout") 156 def test_sumnet(self, dtype, itype, layout): 157 net = SumNet() 158 for sparse_input in self.generate_simple_inputs( 159 layout, 160 device="cpu", 161 dtype=dtype, 162 index_dtype=itype, 163 ): 164 result = net(sparse_input) 165 # Build the traced graph. 166 prog = torch.export.export(net, (sparse_input,)) 167 # Test arg/sum/output. 168 for i, node in enumerate(prog.graph.nodes): 169 meta = node.meta.get("val", None) 170 if i == 0: 171 self.assertEqualMeta(meta, sparse_input) 172 elif i == 1: 173 self.assertEqualMeta(meta, result) 174 else: 175 self.assertEqual(meta, None) 176 177 @parametrize("dtype", DTYPES) 178 @parametrize("itype", ITYPES) 179 @all_sparse_layouts("layout") 180 def test_eltwisenet(self, dtype, itype, layout): 181 net = EltwiseNet() 182 for sparse_input in self.generate_simple_inputs( 183 layout, 184 device="cpu", 185 dtype=dtype, 186 index_dtype=itype, 187 ): 188 result = net(sparse_input) 189 # Build the traced graph. 190 prog = torch.export.export(net, (sparse_input,)) 191 # Test arg/neg/abs/mul/relu/output. 192 for i, node in enumerate(prog.graph.nodes): 193 meta = node.meta.get("val", None) 194 if i <= 4: 195 self.assertEqualMeta(meta, result) 196 else: 197 self.assertEqual(meta, None) 198 199 @parametrize("dtype", DTYPES) 200 @parametrize("itype", ITYPES) 201 @all_sparse_layouts("layout") 202 def test_todensenet(self, dtype, itype, layout): 203 net = ToDenseNet() 204 for sparse_input in self.generate_simple_inputs( 205 layout, 206 device="cpu", 207 dtype=dtype, 208 index_dtype=itype, 209 ): 210 result = net(sparse_input) 211 # Build the traced graph. 212 prog = torch.export.export(net, (sparse_input,)) 213 # Test arg/todense/output. 214 for i, node in enumerate(prog.graph.nodes): 215 meta = node.meta.get("val", None) 216 if i == 0: 217 self.assertEqualMeta(meta, sparse_input) 218 elif i == 1: 219 self.assertEqualMeta(meta, result) 220 else: 221 self.assertEqual(meta, None) 222 223 def test_add(self): 224 net = AddNet() 225 Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4) 226 A = torch.tensor( 227 [ 228 [0.0, 1.0, 0.0, 0.0], 229 [0.0, 0.0, 0.0, 2.0], 230 [0.0, 0.0, 1.0, 1.0], 231 [3.0, 0.0, 3.0, 0.0], 232 ], 233 dtype=torch.float32, 234 ) 235 S = A.to_sparse_csr() 236 result = net(S, Y) 237 # Build the traced graph. 238 prog = torch.export.export(net, (S, Y)) 239 # Test args/add/output. 240 for i, node in enumerate(prog.graph.nodes): 241 meta = node.meta.get("val", None) 242 if i == 0: 243 self.assertEqualMeta(meta, S) 244 elif i == 1: 245 self.assertEqualMeta(meta, Y) 246 elif i == 2: 247 self.assertEqualMeta(meta, result) 248 else: 249 self.assertEqual(meta, None) 250 251 def test_activation_coo(self): 252 net = SparseActivationCOO() 253 x = [torch.randn(3, 3) for _ in range(3)] 254 result = net(x) 255 # Build the traced graph. 256 prog = torch.export.export(net, args=(x,)) 257 # Test args/to_sparse/output. 258 for i, node in enumerate(prog.graph.nodes): 259 meta = node.meta.get("val", None) 260 if i <= 2: 261 self.assertEqualMeta(meta, x[i]) 262 elif i <= 5: 263 self.assertEqualMeta(meta, result[i - 3]) 264 else: 265 self.assertEqual(meta, None) 266 267 def test_activation_csr(self): 268 net = SparseActivationCSR() 269 x = [torch.randn(3, 3) for _ in range(3)] 270 result = net(x) 271 # Build the traced graph. 272 prog = torch.export.export(net, args=(x,)) 273 # Test args/to_sparse/output. 274 for i, node in enumerate(prog.graph.nodes): 275 meta = node.meta.get("val", None) 276 if i <= 2: 277 self.assertEqualMeta(meta, x[i]) 278 elif i <= 5: 279 self.assertEqualMeta(meta, result[i - 3]) 280 else: 281 self.assertEqual(meta, None) 282 283 284instantiate_parametrized_tests(TestSparseProp) 285 286if __name__ == "__main__": 287 run_tests() 288