• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/xla/service/global_device_id.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
23 #include "tensorflow/compiler/xla/util.h"
24 
25 namespace xla {
26 
27 // Match the instruction to a reduction kind. We can represent and/or of pred as
28 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
MatchReductionInstruction(const HloInstruction * hlo)29 absl::optional<ReductionKind> MatchReductionInstruction(
30     const HloInstruction* hlo) {
31   PrimitiveType type = hlo->shape().element_type();
32   switch (hlo->opcode()) {
33     case HloOpcode::kAdd:
34       return ReductionKind::SUM;
35     case HloOpcode::kMultiply:
36       return ReductionKind::PRODUCT;
37     case HloOpcode::kMinimum:
38       return ReductionKind::MIN;
39     case HloOpcode::kMaximum:
40       return ReductionKind::MAX;
41     case HloOpcode::kAnd:
42       return type == PRED ? absl::optional<ReductionKind>(ReductionKind::MIN)
43                           : absl::nullopt;
44     case HloOpcode::kOr:
45       return type == PRED ? absl::optional<ReductionKind>(ReductionKind::MAX)
46                           : absl::nullopt;
47     default:
48       return absl::nullopt;
49   }
50 }
51 
MatchReductionComputation(const HloComputation * computation)52 absl::optional<ReductionKind> MatchReductionComputation(
53     const HloComputation* computation) {
54   namespace m = match;
55   const HloInstruction* root = computation->root_instruction();
56   auto kind = MatchReductionInstruction(root);
57   if (kind && !Match(root, m::Op()
58                                .WithBinaryOperandsAnyOrder(m::Parameter(0),
59                                                            m::Parameter(1))
60                                .WithShape(m::Shape().IsEffectiveScalar()))) {
61     kind = absl::nullopt;
62   }
63   return kind;
64 }
65 
GetParticipatingIDs(int current_id,absl::optional<int> total_participant_count,absl::Span<const ReplicaGroup> groups)66 StatusOr<std::vector<int>> GetParticipatingIDs(
67     int current_id, absl::optional<int> total_participant_count,
68     absl::Span<const ReplicaGroup> groups) {
69   // Empty replica_groups() means that all replicas participate.
70   if (groups.empty()) {
71     TF_RET_CHECK(total_participant_count.has_value());
72     std::vector<int> all_participants(*total_participant_count);
73     absl::c_iota(all_participants, 0);
74     return all_participants;
75   }
76 
77   // Figure out the other replicas that go together with this one.
78   absl::optional<ReplicaGroup> group;
79   for (const ReplicaGroup& g : groups) {
80     if (absl::c_linear_search(g.replica_ids(), current_id)) {
81       TF_RET_CHECK(!group.has_value())
82           << "ID " << current_id << " appears twice in replica groups";
83       group = g;
84     }
85   }
86   TF_RET_CHECK(group.has_value())
87       << "ID " << current_id << " doesn't appear in replica groups";
88   return std::vector<int>(group->replica_ids().begin(),
89                           group->replica_ids().end());
90 }
91 
92 // Returns the group formation mode implied by (a) whether the operation has
93 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
GetCollectiveOpGroupMode(bool has_channel_id,absl::optional<bool> use_global_device_ids)94 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
95     bool has_channel_id, absl::optional<bool> use_global_device_ids) {
96   if (!has_channel_id) {
97     if (!use_global_device_ids.has_value() || !*use_global_device_ids) {
98       return CollectiveOpGroupMode::kCrossReplica;
99     } else {
100       return InvalidArgument(
101           "Invalid combination of has_channel_id and use_global_device_ids");
102     }
103   } else {
104     if (!use_global_device_ids.has_value()) {
105       return CollectiveOpGroupMode::kCrossPartition;
106     } else if (!*use_global_device_ids) {
107       return CollectiveOpGroupMode::kCrossReplicaAndPartition;
108     } else {
109       return CollectiveOpGroupMode::kFlattenedID;
110     }
111   }
112 }
113 
CollectiveOpGroupModeToString(CollectiveOpGroupMode group_mode)114 absl::string_view CollectiveOpGroupModeToString(
115     CollectiveOpGroupMode group_mode) {
116   switch (group_mode) {
117     case CollectiveOpGroupMode::kCrossReplica:
118       return "kCrossReplica";
119     case CollectiveOpGroupMode::kCrossPartition:
120       return "kCrossPartition";
121     case CollectiveOpGroupMode::kCrossReplicaAndPartition:
122       return "kCrossReplicaAndPartition";
123     case CollectiveOpGroupMode::kFlattenedID:
124       return "kFlattenedID";
125   }
126 }
127 
128 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
GetParticipatingDevicesGroups(const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)129 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
130                               absl::Span<const ReplicaGroup> replica_groups,
131                               CollectiveOpGroupMode group_mode) {
132   int replica_count = device_assignment.replica_count();
133   int partition_count = device_assignment.computation_count();
134 
135   std::vector<ReplicaGroup> participating_replica_groups =
136       SpanToVector(replica_groups);
137 
138   // If replica groups are empty, assume a group with all replicas.
139   if (replica_groups.empty()) {
140     if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
141       // replica groups contain flattened-ids and cannot be empty.
142       TF_RET_CHECK(!replica_groups.empty())
143           << "replica groups cannot be empty for kFlattenedID mode";
144     }
145 
146     int total_participant_count;
147     if (group_mode == CollectiveOpGroupMode::kCrossPartition) {
148       // replica group are partition ids.
149       total_participant_count = partition_count;
150     } else {
151       // replica group are replica ids.
152       total_participant_count = replica_count;
153     }
154 
155     ReplicaGroup replica_group = ReplicaGroup();
156     for (int id = 0; id < total_participant_count; id++) {
157       replica_group.add_replica_ids(id);
158     }
159     participating_replica_groups.push_back(replica_group);
160   }
161 
162   std::vector<std::vector<GlobalDeviceId>> groups;
163   switch (group_mode) {
164     case CollectiveOpGroupMode::kCrossReplica: {
165       for (const auto& replica_group : participating_replica_groups) {
166         // replica_group contains replica id, participants contains all
167         // replica_group's replica_ids for the current partition.
168         for (int partition_id = 0; partition_id < partition_count;
169              partition_id++) {
170           std::vector<GlobalDeviceId> participants;
171           participants.reserve(replica_group.replica_ids().size());
172 
173           for (int replica_id : replica_group.replica_ids()) {
174             participants.emplace_back(
175                 device_assignment(replica_id, partition_id));
176           }
177           groups.push_back(participants);
178         }
179       }
180       return groups;
181     }
182     case CollectiveOpGroupMode::kCrossPartition: {
183       for (const auto& replica_group : participating_replica_groups) {
184         // replica_group contains partition id, participants contains all
185         // replica_group's partition_ids for the current replica_id.
186         for (int replica_id = 0; replica_id < replica_count; replica_id++) {
187           std::vector<GlobalDeviceId> participants;
188           participants.reserve(replica_group.replica_ids().size());
189 
190           for (int partition_id : replica_group.replica_ids()) {
191             participants.emplace_back(
192                 device_assignment(replica_id, partition_id));
193           }
194           groups.push_back(participants);
195         }
196       }
197       return groups;
198     }
199     case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
200       for (const auto& replica_group : participating_replica_groups) {
201         std::vector<GlobalDeviceId> participants;
202         participants.reserve(replica_group.replica_ids().size() *
203                              partition_count);
204 
205         // replica_group contains replica id, participants contains all
206         // replica_group's replica_ids for all partitions.
207         for (int replica_id : replica_group.replica_ids()) {
208           for (int partition_id = 0; partition_id < partition_count;
209                partition_id++) {
210             participants.emplace_back(
211                 device_assignment(replica_id, partition_id));
212           }
213         }
214         groups.push_back(participants);
215       }
216       return groups;
217     }
218     case CollectiveOpGroupMode::kFlattenedID: {
219       for (const auto& replica_group : participating_replica_groups) {
220         std::vector<GlobalDeviceId> participants;
221         participants.reserve(replica_group.replica_ids().size());
222 
223         for (int flattened_id : replica_group.replica_ids()) {
224           // Map from flattened id back to replica_id, partition_id.
225           int replica_id = flattened_id / partition_count;
226           int partition_id = flattened_id % partition_count;
227           participants.emplace_back(
228               device_assignment(replica_id, partition_id));
229         }
230         groups.push_back(participants);
231       }
232       return groups;
233     }
234   }
235 }
236 
GetParticipatingDevices(GlobalDeviceId device_id,const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)237 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
238     GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
239     absl::Span<const ReplicaGroup> replica_groups,
240     CollectiveOpGroupMode group_mode) {
241   int replica_count = device_assignment.replica_count();
242   int partition_count = device_assignment.computation_count();
243 
244   TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
245                       device_assignment.LogicalIdForDevice(device_id));
246   int current_replica_id = logical_id.replica_id;
247   int current_partition_id = logical_id.computation_id;
248 
249   std::vector<GlobalDeviceId> participants;
250   switch (group_mode) {
251     case CollectiveOpGroupMode::kCrossReplica: {
252       // This is a cross replica operation. replica group contains replica id.
253       // use current replica id to find the set of participating replicas. If
254       // replica groups are empty, assume a group with all replicas.
255       TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
256                           GetParticipatingIDs(current_replica_id, replica_count,
257                                               replica_groups));
258 
259       // The set of participating devices is the replicas from the current
260       // partition.
261       participants.reserve(participating_replicas.size());
262       for (int replica_id : participating_replicas) {
263         participants.emplace_back(
264             device_assignment(replica_id, current_partition_id));
265       }
266       return participants;
267     }
268 
269     case CollectiveOpGroupMode::kCrossPartition: {
270       // replica_groups contain partition_id, group contains all partitions for
271       // the current replica.
272       TF_ASSIGN_OR_RETURN(std::vector<int> participating_partitions,
273                           GetParticipatingIDs(current_partition_id,
274                                               partition_count, replica_groups));
275       participants.reserve(participating_partitions.size());
276       for (int partition_id : participating_partitions) {
277         participants.emplace_back(
278             device_assignment(current_replica_id, partition_id));
279       }
280       return participants;
281     }
282 
283     case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
284       // replica_groups contain replica_ids. Group contains replicas for all
285       // partitions.
286       TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
287                           GetParticipatingIDs(current_replica_id, replica_count,
288                                               replica_groups));
289       participants.reserve(participating_replicas.size() * partition_count);
290       for (int replica_id : participating_replicas) {
291         for (int partition_id = 0; partition_id < partition_count;
292              ++partition_id) {
293           participants.emplace_back(
294               device_assignment(replica_id, partition_id));
295         }
296       }
297       return participants;
298     }
299 
300     case CollectiveOpGroupMode::kFlattenedID: {
301       // replica groups contain flattened-ids and cannot be empty.
302       TF_RET_CHECK(!replica_groups.empty())
303           << "replica groups cannot be empty for kFlattenedID mode";
304 
305       int current_flattened_id =
306           current_replica_id * partition_count + current_partition_id;
307 
308       // Find participants based on flattened id. replica_groups cannot be empty
309       // so no need to pass in total_participant_count.
310       TF_ASSIGN_OR_RETURN(
311           std::vector<int> participating_flattened_ids,
312           GetParticipatingIDs(current_flattened_id,
313                               /*total_participant_count=*/absl::nullopt,
314                               replica_groups));
315 
316       participants.reserve(participating_flattened_ids.size());
317       for (int flattened_id : participating_flattened_ids) {
318         // Map from flattened id back to replica_id, partition_id.
319         int replica_id = flattened_id / partition_count;
320         int partition_id = flattened_id % partition_count;
321         participants.emplace_back(device_assignment(replica_id, partition_id));
322       }
323       return participants;
324     }
325   }
326 }
327 
ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,absl::Span<const ReplicaGroup> second)328 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
329                              absl::Span<const ReplicaGroup> second) {
330   if (first.size() != second[0].replica_ids_size()) {
331     return false;
332   }
333   if (first[0].replica_ids_size() != second.size()) {
334     return false;
335   }
336   for (int64_t i = 0; i < first.size(); ++i) {
337     for (int64_t j = 0; j < first[i].replica_ids_size(); ++j) {
338       if (first[i].replica_ids(j) != second[j].replica_ids(i)) {
339         return false;
340       }
341     }
342   }
343   return true;
344 }
345 
346 }  // end namespace xla
347