Searched refs:dist_tensor (Results 1 – 9 of 9) sorted by relevance
/external/pytorch/test/distributed/_tensor/ |
D | test_init.py | 51 dist_tensor = dist_init_op( 59 eq_op(ones_expected, dist_tensor.to_local()) 64 dist_tensor = dist_init_op( 81 eq_op(exp_tensor_list[self.rank], dist_tensor.to_local()) 84 eq_op(exp_tensor, dist_tensor.to_local()) 138 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 139 self.assertEqual(dist_tensor.size(), torch.Size(size)) 140 local_tensor = dist_tensor.to_local() 144 self.assertEqual(dist_tensor.to_local(), local_tensor) 146 self.assertEqual(dist_tensor.device.type, self.device_type) [all …]
|
D | test_xla_integration.py | 70 dist_tensor = distribute_tensor( 74 assert type(dist_tensor).__name__ == "XLAShardedTensor" 75 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] 79 local_tensor = dist_tensor.local_shards[0].data 82 self.assertTrue(dist_tensor.global_tensor.requires_grad) 83 self.assertTrue(dist_tensor.is_leaf) 97 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 99 assert type(dist_tensor).__name__ == "XLAShardedTensor" 100 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] 102 local_tensor = dist_tensor.local_shards[0].data [all …]
|
D | test_tensor_ops.py | 54 dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) 55 self.assertTrue(dist_tensor.is_contiguous()) 57 self.assertEqual(dist_tensor.stride(), tensor.stride()) 59 new_dt = dist_tensor.transpose(0, 2) 69 self.assertEqual(dist_tensor.stride(), tensor.stride()) 129 dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) 130 empty_like_dt = torch.empty_like(dist_tensor) 140 dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) 141 full_like_dt = torch.fill_(dist_tensor, 42.0) 144 self.assertEqual(full_expected, dist_tensor.to_local()) [all …]
|
D | test_dtensor.py | 75 dist_tensor = DTensor( 80 self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3))) 153 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) 155 self.assertEqual(dist_tensor.stride(), (8, 1)) 160 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec) 162 self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) 169 dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec) 171 self.assertEqual(dist_tensor.stride(), global_stride) 195 dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements) 196 self.assertFalse(dist_tensor.is_leaf) [all …]
|
D | test_api.py | 52 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 53 self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) 54 local_tensor = dist_tensor.to_local() 57 self.assertTrue(dist_tensor.requires_grad) 58 self.assertTrue(dist_tensor.is_leaf) 63 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec) 64 self.assertEqual(dist_tensor.placements[0].dim, 1) 112 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 113 self.assertEqual(dist_tensor.size(), torch.Size(input_size)) 114 local_tensor = dist_tensor.to_local()
|
/external/pytorch/torch/testing/_internal/distributed/ |
D | common_state_dict.py | 23 def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False): argument 24 if isinstance(dist_tensor, (DTensor, ShardedTensor)): 25 dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey") 29 dist_tensor = dist_tensor.cpu() 30 self.assertTrue(isinstance(dist_tensor, torch.Tensor)) 31 self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
|
D | distributed_test.py | 7129 for (_, local_tensor), (_, dist_tensor) in zip( 7132 self.assertEqual(local_tensor, dist_tensor)
|
/external/pytorch/test/distributed/checkpoint/ |
D | test_state_dict_utils.py | 42 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 43 state_dict = {"dtensor": dist_tensor} 47 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 59 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 60 state_dict = {"dtensor": dist_tensor} 66 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 100 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) 102 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0) 104 return tensor, dist_tensor
|
/external/pytorch/torch/distributed/tensor/ |
D | _api.py | 178 dist_tensor = DTensor( 185 return dist_tensor
|