1# Owner(s): ["oncall: distributed"] 2 3import sys 4 5import torch 6from torch.distributed._tensor import ( 7 distribute_module, 8 distribute_tensor, 9 DTensor, 10 Replicate, 11 Shard, 12) 13from torch.distributed.tensor.debug import CommDebugMode 14from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 15from torch.testing._internal.distributed._tensor.common_dtensor import ( 16 DTensorTestBase, 17 with_comms, 18) 19 20 21if TEST_WITH_DEV_DBG_ASAN: 22 print( 23 "Skip dev-asan as torch + multiprocessing spawn have known issues", 24 file=sys.stderr, 25 ) 26 sys.exit(0) 27 28 29funcol = torch.ops.c10d_functional 30 31 32class TestEmbeddingOp(DTensorTestBase): 33 def _apply_sharding(self, embedding_mod, shard_dim, device_mesh): 34 def shard_embedding_fn(name, module, device_mesh): 35 for name, param in module.named_parameters(): 36 dist_param = torch.nn.Parameter( 37 distribute_tensor(param, device_mesh, [Shard(shard_dim)]) 38 ) 39 module.register_parameter(name, dist_param) 40 41 sharded_embedding = distribute_module( 42 embedding_mod, device_mesh, shard_embedding_fn 43 ) 44 return sharded_embedding 45 46 def _run_embedding_op_test( 47 self, 48 device_mesh, 49 shard_dim, 50 input_size, 51 num_embeddings, 52 embedding_dim, 53 **kwargs, 54 ): 55 # Use same seed. 56 torch.manual_seed(0) 57 local_embedding = torch.nn.Embedding( 58 num_embeddings, 59 embedding_dim, 60 device=self.device_type, 61 **kwargs, 62 ) 63 sharded_embedding = torch.nn.Embedding( 64 num_embeddings, 65 embedding_dim, 66 device=self.device_type, 67 **kwargs, 68 ) 69 70 # Shard the parameter of local embedding and set it to sharded embedding. 71 sharded_embedding.weight = torch.nn.Parameter( 72 local_embedding.weight.clone().detach() 73 ) 74 75 sharded_embedding = self._apply_sharding( 76 sharded_embedding, shard_dim, device_mesh 77 ) 78 79 # Run sharded computation 80 torch.manual_seed(10) 81 inp = torch.randint( 82 0, num_embeddings, tuple(input_size), device=self.device_type 83 ) 84 target = torch.empty( 85 *inp.size(), embedding_dim, dtype=torch.float, device=self.device_type 86 ).random_(0, 1) 87 dist_inp = distribute_tensor(inp, device_mesh, [Replicate()]) 88 89 # fwd computation, ensure no comm happened 90 with CommDebugMode() as fwd_mode: 91 dist_output = sharded_embedding(dist_inp) 92 self.assertEqual(fwd_mode.get_total_counts(), 0) 93 94 output = dist_output.full_tensor() 95 # Run local computation 96 local_output = local_embedding(inp) 97 98 # Verify 99 self.assertEqual(local_output, output) 100 101 # Use a sample cross entry loss to verify backward and grad computation. 102 loss = torch.nn.CrossEntropyLoss() 103 emb_loss = loss( 104 output, 105 target, 106 ) 107 emb_dup_loss = loss( 108 local_output, 109 target, 110 ) 111 112 # local embedding backward 113 emb_dup_loss.backward() 114 115 # sharded embedding bwd computation, ensure no comm happened 116 with CommDebugMode() as bwd_mode: 117 emb_loss.backward() 118 self.assertEqual(bwd_mode.get_total_counts(), 0) 119 120 gradient = sharded_embedding.weight.grad.full_tensor() 121 122 local_grad = local_embedding.weight.grad 123 124 # Verify gradient. 125 self.assertEqual(gradient, local_grad) 126 127 # Validate for torch.nn.functional.embedding version. 128 local_output = torch.nn.functional.embedding( 129 inp, 130 local_embedding.weight, 131 **kwargs, 132 ) 133 sharded_output = torch.nn.functional.embedding( 134 DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False), 135 sharded_embedding.weight, 136 **kwargs, 137 ) 138 self.assertEqual(local_output, sharded_output.full_tensor()) 139 140 @with_comms 141 def test_sharded_embedding_colwise(self): 142 mesh = self.build_device_mesh() 143 self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12) 144 self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11) 145 self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13) 146 self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16) 147 self._run_embedding_op_test(mesh, 1, [4], 15, 14) 148 self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10) 149 self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12) 150 151 @with_comms 152 def test_sharded_embedding_colwise_max_norm_errors(self): 153 mesh = self.build_device_mesh() 154 with self.assertRaisesRegex( 155 NotImplementedError, 156 "aten.embedding_renorm_.default does not have a sharding strategy registered.", 157 ): 158 self._run_embedding_op_test( 159 mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0 160 ) 161 162 @with_comms 163 def test_sharded_embedding_rowwise(self): 164 mesh = self.build_device_mesh() 165 # test correctness 166 self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22) 167 self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22) 168 self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10) 169 170 from torch.distributed.tensor._ops._embedding_ops import _MaskPartial 171 172 # test collectives 173 embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type) 174 sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh) 175 inp = torch.randint(0, 10, (8, 8), device=self.device_type) 176 replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) 177 output = sharded_embedding(replicated_inp) 178 self.assertIsInstance(output.placements[0], _MaskPartial) 179 180 comm_mode = CommDebugMode() 181 182 with comm_mode: 183 output.full_tensor() 184 self.assertEqual(comm_mode.get_total_counts(), 1) 185 self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1) 186 187 @with_comms 188 def test_multiple_embeddings_rowwise(self): 189 mesh = self.build_device_mesh() 190 191 inp = torch.randint(0, 10, (4, 4), device=self.device_type) 192 replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) 193 194 from torch.distributed.tensor._ops._embedding_ops import _MaskPartial 195 196 # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial 197 # and MaskBuffer, because of cache hit from sharding propagation 198 199 emb1 = torch.nn.Embedding(10, 23, device=self.device_type) 200 sharded_emb1 = self._apply_sharding(emb1, 0, mesh) 201 output1 = sharded_emb1(replicated_inp) 202 203 emb2 = torch.nn.Embedding(10, 29, device=self.device_type) 204 sharded_emb2 = self._apply_sharding(emb2, 0, mesh) 205 output2 = sharded_emb2(replicated_inp) 206 207 partial_placement1 = output1.placements[0] 208 self.assertIsInstance(partial_placement1, _MaskPartial) 209 output1.full_tensor() 210 211 partial_placement2 = output2.placements[0] 212 self.assertIsInstance(partial_placement2, _MaskPartial) 213 output2.full_tensor() 214 215 self.assertTrue(id(partial_placement1), id(partial_placement2)) 216 217 # case 2: two embeddings with the same logical_dim_size, but different logical_shape 218 # thus they will have different _MaskPartial placements (with no cache hit) 219 220 emb3 = torch.nn.Embedding(10, 29, device=self.device_type) 221 sharded_emb3 = self._apply_sharding(emb3, 0, mesh) 222 output3 = sharded_emb3(replicated_inp) 223 partial_placement3 = output3.placements[0] 224 self.assertIsInstance(partial_placement3, _MaskPartial) 225 output2.full_tensor() 226 227 # not equal because of different logical_shape, despite of same logical_dim_size 228 self.assertNotEqual(partial_placement1, partial_placement3) 229 230 231if __name__ == "__main__": 232 run_tests() 233