Home
last modified time | relevance | path

Searched refs:replica_groups (Results 1 – 14 of 14) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dar_crs_combiner_test.cc405 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
414 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
473 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
482 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
540 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
550 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
621 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
634 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
763 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
772 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
[all …]
Dhlo_instructions.cc479 const std::vector<ReplicaGroup>& replica_groups) in HloCollectiveInstruction() argument
480 : HloInstruction(opcode, shape), replica_groups_(replica_groups) { in HloCollectiveInstruction()
497 for (const ReplicaGroup& group : replica_groups()) { in ExtraAttributesToStringImpl()
512 return absl::c_equal(replica_groups(), casted_other.replica_groups(), in IdenticalSlowPath()
521 const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, in HloAllReduceInstruction() argument
524 replica_groups), in HloAllReduceInstruction()
546 for (auto replica_group : replica_groups()) { in IsNoop()
583 shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(), in CloneWithNewOperandsImpl()
589 const std::vector<ReplicaGroup>& replica_groups) in HloAllToAllInstruction() argument
591 replica_groups) {} in HloAllToAllInstruction()
[all …]
Dhlo_parser_test.cc1351 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add in CreateTestCases()
1369 …ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc",… in CreateTestCases()
1387 crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add in CreateTestCases()
1388 ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add in CreateTestCases()
1400 ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} in CreateTestCases()
1412 ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} in CreateTestCases()
Dhlo_instruction.cc386 std::vector<ReplicaGroup>(proto.replica_groups().begin(), in CreateFromProto()
387 proto.replica_groups().end()), in CreateFromProto()
396 std::vector<ReplicaGroup>(proto.replica_groups().begin(), in CreateFromProto()
397 proto.replica_groups().end())); in CreateFromProto()
835 const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, in CreateAllReduce() argument
838 shape, operands, reduce_computation, replica_groups, barrier, in CreateAllReduce()
844 const std::vector<ReplicaGroup>& replica_groups) { in CreateAllToAll() argument
846 replica_groups); in CreateAllToAll()
3412 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { in replica_groups() function in xla::HloInstruction
3413 return Cast<HloCollectiveInstruction>(this)->replica_groups(); in replica_groups()
Dhlo_instructions.h290 const std::vector<ReplicaGroup>& replica_groups() const { in replica_groups() function
298 const std::vector<ReplicaGroup>& replica_groups);
317 const std::vector<ReplicaGroup>& replica_groups,
361 const std::vector<ReplicaGroup>& replica_groups);
Dhlo_instruction.h490 const std::vector<ReplicaGroup>& replica_groups,
508 const std::vector<ReplicaGroup>& replica_groups);
1522 const std::vector<ReplicaGroup>& replica_groups() const;
Dhlo_parser.cc395 std::vector<ReplicaGroup> replica_groups; in CreateReplicaGroups() local
396 absl::c_transform(groups, std::back_inserter(replica_groups), in CreateReplicaGroups()
402 return replica_groups; in CreateReplicaGroups()
833 std::vector<ReplicaGroup> replica_groups; in ParseInstructionRhs() local
835 replica_groups = CreateReplicaGroups(*tmp_groups); in ParseInstructionRhs()
838 shape, operands, *to_apply, replica_groups, barrier ? *barrier : "", in ParseInstructionRhs()
850 std::vector<ReplicaGroup> replica_groups; in ParseInstructionRhs() local
852 replica_groups = CreateReplicaGroups(*tmp_groups); in ParseInstructionRhs()
855 HloInstruction::CreateAllToAll(shape, operands, replica_groups)); in ParseInstructionRhs()
Dhlo.proto173 repeated ReplicaGroup replica_groups = 49; field
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_client.py1209 replica_groups=None): argument
1225 if replica_groups is None:
1228 replica_groups = list(replica_groups)
1230 _make_replica_group_proto(group) for group in replica_groups
1232 if not replica_groups:
1235 split_count = len(replica_groups[0])
1236 if not all(split_count == len(g) for g in replica_groups):
1241 def CrossReplicaSum(self, operand, replica_groups=None): argument
1254 if replica_groups is None:
1255 replica_groups = [] # special value for XLA API
[all …]
Dlocal_computation_builder.cc500 int64 split_count, absl::Span<const ReplicaGroup> replica_groups) { in AllToAll() argument
502 rg.reserve(replica_groups.size()); in AllToAll()
503 for (int i = 0; i < replica_groups.size(); ++i) { in AllToAll()
504 rg.push_back(replica_groups[i]); in AllToAll()
511 const LocalOp& operand, absl::Span<const ReplicaGroup> replica_groups) { in CrossReplicaSum() argument
512 return xla::CrossReplicaSum(operand.op(), replica_groups); in CrossReplicaSum()
Dlocal_computation_builder.h257 absl::Span<const ReplicaGroup> replica_groups);
260 absl::Span<const ReplicaGroup> replica_groups);
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h467 absl::Span<const ReplicaGroup> replica_groups = {});
471 absl::Span<const ReplicaGroup> replica_groups = {},
476 const std::vector<ReplicaGroup>& replica_groups);
885 absl::Span<const ReplicaGroup> replica_groups);
888 absl::Span<const ReplicaGroup> replica_groups,
892 const std::vector<ReplicaGroup>& replica_groups);
1568 absl::Span<const ReplicaGroup> replica_groups = {});
1589 absl::Span<const ReplicaGroup> replica_groups = {},
1595 const std::vector<ReplicaGroup>& replica_groups = {});
Dxla_builder.cc2121 const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) { in CrossReplicaSum() argument
2129 return CrossReplicaSum(operand, computation, replica_groups, in CrossReplicaSum()
2136 absl::Span<const ReplicaGroup> replica_groups, in CrossReplicaSum() argument
2145 for (const ReplicaGroup& group : replica_groups) { in CrossReplicaSum()
2161 const std::vector<ReplicaGroup>& replica_groups) { in AllToAll() argument
2197 for (const ReplicaGroup& group : replica_groups) { in AllToAll()
3233 absl::Span<const ReplicaGroup> replica_groups) { in CrossReplicaSum() argument
3234 return operand.builder()->CrossReplicaSum(operand, replica_groups); in CrossReplicaSum()
3238 absl::Span<const ReplicaGroup> replica_groups, in CrossReplicaSum() argument
3241 replica_groups, channel_id); in CrossReplicaSum()
[all …]
/external/tensorflow/tensorflow/compiler/xla/g3doc/
Doperation_semantics.md48 - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty,
63 replica_groups)` </b>
79 : : : `replica_groups` is empty, this :
84 | `replica_groups` | `ReplicaGroup` vector | each group contains a list of |