Home
last modified time | relevance | path

Searched refs:dist_tensor (Results 1 – 9 of 9) sorted by relevance

/external/pytorch/test/distributed/_tensor/
Dtest_init.py51 dist_tensor = dist_init_op(
59 eq_op(ones_expected, dist_tensor.to_local())
64 dist_tensor = dist_init_op(
81 eq_op(exp_tensor_list[self.rank], dist_tensor.to_local())
84 eq_op(exp_tensor, dist_tensor.to_local())
138 dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
139 self.assertEqual(dist_tensor.size(), torch.Size(size))
140 local_tensor = dist_tensor.to_local()
144 self.assertEqual(dist_tensor.to_local(), local_tensor)
146 self.assertEqual(dist_tensor.device.type, self.device_type)
[all …]
Dtest_xla_integration.py70 dist_tensor = distribute_tensor(
74 assert type(dist_tensor).__name__ == "XLAShardedTensor"
75 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
79 local_tensor = dist_tensor.local_shards[0].data
82 self.assertTrue(dist_tensor.global_tensor.requires_grad)
83 self.assertTrue(dist_tensor.is_leaf)
97 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
99 assert type(dist_tensor).__name__ == "XLAShardedTensor"
100 global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
102 local_tensor = dist_tensor.local_shards[0].data
[all …]
Dtest_tensor_ops.py54 dist_tensor = DTensor.from_local(tensor, device_mesh, sharding)
55 self.assertTrue(dist_tensor.is_contiguous())
57 self.assertEqual(dist_tensor.stride(), tensor.stride())
59 new_dt = dist_tensor.transpose(0, 2)
69 self.assertEqual(dist_tensor.stride(), tensor.stride())
129 dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
130 empty_like_dt = torch.empty_like(dist_tensor)
140 dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
141 full_like_dt = torch.fill_(dist_tensor, 42.0)
144 self.assertEqual(full_expected, dist_tensor.to_local())
[all …]
Dtest_dtensor.py75 dist_tensor = DTensor(
80 self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3)))
153 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
155 self.assertEqual(dist_tensor.stride(), (8, 1))
160 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
162 self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
169 dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
171 self.assertEqual(dist_tensor.stride(), global_stride)
195 dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
196 self.assertFalse(dist_tensor.is_leaf)
[all …]
Dtest_api.py52 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
53 self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
54 local_tensor = dist_tensor.to_local()
57 self.assertTrue(dist_tensor.requires_grad)
58 self.assertTrue(dist_tensor.is_leaf)
63 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
64 self.assertEqual(dist_tensor.placements[0].dim, 1)
112 dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
113 self.assertEqual(dist_tensor.size(), torch.Size(input_size))
114 local_tensor = dist_tensor.to_local()
/external/pytorch/torch/testing/_internal/distributed/
Dcommon_state_dict.py23 def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False): argument
24 if isinstance(dist_tensor, (DTensor, ShardedTensor)):
25 dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
29 dist_tensor = dist_tensor.cpu()
30 self.assertTrue(isinstance(dist_tensor, torch.Tensor))
31 self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
Ddistributed_test.py7129 for (_, local_tensor), (_, dist_tensor) in zip(
7132 self.assertEqual(local_tensor, dist_tensor)
/external/pytorch/test/distributed/checkpoint/
Dtest_state_dict_utils.py42 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
43 state_dict = {"dtensor": dist_tensor}
47 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
59 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
60 state_dict = {"dtensor": dist_tensor}
66 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
100 dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
102 dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
104 return tensor, dist_tensor
/external/pytorch/torch/distributed/tensor/
D_api.py178 dist_tensor = DTensor(
185 return dist_tensor