1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_query.h"
35 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
36 #include "tensorflow/compiler/xla/service/shape_inference.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace xla {
44 namespace {
45
46 // Combines the elements of to_combine into a single AllReduce op. All
47 // entries in to_combine must be AllReduce ops with exactly one operand
48 // and the same reduction operation.
CombineAllReduces(absl::Span<HloInstruction * const> to_combine)49 Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
50 if (to_combine.size() < 2) {
51 return Status::OK();
52 }
53 VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
54
55 HloComputation& computation = *to_combine.back()->parent();
56 HloComputation* reduction = to_combine[0]->to_apply();
57 const HloOpcode type = reduction->root_instruction()->opcode();
58
59 // Create a single bigger AllReduce of the operands of the smaller
60 // AllReduces.
61 std::vector<HloInstruction*> operands;
62 std::vector<Shape> operand_shapes;
63 VLOG(1) << "Combining set";
64 for (HloInstruction* hlo : to_combine) {
65 VLOG(1) << "Set element: " << hlo->ToString();
66 TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
67 TF_RET_CHECK(hlo->operands().size() == 1);
68 TF_RET_CHECK(hlo->to_apply() == reduction ||
69 (hlo->to_apply()->instruction_count() == 3 &&
70 hlo->to_apply()->num_parameters() == 2 &&
71 hlo->to_apply()->root_instruction()->opcode() == type));
72 TF_RET_CHECK(hlo->shape().IsArray());
73 for (HloInstruction* operand : hlo->operands()) {
74 operands.push_back(operand);
75 operand_shapes.push_back(operand->shape());
76 }
77 }
78
79 HloInstruction* combined;
80 // AllReduce ops with more than one operand produce a tuple.
81 TF_RET_CHECK(operands.size() >= 2);
82 combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
83 ShapeUtil::MakeTupleShape(operand_shapes), operands, reduction,
84 to_combine.front()->replica_groups(),
85 /*constrain_layout=*/false, to_combine.front()->channel_id(),
86 Cast<HloAllReduceInstruction>(to_combine.front())
87 ->use_global_device_ids()));
88
89 // We have to propagate the sharding manually because Domain instructions are
90 // not guaranteed to preserve it for side effecting instructions.
91 if (to_combine.front()->has_sharding()) {
92 combined->set_sharding(to_combine.front()->sharding());
93 }
94 VLOG(1) << "Replacing with : " << combined->ToString();
95
96 // Replace all the smaller AllReduces with elements of the tuple output
97 // of the single bigger AllReduce.
98 for (int64 i = 0; i < to_combine.size(); ++i) {
99 auto replace_with = HloInstruction::CreateGetTupleElement(
100 to_combine[i]->shape(), combined, i);
101 TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
102 to_combine[i], std::move(replace_with)));
103 }
104 return Status::OK();
105 }
106
107 struct GroupKey {
GroupKeyxla::__anon921b162c0111::GroupKey108 GroupKey(const HloInstruction* hlo, const HloDomainMap& domain_map)
109 : opcode(hlo->to_apply()->root_instruction()->opcode()),
110 accum_type(hlo->to_apply()->root_instruction()->shape().element_type()),
111 domain_id(domain_map.GetDomainMetadataId(hlo)),
112 is_cross_shard(hlo->channel_id().has_value()),
113 use_global_device_ids(
114 Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids()),
115 replica_groups(hlo->replica_groups()) {}
116
operator <xla::__anon921b162c0111::GroupKey117 bool operator<(const GroupKey& other) const {
118 if (opcode != other.opcode) {
119 return opcode < other.opcode;
120 }
121 if (accum_type != other.accum_type) {
122 return accum_type < other.accum_type;
123 }
124 if (domain_id != other.domain_id) {
125 return domain_id < other.domain_id;
126 }
127 if (is_cross_shard != other.is_cross_shard) {
128 return is_cross_shard < other.is_cross_shard;
129 }
130 if (use_global_device_ids != other.use_global_device_ids) {
131 return use_global_device_ids < other.use_global_device_ids;
132 }
133 if (replica_groups.size() != other.replica_groups.size()) {
134 return replica_groups.size() < other.replica_groups.size();
135 }
136 for (int64 i = 0; i < replica_groups.size(); ++i) {
137 const auto& rg = replica_groups[i];
138 const auto& org = other.replica_groups[i];
139 if (rg.replica_ids_size() != org.replica_ids_size()) {
140 return rg.replica_ids_size() < org.replica_ids_size();
141 }
142 for (int64 j = 0; j < rg.replica_ids_size(); ++j) {
143 if (rg.replica_ids(j) != org.replica_ids(j)) {
144 return rg.replica_ids(j) < org.replica_ids(j);
145 }
146 }
147 }
148 return false;
149 }
150
151 HloOpcode opcode;
152 PrimitiveType accum_type;
153 int64 domain_id;
154 bool is_cross_shard;
155 bool use_global_device_ids;
156 std::vector<ReplicaGroup> replica_groups;
157 };
158
159 // Group AllReduce instructions by the reduction types, e.g., add, min,
160 // max, replica groups and domain. For cross-module all reduce instructions
161 // we group them by the set of domains they are reducing across.
162 //
163 // Note that the shape of the reduction computation is not included in the
164 // reduction types, e.g.: "f32[] add" and "bf16[] add" will be the same type. We
165 // need to disallow combining CRS instructions with different domain metadata as
166 // well as that could end up short-cutting two or more different domains.
167 //
168 // In each group, the instructions should be in post order. We will then iterate
169 // each group and try to combine them, so to prevent non-determinism, we use
170 // std::map here.
171 //
172 // The return value is a list of groups where every group contains a list of
173 // all-reduce instruction sets in topological order and with a deterministic
174 // order within the set. Additionally due to the above constraints every all
175 // reduce set within a group will contain the same number of elements
176 // and every instruction within an all reduce set will have the same
177 // all-reduce-id (if specified) and thus shape (all reduce sets without an
178 // all-reduce-id will have a single instruction).
179 using InstructionGroups =
180 std::vector<std::vector<std::vector<HloInstruction*>>>;
CreateComputationGroups(HloComputation * computation)181 StatusOr<InstructionGroups> CreateComputationGroups(
182 HloComputation* computation) {
183 TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
184
185 // Group instructions by opcode, domain id and replica group.
186 std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
187 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
188 if (instruction->opcode() != HloOpcode::kAllReduce) {
189 continue;
190 }
191 if (instruction->to_apply()->instruction_count() != 3 ||
192 instruction->to_apply()->num_parameters() != 2) {
193 VLOG(1) << "Skipping due to non-trivial reduction function.";
194 continue;
195 }
196 opcode_groups[GroupKey(instruction, *domain_map)].push_back(instruction);
197 }
198
199 // Generate a unique all-reduce-id for instructions without one by negating
200 // the unique id of the hlo. This way we can treat cross module and normal CRS
201 // instructions uniformly.
202 auto channel_id = [](const HloInstruction* all_reduce) {
203 return all_reduce->IsCrossModuleAllReduce()
204 ? all_reduce->channel_id().value()
205 : -1 * all_reduce->unique_id();
206 };
207
208 // Group instructions by all-reduce id with instructions for an all-reduce id
209 // is listed along their group id and the (group id, instruction) pairs are
210 // sorted by group id in the vector.
211 std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
212 all_reduce_sets;
213 int64 group_id = 0;
214 for (auto& domain_groups : opcode_groups) {
215 for (HloInstruction* hlo : domain_groups.second) {
216 all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
217 }
218 ++group_id;
219 }
220
221 // Group instructions by participating group ids. Instructions within a group
222 // are sorted by topological order and instructions within an all reduce group
223 // is still sorted by group id.
224 std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
225 all_reduce_group_map;
226 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
227 if (instruction->opcode() != HloOpcode::kAllReduce) {
228 continue;
229 }
230 if (instruction->to_apply()->instruction_count() != 3 ||
231 instruction->to_apply()->num_parameters() != 2) {
232 VLOG(1) << "Skipping due to non-trivial reduction function.";
233 continue;
234 }
235
236 int64 arid = channel_id(instruction);
237 if (all_reduce_sets.count(arid) == 0) {
238 // Already processed.
239 continue;
240 }
241
242 std::vector<int64> group_ids;
243 std::vector<HloInstruction*> instructions;
244 for (const auto& hlo : all_reduce_sets[arid]) {
245 group_ids.push_back(hlo.first);
246 instructions.push_back(hlo.second);
247 }
248 all_reduce_group_map[group_ids].push_back(std::move(instructions));
249 all_reduce_sets.erase(arid);
250 }
251 CHECK(all_reduce_sets.empty());
252
253 InstructionGroups groups;
254 for (const auto& all_reduce_group : all_reduce_group_map) {
255 groups.push_back(all_reduce_group.second);
256 }
257 return std::move(groups);
258 }
259
260 } // namespace
261
AllReduceCombiner(int64 combine_threshold_in_bytes,int64 combine_threshold_count)262 AllReduceCombiner::AllReduceCombiner(int64 combine_threshold_in_bytes,
263 int64 combine_threshold_count)
264 : combine_threshold_in_bytes_(combine_threshold_in_bytes),
265 combine_threshold_count_(combine_threshold_count) {}
266
Run(HloModule * module)267 StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
268 VLOG(1) << "Running AllReduceCombiner with threshold of "
269 << combine_threshold_in_bytes_ << " bytes";
270
271 if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
272 VLOG(1) << "Skip AllReduceCombiner because the threshold is zero";
273 return false;
274 }
275
276 if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
277 VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
278 "with constrained layouts";
279 return false;
280 }
281
282 bool changed = false;
283 for (HloComputation* computation : module->MakeNonfusionComputations()) {
284 TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
285 for (auto group : groups) {
286 // Recompute reachability after every combine group because we can't
287 // maintain a cross group topolgical order to be able to rely on the
288 // transitive dependencies to detect cycles.
289 auto reachability = HloReachabilityMap::Build(computation);
290
291 // Create a map to be able to find an instruction group based on the first
292 // instruction in the group. It will be used during the post order
293 // iteration to be able to process full groups at a time. Doing it only
294 // for one instruction in every group will be sufficient because all
295 // instruction have to schedule at the same time due to cross core
296 // dependencies.
297 absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
298 group_map;
299 for (auto& instruction : group) {
300 group_map[instruction.front()] = &instruction;
301 }
302
303 // Collect sets of AllReduce instructions to combine.
304 std::vector<std::vector<std::vector<HloInstruction*>>> combine_sets(1);
305 int64 current_size_in_bytes = 0;
306 int64 current_operand_count = 0;
307
308 // Iterate all instructions in post order and skip the ones not in the
309 // current group. We have to create a new post order iteration for every
310 // group because merging instructions in the previous group can made the
311 // original post order no longer hold.
312 // This will make it likely that we won't increase memory pressure much
313 // above combine_threshold_in_bytes, since two AllReduces that are
314 // near in post order are most likely, but not for sure, also near in
315 // scheduled order.
316 //
317 // TODO(b/70235266): This should usually be fine, but it's probably
318 // possible to construct some case where the memory usage increases beyond
319 // the threshold due to reordering of the instructions in scheduling. If
320 // this ever comes up as a real problem, it would be nice to implement
321 // safeguards so that that cannot possibly happen.
322 for (const HloInstruction* inst :
323 computation->MakeInstructionPostOrder()) {
324 auto it = group_map.find(inst);
325 if (it == group_map.end()) {
326 // Instruction belongs to a different group.
327 continue;
328 }
329 const auto& instructions = *it->second;
330
331 VLOG(1) << "Considering HLO " << instructions.front()->ToString()
332 << " with current set size of " << current_size_in_bytes
333 << " and current operand count of " << current_operand_count;
334
335 // We do not handle AllReduce ops that do not have exactly 1
336 // operand since that is simpler and this pass is the only way to
337 // generate such ops and it should rarely be important to consider the
338 // same ops again.
339 if (instructions.front()->operands().size() != 1) {
340 VLOG(1) << "Skipping due to "
341 << instructions.front()->operands().size() << " operands";
342 continue;
343 }
344
345 int64 size_in_bytes;
346 TF_RET_CHECK(instructions.front()->shape().IsArray());
347 size_in_bytes = ShapeUtil::ByteSizeOf(instructions.front()->shape());
348
349 if (size_in_bytes > combine_threshold_in_bytes_) {
350 VLOG(1) << "Skipping due to size " << size_in_bytes
351 << " above threshold";
352 // If the instruction is greather than the threshold, then we can
353 // never combine it with anything.
354 continue;
355 }
356
357 // If the current set is dependent on the instruction, then create a new
358 // one to avoid the dependency. We move on from the current set instead
359 // of ignoring the instruction since otherwise a single AllReduce
360 // instruction that all the other ones depend on (such as one on the
361 // forward pass of a model) could disable this optimization entirely.
362 TF_RET_CHECK(!combine_sets.empty());
363 for (const auto& previous : combine_sets.back()) {
364 // The reachability information does not reflect the planned
365 // combination from combine_sets. We cannot just bring it up to date
366 // cheaply since HloReachabilityMap does not track reachability
367 // updates transitively and doing it directly is expensive. However,
368 // leaving it stale has no effect on the reachability queries that we
369 // are doing here because we are considering the ops in a topological
370 // order, so we can just leave it stale.
371 //
372 // Proof: Suppose A is the instruction we are looking to combine and B
373 // is an element of the current combine set that we are looking to
374 // combine A into.
375 //
376 // First of all, we check that all elements in each set do not depend
377 // on each other, so combining the *current* combine set cannot create
378 // new dependencies between A and B. It remains to prove that
379 // combining the prior combine sets also cannot create a dependency
380 // between A and B.
381 //
382 // Assume to get a contradiction that there are two AllReduce
383 // ops C and D in combine_sets that will be combined and that A and B
384 // are not connected now but that they will be after combining C and
385 // D. Then there exist paths in the dependency graph such that one of
386 // these cases is true:
387 //
388 // A -> ... -> C and D -> ... -> B
389 // A -> ... -> D and C -> ... -> B
390 // B -> ... -> C and D -> ... -> A
391 // B -> ... -> D and C -> ... -> A
392 //
393 // None of these cases are possible because we are visiting the nodes
394 // in a topological order, so C and D cannot be in-between A and B.
395 // That is a contradiction, so combining the prior combine sets also
396 // cannot create a dependency between A and B.
397 bool new_set = false;
398 for (int64 i = 0; i < instructions.size(); ++i) {
399 if (reachability->IsReachable(previous[i], instructions[i])) {
400 VLOG(1) << "Starting new set due to dependency between "
401 << previous[i]->ToString() << " AND "
402 << instructions[i]->ToString();
403 new_set = true;
404 break;
405 }
406 }
407 if (new_set) {
408 combine_sets.emplace_back();
409 current_size_in_bytes = 0;
410 current_operand_count = 0;
411 break;
412 }
413 }
414
415 if (current_size_in_bytes + size_in_bytes >
416 combine_threshold_in_bytes_ ||
417 current_operand_count + 1 > combine_threshold_count_) {
418 VLOG(1) << "The instruction cannot be entered into the set due "
419 "to the combined size being too large.";
420 // In this case we cannot include the instruction into the current set
421 // since then it would grow beyond the threshold. The set of
422 // instructions to carry forward will either be the current set or the
423 // instruction by itself, whichever is smaller, since that maximizes
424 // the chance of being able to combine with the next instruction.
425 if (size_in_bytes > current_size_in_bytes) {
426 VLOG(1) << "Skipping as the instruction is larger than the set.";
427 continue; // keep the current set
428 }
429 VLOG(1)
430 << "Resetting the set as the set is larger than the instruction.";
431 combine_sets.emplace_back();
432 current_size_in_bytes = 0;
433 current_operand_count = 0;
434 }
435
436 VLOG(1) << "Adding instruction to set.";
437 combine_sets.back().push_back(instructions);
438 current_size_in_bytes += size_in_bytes;
439 current_operand_count += 1;
440 TF_RET_CHECK(current_size_in_bytes <= combine_threshold_in_bytes_);
441 TF_RET_CHECK(current_operand_count <= combine_threshold_count_);
442 }
443 VLOG(1) << "Done constructing sets. Final set size is "
444 << current_size_in_bytes << " bytes and " << current_operand_count
445 << " operands";
446
447 // Combine the collected sets of AllReduce instructions.
448 for (const auto& combine_set : combine_sets) {
449 if (combine_set.size() >= 2) {
450 changed = true;
451 for (int64 i = 0; i < combine_set.front().size(); ++i) {
452 std::vector<HloInstruction*> to_combine;
453 to_combine.reserve(combine_set.size());
454 for (const auto& c : combine_set) {
455 to_combine.push_back(c[i]);
456 }
457 TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));
458 }
459 }
460 }
461 }
462 }
463
464 return changed;
465 }
466
467 } // namespace xla
468