Home
last modified time | relevance | path

Searched refs:output_tensor_list (Results 1 – 9 of 9) sorted by relevance

/external/pytorch/torch/distributed/_shard/sharded_tensor/
Dreshard.py212 output_tensor_list = [torch.tensor(1)] * world_size
223 output_tensor_list[
235 output_tensor_list = all_to_all(
236 output_tensor_list,
243output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload]
244 local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim)
/external/pytorch/torch/testing/_internal/distributed/
Dmulti_threaded_pg.py74 output_tensor_list, _ = data[dest_rank]
77 output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
330 def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()): argument
332 res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
399 …def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOpt… argument
401 for o_t, i_t in zip(output_tensor_list, input_tensor_list):
/external/pytorch/test/distributed/
Dtest_c10d_common.py1440 output_tensor_list = [
1447 dist.all_gather(output_tensor_list, input_tensor, group=new_pg)
1453 self.assertEqual(output_tensor_list, expected)
1481 output_tensor_list = [
1488 dist.all_gather(output_tensor_list, input_tensor, group=new_pg)
1494 self.assertEqual(output_tensor_list, expected)
1562 for output_tensor_list, input_tensor in zip(
1565 for output_tensor in output_tensor_list:
1601 def reduce_scatter(self, output_tensor_list, input_tensor_lists, opts=None): argument
1603 output_tensor_list, input_tensor_lists
[all …]
Dtest_multi_threaded_pg.py210 output_tensor_list = [torch.empty_like(tensor) for tensor in input_tensor_list]
211 dist.all_to_all(output_tensor_list, input_tensor_list)
216 self.assertEqual(expected_tensor_list, output_tensor_list)
Dtest_c10d_gloo.py2450 output_tensor_list = [torch.zeros_like(input_tensor)]
2451 dist.all_gather_coalesced([output_tensor_list], [input_tensor])
2452 self.assertEqual(output_tensor_list, [input_tensor])
/external/pytorch/torch/nn/parallel/
Ddistributed.py130 output_tensor_list, treespec = tree_flatten(output.local_value())
132 output_tensor_list, treespec = tree_flatten(output)
135 return output_tensor_list, treespec, output_is_rref
1602 output_tensor_list,
1607 None for _ in range(len(output_tensor_list))
1611 for i, output in enumerate(output_tensor_list):
1622 *output_tensor_list,
/external/pytorch/torch/distributed/nn/
Dfunctional.py158 def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): argument
171 return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
/external/pytorch/torch/distributed/
Ddistributed_c10d.py3537 for output_tensor_list in output_tensor_lists:
3538 _check_tensor_list(output_tensor_list, "output_tensor_lists")
3539 _ensure_all_tensors_same_dtype(output_tensor_list)
4007 def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): argument
4102 _check_tensor_list(output_tensor_list, "output_tensor_list")
4104 _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
4109 output_tensor_list = [
4110 t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list
4114 work = group.alltoall(output_tensor_list, input_tensor_list, opts)
/external/tensorflow/tensorflow/python/eager/
Dfunction_test.py1356 output_tensor_list = compiled()
1358 output_tensor_list, element_dtype=dtypes.float32)