Home
last modified time | relevance | path

Searched refs:reduce_scatter (Results 1 – 25 of 69) sorted by relevance

123

/external/tensorflow/tensorflow/dtensor/mlir/utils/
Dcollective_lowering.cc325 mlir::TF::DTensorReduceScatterOp reduce_scatter) { in LowerReduceScatterOp() argument
326 mlir::Location loc = reduce_scatter.getLoc(); in LowerReduceScatterOp()
329 ExtractRequiredSingleLayoutFromOp(reduce_scatter); in LowerReduceScatterOp()
331 return reduce_scatter.emitOpError(output_layout.status().error_message()); in LowerReduceScatterOp()
334 if (!matchPattern(reduce_scatter.group_assignment(), in LowerReduceScatterOp()
336 return reduce_scatter.emitOpError("group_assigment must be a constant."); in LowerReduceScatterOp()
338 return reduce_scatter.emitOpError( in LowerReduceScatterOp()
341 mlir::OpBuilder builder(reduce_scatter); in LowerReduceScatterOp()
342 if (reduce_scatter.device_type().endswith("TPU")) { in LowerReduceScatterOp()
344 reduce_scatter))) in LowerReduceScatterOp()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dall_reduce_blueconnect_test.cc73 auto reduce_scatter = AllOf(op::Shape("f32[4]"), op::ReduceScatter(bitcast), in TEST_F() local
75 auto all_reduce = AllOf(op::Shape("f32[4]"), op::AllReduce(reduce_scatter), in TEST_F()
159 auto reduce_scatter = AllOf(op::Shape("(f32[4], f32[8])"), in TEST_F() local
163 op::AllReduce(op::GetTupleElement(reduce_scatter, 0), in TEST_F()
164 op::GetTupleElement(reduce_scatter, 1)), in TEST_F()
Dall_reduce_blueconnect.cc225 HloInstruction* reduce_scatter = in TryDecomposeAllReduce() local
234 reduce_scatter_shape, GetOutputs(*reduce_scatter), in TryDecomposeAllReduce()
/external/tensorflow/third_party/nccl/
Darchive.patch25 diff --git a/src/collectives/device/reduce_scatter.cu b/src/collectives/device/reduce_scatter.cu.cc
27 rename from src/collectives/device/reduce_scatter.cu
28 rename to src/collectives/device/reduce_scatter.cu.cc
Darchive.BUILD49 "src/collectives/device/reduce_scatter.cu.cc",
/external/pytorch/test/distributed/tensor/parallel/
Dtest_tp_examples.py47 reduce_scatter, all_gather, all_reduce = ( variable
289 fwd={reduce_scatter: 6, all_gather: 6},
290 bwd={reduce_scatter: 5, all_gather: 6},
345 bwd={reduce_scatter: 5, all_gather: 6}, optim={all_reduce: 30}
359 ExpCommCounts(bwd={reduce_scatter: 1}, optim={all_reduce: 6}),
365 ExpCommCounts(bwd={reduce_scatter: 5, all_gather: 5}),
377 bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 6}
392 bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 12}
407 bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 3}
Dtest_micro_pipeline_tp.py150 for reduce_scatter in reduce_scatters:
152 reduce_scatter.input_node.op,
156 reduce_scatter.rs_node.target,
160 reduce_scatter.res_node.target,
163 self.assertEqual(reduce_scatter.group_name, group.group_name)
/external/pytorch/torch/_inductor/fx_passes/
Dmicro_pipeline_tp.py668 def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
693 reduce_scatter.input_node,
694 reduce_scatter.rs_node,
695 reduce_scatter.res_node,
696 reduce_scatter.reduce_op,
697 reduce_scatter.scatter_dim,
698 reduce_scatter.group_name,
737 reduce_scatter.replace_with(fused_node)
738 reduce_scatter.erase()
853 for reduce_scatter in reduce_scatters:
[all …]
/external/tensorflow/tensorflow/dtensor/mlir/
Ddtensor_allreduce_scatter_optimization.cc138 auto reduce_scatter = builder.create<mlir::TF::DTensorReduceScatterOp>( in ApplyOptimization() local
143 SetSingleLayoutOnOp(reduce_scatter, desired_layout); in ApplyOptimization()
145 all_scatter->replaceAllUsesWith(reduce_scatter); in ApplyOptimization()
/external/pytorch/test/distributed/_composable/fsdp/
Dtest_fully_shard_mixed_precision.py83 reduce_scatter = functools.partial(
95 with patch_reduce_scatter(reduce_scatter):
152 reduce_scatter = functools.partial(
160 with patch_reduce_scatter(reduce_scatter):
196 reduce_scatter = functools.partial(
204 with patch_reduce_scatter(reduce_scatter):
267 reduce_scatter = functools.partial(
289 with patch_reduce_scatter(reduce_scatter):
Dtest_fully_shard_frozen.py101 reduce_scatter = functools.partial(
115 reduce_scatter
/external/pytorch/test/distributed/
Dtest_c10d_ops_nccl.py95 pg.reduce_scatter(ys, xs).wait()
739 def reduce_scatter(outputs, input_lists, op): function
742 work = pg.reduce_scatter(outputs, input_lists, opts)
763 reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM)
776 reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN)
783 reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX)
790 reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT)
811 pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
818 pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
823 pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
[all …]
Dtest_nccl.py170 nccl.reduce_scatter(t, t)
221 nccl.reduce_scatter(inputs, outputs)
229 nccl.reduce_scatter(tuple(inputs), tuple(outputs))
Dtest_c10d_spawn_nccl.py170 y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
209 y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
Dtest_multi_threaded_pg.py180 dist.reduce_scatter(output_tensor, to_reduce_scatter)
185 dist.reduce_scatter(output_tensor, to_reduce_scatter, op=dist.ReduceOp.AVG)
/external/pytorch/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/
Dembedding.py8 from torch.distributed.nn.functional import all_gather, reduce_scatter
290 return reduce_scatter(
Dembedding_bag.py11 from torch.distributed.nn.functional import all_gather, reduce_scatter
410 result = reduce_scatter(
/external/pytorch/torch/csrc/distributed/c10d/
DPyProcessGroup.hpp155 c10::intrusive_ptr<Work> reduce_scatter( in reduce_scatter() function in c10d::PyProcessGroup
162 reduce_scatter, /* Name of function in C++ */ in reduce_scatter()
DProcessGroupWrapper.hpp65 c10::intrusive_ptr<Work> reduce_scatter(
/external/pytorch/torch/distributed/nn/
Dfunctional.py88 def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): function
310 dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
/external/pytorch/test/cpp_extensions/
Dcpp_c10d_extension.hpp84 c10::intrusive_ptr<Work> reduce_scatter(
Dcpp_c10d_extension.cpp83 c10::intrusive_ptr<Work> ProcessGroupTest::reduce_scatter( in reduce_scatter() function in c10d::ProcessGroupTest
/external/pytorch/torch/cuda/
Dnccl.py142 def reduce_scatter( function
/external/pytorch/test/inductor/
Dtest_distributed_patterns.py24 def reduce_scatter(t): function
90 new_grad = reduce_scatter(grad)
99 m.sharded_weight = nn.Parameter(reduce_scatter(m.weight))
/external/pytorch/torch/distributed/checkpoint/
Dstate_dict_loader.py222 central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)

123