Home
last modified time | relevance | path

Searched refs:replica_groups (Results 1 – 25 of 51) sorted by relevance

123

/external/tensorflow/tensorflow/compiler/xla/service/
Dcollective_ops_utils_test.cc36 std::vector<ReplicaGroup> replica_groups(3); in TEST() local
37 replica_groups[0].add_replica_ids(0); in TEST()
38 replica_groups[0].add_replica_ids(4); in TEST()
39 replica_groups[1].add_replica_ids(1); in TEST()
40 replica_groups[1].add_replica_ids(5); in TEST()
41 replica_groups[2].add_replica_ids(2); in TEST()
42 replica_groups[2].add_replica_ids(3); in TEST()
46 /*replica_id=*/1, /*total_replica_count=*/6, replica_groups) in TEST()
76 std::vector<ReplicaGroup> replica_groups(2); in TEST() local
77 replica_groups[0].add_replica_ids(0); in TEST()
[all …]
Dall_gather_decomposer.cc67 if (ag->replica_groups().empty()) { in DecomposeAllGather()
70 if (ag->replica_groups().size() == 1) { in DecomposeAllGather()
73 for (int64 i = 0; i < ag->replica_groups()[0].replica_ids_size(); ++i) { in DecomposeAllGather()
74 if (ag->replica_groups()[0].replica_ids(i) != i) { in DecomposeAllGather()
80 CHECK_EQ(partition_count, ag->replica_groups()[0].replica_ids_size()); in DecomposeAllGather()
86 std::vector<uint32> shard_ids(ag->replica_groups().size() * in DecomposeAllGather()
87 ag->replica_groups()[0].replica_ids_size()); in DecomposeAllGather()
88 for (const auto& group : ag->replica_groups()) { in DecomposeAllGather()
114 TF_RET_CHECK(!ag->replica_groups().empty()); in DecomposeAllGather()
115 TF_RET_CHECK(ag->replica_groups()[0].replica_ids_size() == 1); in DecomposeAllGather()
[all …]
Dar_crs_combiner_test.cc455 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
465 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
506 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
515 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
575 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
585 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
644 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
655 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
687 auto replica_groups_before = crs_before->replica_groups(); in TEST_F()
697 auto replica_groups_after = crs_after->replica_groups(); in TEST_F()
[all …]
Dcollective_ops_utils.cc55 absl::Span<const ReplicaGroup> replica_groups) { in GetParticipatingReplicas() argument
57 if (replica_groups.empty()) { in GetParticipatingReplicas()
65 for (const ReplicaGroup& g : replica_groups) { in GetParticipatingReplicas()
80 int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) { in GetParticipatingDevices() argument
83 if (replica_groups.empty() && device_assignment.computation_count() == 1) { in GetParticipatingDevices()
99 replica_groups)); in GetParticipatingDevices()
Dall_reduce_combiner.cc84 to_combine.front()->replica_groups(), in CombineAllReduces()
115 replica_groups(hlo->replica_groups()) {} in GroupKey()
133 if (replica_groups.size() != other.replica_groups.size()) { in operator <()
134 return replica_groups.size() < other.replica_groups.size(); in operator <()
136 for (int64 i = 0; i < replica_groups.size(); ++i) { in operator <()
137 const auto& rg = replica_groups[i]; in operator <()
138 const auto& org = other.replica_groups[i]; in operator <()
156 std::vector<ReplicaGroup> replica_groups; member
Dall_to_all_decomposer.cc60 all_to_all->replica_groups().empty() in ExpandInstruction()
62 : all_to_all->replica_groups()[0].replica_ids_size(); in ExpandInstruction()
116 all_to_all_shape, slices, all_to_all->replica_groups(), false, in ExpandInstruction()
Dbfloat16_normalization_test.cc287 std::vector<ReplicaGroup> replica_groups(1); in TEST_F() local
288 replica_groups[0].add_replica_ids(0); in TEST_F()
289 replica_groups[0].add_replica_ids(1); in TEST_F()
292 replica_groups, /*constrain_layout=*/false, absl::nullopt)); in TEST_F()
316 std::vector<ReplicaGroup> replica_groups(1); in TEST_F() local
317 replica_groups[0].add_replica_ids(0); in TEST_F()
318 replica_groups[0].add_replica_ids(1); in TEST_F()
321 replica_groups, /*constrain_layout=*/false, absl::nullopt)); in TEST_F()
Dhlo_replication_analysis.cc72 return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; in DetermineHloInstructionIsReplicated()
85 for (const auto& group : hlo->replica_groups()) { in DetermineHloInstructionIsReplicated()
104 : hlo->replica_groups().empty() || in DetermineHloInstructionIsReplicated()
105 hlo->replica_groups().size() == 1; in DetermineHloInstructionIsReplicated()
Dhlo_verifier_test.cc872 string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) { in ReplicaGroupsStr() argument
874 for (const auto& g : replica_groups) { in ReplicaGroupsStr()
881 int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) { in ReplicaCount() argument
883 for (auto group : replica_groups) { in ReplicaCount()
890 std::vector<std::vector<int64>> replica_groups, in MakeAllReduceComputation() argument
908 config.set_replica_count(ReplicaCount(replica_groups)); in MakeAllReduceComputation()
912 kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), in MakeAllReduceComputation()
963 std::vector<std::vector<int64>> replica_groups) { in MakeAllToAllComputation() argument
972 config.set_replica_count(ReplicaCount(replica_groups)); in MakeAllToAllComputation()
975 kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), in MakeAllToAllComputation()
Dall_reduce_simplifier.cc40 if (all_reduce->replica_groups().empty()) { in Run()
44 for (const auto& group : all_reduce->replica_groups()) { in Run()
Dar_crs_combiner.cc60 if (ar->replica_groups().size() > 1) { in ReplaceReplicatedAllReduce()
97 auto replica_groups = all_reduce->replica_groups(); in HasCombinableReplicaGroup() local
101 if (replica_groups.size() != num_replicas) { in HasCombinableReplicaGroup()
104 for (const auto& group : replica_groups) { in HasCombinableReplicaGroup()
123 return replica_groups.size() == num_replicas; in HasCombinableReplicaGroup()
Dwhile_loop_all_reduce_code_motion_test.cc118 EXPECT_THAT(moved_all_reduce->replica_groups(), SizeIs(1)); in TEST_F()
120 std::equal(moved_all_reduce->replica_groups()[0].replica_ids().begin(), in TEST_F()
121 moved_all_reduce->replica_groups()[0].replica_ids().end(), in TEST_F()
Dcollective_ops_utils.h44 absl::Span<const ReplicaGroup> replica_groups);
50 int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
Dall_reduce_combiner_test.cc103 std::vector<ReplicaGroup> replica_groups(groups.size()); in CreateReplicaGroups() local
105 *replica_groups[i].mutable_replica_ids() = {groups[i].begin(), in CreateReplicaGroups()
108 return replica_groups; in CreateReplicaGroups()
Dhlo_verifier.cc198 for (const ReplicaGroup& g : hlo->replica_groups()) { in CheckReplicaGroups()
224 if (hlo->replica_groups().empty()) { in CheckReplicaGroups()
262 TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); in HandleAllGather()
264 if (ag->replica_groups().empty() || in HandleAllGather()
265 ag->replica_groups()[0].replica_ids_size() != 1) { in HandleAllGather()
271 } else if (!ag->replica_groups().empty()) { in HandleAllGather()
273 TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); in HandleAllGather()
297 if (hlo->replica_groups().empty()) { in HandleAllToAll()
307 const int64 split_count = hlo->replica_groups().empty() in HandleAllToAll()
309 : hlo->replica_groups()[0].replica_ids_size(); in HandleAllToAll()
[all …]
Dhlo_instructions.cc573 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, in HloCollectiveInstruction() argument
576 replica_groups_(replica_groups), in HloCollectiveInstruction()
596 StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))); in ExtraAttributesToStringImpl()
612 absl::c_equal(replica_groups(), casted_other.replica_groups(), in IdenticalSlowPathIgnoringChannelIdValues()
620 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, in HloAllGatherInstruction() argument
623 replica_groups, constrain_layout, channel_id),
643 shape, new_operands[0], all_gather_dimension(), replica_groups(), in CloneWithNewOperandsImpl()
668 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout, in HloAllReduceInstruction() argument
671 replica_groups, constrain_layout, channel_id), in HloAllReduceInstruction()
677 for (const auto& replica_group : replica_groups()) { in IsNoop()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/
Dcross_replica_ops.cc34 std::vector<xla::ReplicaGroup> replica_groups; in Convert() local
38 replica_groups.reserve(num_groups); in Convert()
46 replica_groups.push_back(replica_group); in Convert()
48 return replica_groups; in Convert()
59 std::vector<xla::ReplicaGroup> replica_groups = in Compile() local
61 ctx->SetOutput(0, xla::CrossReplicaSum(ctx->Input(0), replica_groups)); in Compile()
81 std::vector<xla::ReplicaGroup> replica_groups = in Compile() local
85 split_count_, replica_groups)); in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dnccl_collective_thunk.h71 std::vector<ReplicaGroup> replica_groups; member
88 config.replica_groups = in GetNcclCollectiveConfigForMlir()
89 ConvertReplicaGroups(op.replica_groups()).ValueOrDie(); in GetNcclCollectiveConfigForMlir()
Dnccl_collective_thunk.cc76 config().replica_count, config().replica_groups)); in ExecuteOnStream()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dattribute_exporter.cc100 std::vector<ReplicaGroup> replica_groups(type.getDimSize(0)); in ConvertReplicaGroups() local
101 for (ReplicaGroup& group : replica_groups) { in ConvertReplicaGroups()
112 return replica_groups; in ConvertReplicaGroups()
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_client.py677 def make_replica_groups(replica_groups): argument
678 if replica_groups is None:
681 replica_groups = list(replica_groups)
683 _make_replica_group_proto(group) for group in replica_groups
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/
Dhlo_text_to_lhlo_no_opt.hlotxt218 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
219 …ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_ap…
239 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 4]]> : tensor<2x4xi64>
244 …ROOT result = (f32[8], f32[5]) all-reduce(%tuple), replica_groups={{0,1,2,3}, {5,6,7,4}}, to_apply…
268 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, -1], [3, 4, 5, 6]]> : tensor<2x4xi64>
269 …ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2}, {3,4,5,6}}, to_appl…
608 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
612 ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}},
621 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
627 replica_groups={{0,1}}
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dcanonicalize_all_gather_for_cse.cc84 real_data, /*all_gather_dimension=*/0, ag->replica_groups(), in RunOnComputation()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h731 absl::Span<const ReplicaGroup> replica_groups = {});
735 absl::Span<const ReplicaGroup> replica_groups = {},
742 absl::Span<const ReplicaGroup> replica_groups = {},
748 const std::vector<ReplicaGroup>& replica_groups,
753 const std::vector<ReplicaGroup>& replica_groups,
1296 absl::Span<const ReplicaGroup> replica_groups);
1299 absl::Span<const ReplicaGroup> replica_groups,
1304 absl::Span<const ReplicaGroup> replica_groups,
1309 const std::vector<ReplicaGroup>& replica_groups,
1313 const std::vector<ReplicaGroup>& replica_groups,
[all …]
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
Dlhlo_ops.cc67 DenseIntElementsAttr attr = op.replica_groups(); in VerifyReplicaGroups()

123