1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/global_device_id.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23 #include "tensorflow/compiler/xla/util.h"
24
25 namespace xla {
26
27 // Match the instruction to a reduction kind. We can represent and/or of pred as
28 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
MatchReductionInstruction(const HloInstruction * hlo)29 absl::optional<ReductionKind> MatchReductionInstruction(
30 const HloInstruction* hlo) {
31 PrimitiveType type = hlo->shape().element_type();
32 switch (hlo->opcode()) {
33 case HloOpcode::kAdd:
34 return ReductionKind::SUM;
35 case HloOpcode::kMultiply:
36 return ReductionKind::PRODUCT;
37 case HloOpcode::kMinimum:
38 return ReductionKind::MIN;
39 case HloOpcode::kMaximum:
40 return ReductionKind::MAX;
41 case HloOpcode::kAnd:
42 return type == PRED ? absl::optional<ReductionKind>(ReductionKind::MIN)
43 : absl::nullopt;
44 case HloOpcode::kOr:
45 return type == PRED ? absl::optional<ReductionKind>(ReductionKind::MAX)
46 : absl::nullopt;
47 default:
48 return absl::nullopt;
49 }
50 }
51
MatchReductionComputation(const HloComputation * computation)52 absl::optional<ReductionKind> MatchReductionComputation(
53 const HloComputation* computation) {
54 namespace m = match;
55 const HloInstruction* root = computation->root_instruction();
56 auto kind = MatchReductionInstruction(root);
57 if (kind && !Match(root, m::Op()
58 .WithBinaryOperandsAnyOrder(m::Parameter(0),
59 m::Parameter(1))
60 .WithShape(m::Shape().IsEffectiveScalar()))) {
61 kind = absl::nullopt;
62 }
63 return kind;
64 }
65
GetParticipatingIDs(int current_id,absl::optional<int> total_participant_count,absl::Span<const ReplicaGroup> groups)66 StatusOr<std::vector<int>> GetParticipatingIDs(
67 int current_id, absl::optional<int> total_participant_count,
68 absl::Span<const ReplicaGroup> groups) {
69 // Empty replica_groups() means that all replicas participate.
70 if (groups.empty()) {
71 TF_RET_CHECK(total_participant_count.has_value());
72 std::vector<int> all_participants(*total_participant_count);
73 absl::c_iota(all_participants, 0);
74 return all_participants;
75 }
76
77 // Figure out the other replicas that go together with this one.
78 absl::optional<ReplicaGroup> group;
79 for (const ReplicaGroup& g : groups) {
80 if (absl::c_linear_search(g.replica_ids(), current_id)) {
81 TF_RET_CHECK(!group.has_value())
82 << "ID " << current_id << " appears twice in replica groups";
83 group = g;
84 }
85 }
86 TF_RET_CHECK(group.has_value())
87 << "ID " << current_id << " doesn't appear in replica groups";
88 return std::vector<int>(group->replica_ids().begin(),
89 group->replica_ids().end());
90 }
91
92 // Returns the group formation mode implied by (a) whether the operation has
93 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
GetCollectiveOpGroupMode(bool has_channel_id,absl::optional<bool> use_global_device_ids)94 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
95 bool has_channel_id, absl::optional<bool> use_global_device_ids) {
96 if (!has_channel_id) {
97 if (!use_global_device_ids.has_value() || !*use_global_device_ids) {
98 return CollectiveOpGroupMode::kCrossReplica;
99 } else {
100 return InvalidArgument(
101 "Invalid combination of has_channel_id and use_global_device_ids");
102 }
103 } else {
104 if (!use_global_device_ids.has_value()) {
105 return CollectiveOpGroupMode::kCrossPartition;
106 } else if (!*use_global_device_ids) {
107 return CollectiveOpGroupMode::kCrossReplicaAndPartition;
108 } else {
109 return CollectiveOpGroupMode::kFlattenedID;
110 }
111 }
112 }
113
CollectiveOpGroupModeToString(CollectiveOpGroupMode group_mode)114 absl::string_view CollectiveOpGroupModeToString(
115 CollectiveOpGroupMode group_mode) {
116 switch (group_mode) {
117 case CollectiveOpGroupMode::kCrossReplica:
118 return "kCrossReplica";
119 case CollectiveOpGroupMode::kCrossPartition:
120 return "kCrossPartition";
121 case CollectiveOpGroupMode::kCrossReplicaAndPartition:
122 return "kCrossReplicaAndPartition";
123 case CollectiveOpGroupMode::kFlattenedID:
124 return "kFlattenedID";
125 }
126 }
127
128 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
GetParticipatingDevicesGroups(const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)129 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
130 absl::Span<const ReplicaGroup> replica_groups,
131 CollectiveOpGroupMode group_mode) {
132 int replica_count = device_assignment.replica_count();
133 int partition_count = device_assignment.computation_count();
134
135 std::vector<ReplicaGroup> participating_replica_groups =
136 SpanToVector(replica_groups);
137
138 // If replica groups are empty, assume a group with all replicas.
139 if (replica_groups.empty()) {
140 if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
141 // replica groups contain flattened-ids and cannot be empty.
142 TF_RET_CHECK(!replica_groups.empty())
143 << "replica groups cannot be empty for kFlattenedID mode";
144 }
145
146 int total_participant_count;
147 if (group_mode == CollectiveOpGroupMode::kCrossPartition) {
148 // replica group are partition ids.
149 total_participant_count = partition_count;
150 } else {
151 // replica group are replica ids.
152 total_participant_count = replica_count;
153 }
154
155 ReplicaGroup replica_group = ReplicaGroup();
156 for (int id = 0; id < total_participant_count; id++) {
157 replica_group.add_replica_ids(id);
158 }
159 participating_replica_groups.push_back(replica_group);
160 }
161
162 std::vector<std::vector<GlobalDeviceId>> groups;
163 switch (group_mode) {
164 case CollectiveOpGroupMode::kCrossReplica: {
165 for (const auto& replica_group : participating_replica_groups) {
166 // replica_group contains replica id, participants contains all
167 // replica_group's replica_ids for the current partition.
168 for (int partition_id = 0; partition_id < partition_count;
169 partition_id++) {
170 std::vector<GlobalDeviceId> participants;
171 participants.reserve(replica_group.replica_ids().size());
172
173 for (int replica_id : replica_group.replica_ids()) {
174 participants.emplace_back(
175 device_assignment(replica_id, partition_id));
176 }
177 groups.push_back(participants);
178 }
179 }
180 return groups;
181 }
182 case CollectiveOpGroupMode::kCrossPartition: {
183 for (const auto& replica_group : participating_replica_groups) {
184 // replica_group contains partition id, participants contains all
185 // replica_group's partition_ids for the current replica_id.
186 for (int replica_id = 0; replica_id < replica_count; replica_id++) {
187 std::vector<GlobalDeviceId> participants;
188 participants.reserve(replica_group.replica_ids().size());
189
190 for (int partition_id : replica_group.replica_ids()) {
191 participants.emplace_back(
192 device_assignment(replica_id, partition_id));
193 }
194 groups.push_back(participants);
195 }
196 }
197 return groups;
198 }
199 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
200 for (const auto& replica_group : participating_replica_groups) {
201 std::vector<GlobalDeviceId> participants;
202 participants.reserve(replica_group.replica_ids().size() *
203 partition_count);
204
205 // replica_group contains replica id, participants contains all
206 // replica_group's replica_ids for all partitions.
207 for (int replica_id : replica_group.replica_ids()) {
208 for (int partition_id = 0; partition_id < partition_count;
209 partition_id++) {
210 participants.emplace_back(
211 device_assignment(replica_id, partition_id));
212 }
213 }
214 groups.push_back(participants);
215 }
216 return groups;
217 }
218 case CollectiveOpGroupMode::kFlattenedID: {
219 for (const auto& replica_group : participating_replica_groups) {
220 std::vector<GlobalDeviceId> participants;
221 participants.reserve(replica_group.replica_ids().size());
222
223 for (int flattened_id : replica_group.replica_ids()) {
224 // Map from flattened id back to replica_id, partition_id.
225 int replica_id = flattened_id / partition_count;
226 int partition_id = flattened_id % partition_count;
227 participants.emplace_back(
228 device_assignment(replica_id, partition_id));
229 }
230 groups.push_back(participants);
231 }
232 return groups;
233 }
234 }
235 }
236
GetParticipatingDevices(GlobalDeviceId device_id,const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)237 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
238 GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
239 absl::Span<const ReplicaGroup> replica_groups,
240 CollectiveOpGroupMode group_mode) {
241 int replica_count = device_assignment.replica_count();
242 int partition_count = device_assignment.computation_count();
243
244 TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
245 device_assignment.LogicalIdForDevice(device_id));
246 int current_replica_id = logical_id.replica_id;
247 int current_partition_id = logical_id.computation_id;
248
249 std::vector<GlobalDeviceId> participants;
250 switch (group_mode) {
251 case CollectiveOpGroupMode::kCrossReplica: {
252 // This is a cross replica operation. replica group contains replica id.
253 // use current replica id to find the set of participating replicas. If
254 // replica groups are empty, assume a group with all replicas.
255 TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
256 GetParticipatingIDs(current_replica_id, replica_count,
257 replica_groups));
258
259 // The set of participating devices is the replicas from the current
260 // partition.
261 participants.reserve(participating_replicas.size());
262 for (int replica_id : participating_replicas) {
263 participants.emplace_back(
264 device_assignment(replica_id, current_partition_id));
265 }
266 return participants;
267 }
268
269 case CollectiveOpGroupMode::kCrossPartition: {
270 // replica_groups contain partition_id, group contains all partitions for
271 // the current replica.
272 TF_ASSIGN_OR_RETURN(std::vector<int> participating_partitions,
273 GetParticipatingIDs(current_partition_id,
274 partition_count, replica_groups));
275 participants.reserve(participating_partitions.size());
276 for (int partition_id : participating_partitions) {
277 participants.emplace_back(
278 device_assignment(current_replica_id, partition_id));
279 }
280 return participants;
281 }
282
283 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
284 // replica_groups contain replica_ids. Group contains replicas for all
285 // partitions.
286 TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
287 GetParticipatingIDs(current_replica_id, replica_count,
288 replica_groups));
289 participants.reserve(participating_replicas.size() * partition_count);
290 for (int replica_id : participating_replicas) {
291 for (int partition_id = 0; partition_id < partition_count;
292 ++partition_id) {
293 participants.emplace_back(
294 device_assignment(replica_id, partition_id));
295 }
296 }
297 return participants;
298 }
299
300 case CollectiveOpGroupMode::kFlattenedID: {
301 // replica groups contain flattened-ids and cannot be empty.
302 TF_RET_CHECK(!replica_groups.empty())
303 << "replica groups cannot be empty for kFlattenedID mode";
304
305 int current_flattened_id =
306 current_replica_id * partition_count + current_partition_id;
307
308 // Find participants based on flattened id. replica_groups cannot be empty
309 // so no need to pass in total_participant_count.
310 TF_ASSIGN_OR_RETURN(
311 std::vector<int> participating_flattened_ids,
312 GetParticipatingIDs(current_flattened_id,
313 /*total_participant_count=*/absl::nullopt,
314 replica_groups));
315
316 participants.reserve(participating_flattened_ids.size());
317 for (int flattened_id : participating_flattened_ids) {
318 // Map from flattened id back to replica_id, partition_id.
319 int replica_id = flattened_id / partition_count;
320 int partition_id = flattened_id % partition_count;
321 participants.emplace_back(device_assignment(replica_id, partition_id));
322 }
323 return participants;
324 }
325 }
326 }
327
ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,absl::Span<const ReplicaGroup> second)328 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
329 absl::Span<const ReplicaGroup> second) {
330 if (first.size() != second[0].replica_ids_size()) {
331 return false;
332 }
333 if (first[0].replica_ids_size() != second.size()) {
334 return false;
335 }
336 for (int64_t i = 0; i < first.size(); ++i) {
337 for (int64_t j = 0; j < first[i].replica_ids_size(); ++j) {
338 if (first[i].replica_ids(j) != second[j].replica_ids(i)) {
339 return false;
340 }
341 }
342 }
343 return true;
344 }
345
346 } // end namespace xla
347