# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import itertools from typing import cast, List import torch import torch.distributed as dist from torch import rand, randn, Tensor from torch.distributed._tensor import ( DeviceMesh, distribute_tensor, init_device_mesh, Replicate, Shard, ) from torch.distributed._tensor.placement_types import Placement from torch.distributed.tensor._ops._view_ops import ( Broadcast, dim_maps, Flatten, InputDim, Repeat, Singleton, Split, view_groups, ) from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, ) from torch.utils import _pytree as pytree class TestViewOps(DTensorTestBase): @property def world_size(self) -> int: return 6 def test_view_groups(self): self.assertEqual( view_groups([2, 3], [3, 2]), ( Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), ), ) self.assertEqual( view_groups([3, 4, 5], [12, 5]), (Flatten((InputDim(0), InputDim(1))), InputDim(2)), ) self.assertEqual( view_groups([2, 3, 4, 5, 7], [12, 70]), ( Split( Flatten( ( InputDim(0), InputDim(1), InputDim(2), InputDim(3), InputDim(4), ) ), (12, 70), 0, ), Split( Flatten( ( InputDim(0), InputDim(1), InputDim(2), InputDim(3), InputDim(4), ) ), (12, 70), 1, ), ), ) self.assertEqual( view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]), ( Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0), Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1), Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0), Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1), ), ) self.assertEqual( view_groups([3, 4, 8, 3], [12, 4, 2, 3]), ( Flatten((InputDim(0), InputDim(1))), Split(InputDim(2), (4, 2), 0), Split(InputDim(2), (4, 2), 1), InputDim(3), ), ) self.assertEqual( view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]), ( Singleton(), InputDim(0), Split(InputDim(1), (2, 4, 3), 0), Split(InputDim(1), (2, 4, 3), 1), Singleton(), Split(InputDim(1), (2, 4, 3), 2), Singleton(), ), ) self.assertEqual( view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), ( Flatten((InputDim(2), InputDim(3))), InputDim(4), InputDim(5), Singleton(), ), ) self.assertEqual( view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]), ( Split(InputDim(2), (3, 4), 0), Split(InputDim(2), (3, 4), 1), InputDim(3), Flatten((InputDim(6), InputDim(7))), ), ) self.assertEqual( view_groups([2, 3, 4], [2, -1, 4]), (InputDim(0), InputDim(1), InputDim(2)), ) def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): dim_map = dim_maps[op] rules = dim_map(*args, **kwargs) outputs = op(*args, **kwargs) flat_args = pytree.arg_tree_leaves(*args) in_shape = flat_args[0].shape no_shard_dims = set() for rule in rules: if isinstance(rule, Repeat): if isinstance(rule.input_dim, InputDim): no_shard_dims.add(rule.input_dim.input_dim) elif isinstance(rule, Flatten): for dim in rule.input_dims[1:]: if isinstance(dim, InputDim): no_shard_dims.add(dim.input_dim) elif isinstance(rule, Split): if isinstance(rule.input_dim, Flatten): for dim in rule.input_dim.input_dims[1:]: if isinstance(dim, InputDim): no_shard_dims.add(dim.input_dim) if op == torch.unbind: no_shard_dims.add(kwargs.get("dim", 0)) sharding_choices = cast(List[Placement], [Replicate()]) + [ Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims ] all_sharding_choices = itertools.product( *(device_mesh.ndim * [sharding_choices]) ) for in_shard in all_sharding_choices: in_dt = distribute_tensor(args[0], device_mesh, in_shard) comm_mode = CommDebugMode() with comm_mode: out_dt = op(in_dt, *args[1:], **kwargs) self.assertEqual( comm_mode.get_total_counts(), 0, "Expected no redistribution." ) full_out = out_dt.full_tensor() if dist.get_rank() == 0: self.assertEqual(outputs, full_out) def dimmap_test(self, op, args, expected_rule_output): rules = dim_maps[op](*args) self.assertEqual(rules, expected_rule_output) self.call_dt_test(op, args, {}, self.device_mesh) @with_comms def test_view_ops(self): self.device_mesh = DeviceMesh( self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) ) self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1))) self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton())) self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0))) self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1))) self.dimmap_test( torch.atleast_2d, (randn(24, 36, 48),), (InputDim(0), InputDim(1), InputDim(2)), ) self.dimmap_test( torch.atleast_3d, (randn(()),), (Singleton(), Singleton(), Singleton()), ) self.dimmap_test( torch.atleast_3d, (randn(24),), (Singleton(), InputDim(0), Singleton()), ) self.dimmap_test( torch.atleast_3d, (randn(24, 36),), (InputDim(0), InputDim(1), Singleton()), ) self.dimmap_test( torch.atleast_3d, (randn(24, 36, 42),), (InputDim(0), InputDim(1), InputDim(2)), ) self.dimmap_test( torch.atleast_3d, (randn(24, 36, 42, 24),), (InputDim(0), InputDim(1), InputDim(2), InputDim(3)), ) with self.assertRaises(AssertionError): dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4)) self.dimmap_test( torch.broadcast_to, (rand(24, 36), (1, 24, 36)), (Singleton(), InputDim(0), InputDim(1)), ) self.dimmap_test( torch.broadcast_to, (rand(24, 36), (42, 24, 36)), (Broadcast(Singleton(), 42), InputDim(0), InputDim(1)), ) self.dimmap_test( torch.broadcast_to, (rand(24, 1, 36), (12, 24, 24, 36)), ( Broadcast(Singleton(), 12), InputDim(0), Broadcast(InputDim(1), 24), InputDim(2), ), ) self.dimmap_test( torch.broadcast_to, (rand(24, 36), (-1, 36)), (InputDim(0), InputDim(1)), ) self.dimmap_test( torch.broadcast_to, (rand(24, 1, 36), (-1, 1, 36)), (InputDim(0), InputDim(1), InputDim(2)), ) self.dimmap_test( torch.broadcast_to, (randn(36, 1, 24), (12, 36, 42, 24)), ( Broadcast(Singleton(), 12), InputDim(0), Broadcast(InputDim(1), 42), InputDim(2), ), ) self.dimmap_test( Tensor.expand, (randn(24, 1, 36, 1), 36, 24, 42, -1, 24), ( Broadcast(Singleton(), 36), InputDim(0), Broadcast(InputDim(1), 42), InputDim(2), Broadcast(InputDim(3), 24), ), ) self.dimmap_test( Tensor.expand, (randn(24, 1, 36, 1), (36, 24, 42, -1, 24)), ( Broadcast(Singleton(), 36), InputDim(0), Broadcast(InputDim(1), 42), InputDim(2), Broadcast(InputDim(3), 24), ), ) self.dimmap_test( torch.flatten, (randn(24, 36),), (Flatten((InputDim(0), InputDim(1))),), ) self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),)) self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),)) self.dimmap_test( torch.movedim, (randn(12, 24, 48, 96), 1, 2), (InputDim(0), InputDim(2), InputDim(1), InputDim(3)), ) self.dimmap_test( torch.movedim, (randn(6, 12, 24), 1, 0), (InputDim(1), InputDim(0), InputDim(2)), ) self.dimmap_test( torch.movedim, (randn(24, 12, 6), (1, 2), (0, 1)), (InputDim(1), InputDim(2), InputDim(0)), ) self.dimmap_test( torch.movedim, (randn(24, 6, 12), (0, 2, 1), (2, 1, 0)), (InputDim(1), InputDim(2), InputDim(0)), ) self.dimmap_test( torch.movedim, (randn(24, 12), (1, 0), (0, 1)), (InputDim(1), InputDim(0)), ) self.dimmap_test( torch.movedim, (randn(36, 24, 12), (1, 2), (0, 1)), (InputDim(1), InputDim(2), InputDim(0)), ) self.dimmap_test( torch.movedim, (randn(36, 24, 12), (1, 2), (-3, -2)), (InputDim(1), InputDim(2), InputDim(0)), ) self.dimmap_test( torch.permute, (randn(24, 36, 42), (2, 0, 1)), (InputDim(2), InputDim(0), InputDim(1)), ) self.dimmap_test( torch.permute, (randn(24, 36, 42), (-1, -3, -2)), (InputDim(2), InputDim(0), InputDim(1)), ) self.dimmap_test( torch.ravel, (randn(24, 36),), (Flatten((InputDim(0), InputDim(1))),), ) self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),)) self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),)) self.dimmap_test( Tensor.repeat, (randn(24, 36), 1, 2, 1, 1, 2), ( Singleton(), Broadcast(Singleton(), 2), Singleton(), InputDim(0), Repeat(InputDim(1), 2), ), ) self.dimmap_test( torch.reshape, (randn(6, 12, 24), (72, 24)), (Flatten((InputDim(0), InputDim(1))), InputDim(2)), ) self.dimmap_test( torch.tile, (randn(24, 36), (1, 2, 1, 1, 2)), ( Singleton(), Broadcast(Singleton(), 2), Singleton(), InputDim(0), Repeat(InputDim(1), 2), ), ) self.dimmap_test( torch.tile, (randn(42, 24, 36), (1, 3)), (InputDim(0), InputDim(1), Repeat(InputDim(2), 3)), ) self.dimmap_test( torch.transpose, (randn(24, 60, 42, 60), 2, 0), (InputDim(2), InputDim(1), InputDim(0), InputDim(3)), ) self.dimmap_test( torch.transpose, (randn(24, 60, 42, 60), -1, 0), (InputDim(3), InputDim(1), InputDim(2), InputDim(0)), ) self.dimmap_test( torch.unsqueeze, (randn(42, 24, 36), 1), (InputDim(0), Singleton(), InputDim(1), InputDim(2)), ) self.dimmap_test( Tensor.view, (randn(6, 12, 24), 72, 24), (Flatten((InputDim(0), InputDim(1))), InputDim(2)), ) self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),)) self.dimmap_test( Tensor.view, (randn(1, 1, 42, 24), -1), (Flatten((InputDim(2), InputDim(3))),), ) self.dimmap_test( Tensor.view, (randn(1, 1, 42, 1, 24, 1), -1), (Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),), ) self.dimmap_test( Tensor.view, (randn(48, 35, 26), (24, 4, 35, 13)), ( Split( Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), group_shape=(24, 4, 35, 13), split_id=0, ), Split( Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), group_shape=(24, 4, 35, 13), split_id=1, ), Split( Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), group_shape=(24, 4, 35, 13), split_id=2, ), Split( Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), group_shape=(24, 4, 35, 13), split_id=3, ), ), ) # TODO: Currently functional collectives on complex numbers are not fully supported, # so we are having a standalone test for view_as_complex and view_as_real combined. # Once complex numbers are supported, we can add the following to the dim_map test. # # self.dimmap_test( # torch.view_as_complex, # (randn(24, 13, 2),), # ( # InputDim(0), # Flatten((InputDim(1), InputDim(2))), # ), # ) # self.dimmap_test( # torch.view_as_real, # (torch.randn(24, 13, dtype=torch.cfloat),), # ( # InputDim(0), # Split(InputDim(1), (13, 2), 0), # Split(InputDim(1), (13, 2), 1), # ), # ) @with_comms def test_complex_view_ops(self): self.device_mesh = DeviceMesh( self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) ) inp = randn(24, 13, 2) intermediate = torch.view_as_complex(inp) out = torch.view_as_real(intermediate) # test dim_map correctness expected_view_as_complex_rule = ( InputDim(0), Flatten((InputDim(1), InputDim(2))), ) view_as_complex_rule = dim_maps[torch.view_as_complex](inp) self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule) expected_view_as_real_rule = ( InputDim(0), Split(InputDim(1), (13, 2), 0), Split(InputDim(1), (13, 2), 1), ) view_as_real_rule = dim_maps[torch.view_as_real](intermediate) self.assertEqual(view_as_real_rule, expected_view_as_real_rule) # test sharded computation correctness # NOTE: For the input to torch.view_as_complex, sharding # on the last two dimensions is not supported. sharding_choices: List[Placement] = [Replicate(), Shard(0)] all_sharding_choices = itertools.product( *(self.device_mesh.ndim * [sharding_choices]) ) for inp_shard in all_sharding_choices: inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard) comm_mode = CommDebugMode() with comm_mode: intermediate_dt = torch.view_as_complex(inp_dt) out_dt = torch.view_as_real(intermediate_dt) self.assertEqual( comm_mode.get_total_counts(), 0, "Expected no redistribution." ) self.assertEqual(out, out_dt.full_tensor()) @with_comms def test_dtensor_view_op_uneven(self): """ Test two uneven cases for view op: 1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard. 2) the sharded tensor dim is uneven such that some ranks have full shards, smaller non-empty shards, and empty shards. """ dim0_sizes = [1, self.world_size + 1] for dim0_size in dim0_sizes: p = torch.randn(dim0_size, 2, 2, 2) mesh = init_device_mesh(self.device_type, (self.world_size,)) dtensor = distribute_tensor(p, mesh, [Shard(0)]) with CommDebugMode() as comm_mode: view = dtensor.view(dim0_size, 2, 4) self.assertEqual(len(comm_mode.get_comm_counts()), 0) # when no communication happens, the data pointer should be the same. self.assertEqual( view.to_local().data_ptr(), dtensor.to_local().data_ptr() ) view = dtensor.view(dim0_size, 4, 2) self.assertEqual( view.to_local().data_ptr(), dtensor.to_local().data_ptr() ) self.assertEqual(len(comm_mode.get_comm_counts()), 0) view = dtensor.view(dim0_size, 8) self.assertEqual( view.to_local().data_ptr(), dtensor.to_local().data_ptr() ) self.assertEqual(len(comm_mode.get_comm_counts()), 0) view = dtensor.view(dtensor.shape) self.assertEqual( view.to_local().data_ptr(), dtensor.to_local().data_ptr() ) self.assertEqual(len(comm_mode.get_comm_counts()), 0) if __name__ == "__main__": run_tests()