# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import torch from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import OpSchema from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, ) aten = torch.ops.aten class CommonRulesTest(DTensorTestBase): @property def world_size(self) -> int: # hard code world size to 4 as we need to test # at least with 2d mesh return 4 def _gen_tensor_meta(self, shape): empty_tensor = torch.empty(shape) return TensorMeta( empty_tensor.shape, empty_tensor.stride(), empty_tensor.dtype, ) @with_comms def test_einop_basic_propagation(self): # plain einsum, mm mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) mm_call = aten.mm.default # propagate col-wise sharding mat1, mat2 = [-1, -1], [-1, 0] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) # propagate row-wise sharding mat1, mat2 = [0, -1], [-1, -1] mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1]) # generate partial mat1, mat2 = [-1, 0], [0, -1] mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) @with_comms def test_einop_pointwise_propagation(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) add_call = aten.add.Tensor # addition mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8])) mat1 = [0, -1] mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) output_sharding = einop_rule( "ij,ij->ij", OpSchema(add_call, (mat1_spec, mat1_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1]) # broadcast addition mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8])) mat1 = [-1, 0, -1] mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([2])) mat2_spec = DTensorSpec.from_dim_map( mesh, [-1], [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "ijk,k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0, -1]) # broadcast to a common shape mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 8, 8])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([1, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, [0, -1, -1], [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, [-1, -1], [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "ijk,1k->ijk", OpSchema(add_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) @with_comms def test_einop_merge_sharding(self): # 2d mesh einop merge sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) mesh = DeviceMesh(self.device_type, mesh_shape) mm_call = aten.mm.default mat1, mat2 = [0, -1], [-1, 1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) @with_comms def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) mesh = DeviceMesh(self.device_type, mesh_shape) mm_call = aten.mm.default mat1, mat2 = [0, -1], [-1, -1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([4, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [1], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) # if not turn on linearity, partial sum is not eligible to propagate, we return # suggestion to reshard inputs with no partial sum (i.e. all_reduce one input) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) self.assertIsNone(output_sharding.output_spec) suggestions = output_sharding.redistribute_schema self.assertIsNotNone(suggestions) suggested_spec = suggestions.args_schema[0] self.assertFalse(suggested_spec.placements[1].is_partial()) # einop prop with linearity on mm, should give back suggestion # on converting placements to partial output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}), linearity=True, ) self.assertIsNone(output_sharding.output_spec) suggestions = output_sharding.redistribute_schema self.assertIsNotNone(suggestions) mat2_spec = suggestions.args_schema[1] # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) # einop prop with linearity on point-wise, should give back suggestion # on converting placements to partial add_call = aten.add.Tensor mat1, mat2 = [0, -1], [0, -1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 6])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [1], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {}), linearity=True, ) self.assertIsNone(output_sharding.output_spec) suggestions = output_sharding.redistribute_schema self.assertIsNotNone(suggestions) mat2_spec = suggestions.args_schema[1] # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) @with_comms def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim mesh_shape = torch.arange(self.world_size) mesh = DeviceMesh(self.device_type, mesh_shape) mm_call = aten.mm.default mat1, mat2 = [0, -1], [0, -1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 12])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = einop_rule( "mk,kn->mn", OpSchema(mm_call, (mat1_spec, mat2_spec), {}) ) output_spec = output_sharding.output_spec self.assertIsNone(output_spec) self.assertIsNotNone(output_sharding.redistribute_schema) # ensure that the suggestion is to reshard the second # arg by all_gather its tensor dim sharding schema_suggestion = output_sharding.redistribute_schema self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) @with_comms def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) mesh = DeviceMesh(self.device_type, mesh_shape) add_call = aten.add.Tensor mat1, mat2 = [0, -1], [1, -1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {})) @with_comms def test_pointwise_rules_broadcasting(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) where_call = aten.where.self inp1, inp2, inp3 = [0], [], [-1, -1] inp1_tensor_meta = self._gen_tensor_meta(torch.Size([8])) inp2_tensor_meta = self._gen_tensor_meta(torch.Size([])) inp3_tensor_meta = self._gen_tensor_meta(torch.Size([1, 1])) condition = DTensorSpec.from_dim_map( mesh, inp1, [], tensor_meta=inp1_tensor_meta ) self_tensor = DTensorSpec.from_dim_map( mesh, inp2, [], tensor_meta=inp2_tensor_meta ) other_tensor = DTensorSpec.from_dim_map( mesh, inp3, [], tensor_meta=inp3_tensor_meta ) # propagate point-wise sharding with broadcasting output_sharding = pointwise_rule( OpSchema(where_call, (condition, self_tensor, other_tensor), {}) ) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) @with_comms def test_pointwise_rules_suggestion(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) lerp_call = aten.lerp.Scalar # propagate point-wise sharding inp1, inp2 = [-1, -1], [-1, 0] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([8, 4])) mat1_spec = DTensorSpec.from_dim_map( mesh, inp1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, inp2, [], tensor_meta=mat2_tensor_meta ) # adding a positional argument -1 to arg schema output_sharding = pointwise_rule( OpSchema(lerp_call, (mat1_spec, mat2_spec, -1), {}) ) self.assertIsNone(output_sharding.output_spec) self.assertIsNotNone(output_sharding.redistribute_schema) # ensure that the suggestion from pointwise rules still have # the positional args that are not DTensorSpec schema_suggestion = output_sharding.redistribute_schema self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) @with_comms def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) mesh = DeviceMesh(self.device_type, mesh_shape) add_call = aten.add.Tensor # basic case to test implicit broadcasting shape alignment mat1, mat2 = [-1, 0], [0] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([20, 6])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([6])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {})) output_spec = output_sharding.output_spec self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) # more advanced case that needs reshard one input to align sharding mat1, mat2 = [0, -1, -1, 1], [0, -1, 1] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 1, 8])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {})) output_spec = output_sharding.output_spec self.assertIsNone(output_spec) self.assertIsNotNone(output_sharding.redistribute_schema) # ensure that the suggestion is to reshard the first # arg by all_gather first tensor dim sharding schema_suggestion = output_sharding.redistribute_schema self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) @with_comms def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) mesh = DeviceMesh(self.device_type, mesh_shape) add_call = aten.add_.Tensor # more advanced case that needs reshard one input to align sharding mat1, mat2 = [0, -1, 1], [-1, -1, 0] mat1_tensor_meta = self._gen_tensor_meta(torch.Size([12, 4, 8])) mat2_tensor_meta = self._gen_tensor_meta(torch.Size([12, 1, 8])) mat1_spec = DTensorSpec.from_dim_map( mesh, mat1, [], tensor_meta=mat1_tensor_meta ) mat2_spec = DTensorSpec.from_dim_map( mesh, mat2, [], tensor_meta=mat2_tensor_meta ) output_sharding = pointwise_rule(OpSchema(add_call, (mat1_spec, mat2_spec), {})) output_spec = output_sharding.output_spec self.assertIsNone(output_spec) self.assertIsNotNone(output_sharding.redistribute_schema) # ensure that the suggestion is to reshard the second # arg as we should enforce the sharding of the first arg schema_suggestion = output_sharding.redistribute_schema self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1) if __name__ == "__main__": run_tests()