1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
17
18 #include "tensorflow/compiler/xla/service/all_reduce_key.h"
19 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_query.h"
25
26 namespace xla {
27 namespace {
28
29 // Returns if the given all reduce instructions are compatible with each other.
30 // Note that since the given all-reduce instructions are connected to another
31 // instruction by a direct data flow edge, they must belong to the same domain.
32 // As a result, we don't need to include any domain information in the
33 // AllReduceKey to check compatibility.
AreCompatible(const HloAllReduceInstruction * ar0,const HloAllReduceInstruction * ar1,ReductionKind op_kind)34 bool AreCompatible(const HloAllReduceInstruction *ar0,
35 const HloAllReduceInstruction *ar1, ReductionKind op_kind) {
36 absl::optional<AllReduceKey> key0 = GetAllReduceKey(ar0);
37 absl::optional<AllReduceKey> key1 = GetAllReduceKey(ar1);
38 auto kind0 = MatchReductionComputation(ar0->to_apply());
39 return key0 && key1 && kind0 && *key0 == *key1 && kind0 == op_kind;
40 }
41
42 } // namespace
43
Run(HloModule * module)44 StatusOr<bool> AllReduceReassociate::Run(HloModule *module) {
45 if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
46 VLOG(1)
47 << "Skip AllReduceReassociate because the module contains all-reduce "
48 "with constrained layouts";
49 return false;
50 }
51
52 int64_t next_channel_id = hlo_query::NextChannelId(*module);
53
54 bool changed = false;
55 for (auto computation : module->computations()) {
56 for (HloInstruction *inst : computation->MakeInstructionPostOrder()) {
57 absl::optional<ReductionKind> kind = MatchReductionInstruction(inst);
58 if (!kind || inst->operand(0)->opcode() != HloOpcode::kAllReduce ||
59 inst->operand(1)->opcode() != HloOpcode::kAllReduce ||
60 !inst->shape().IsArray()) {
61 continue;
62 }
63
64 auto *ar0 = Cast<HloAllReduceInstruction>(inst->mutable_operand(0));
65 auto *ar1 = Cast<HloAllReduceInstruction>(inst->mutable_operand(1));
66 if (!AreCompatible(ar0, ar1, *kind)) {
67 VLOG(2) << "All-Reduce operations are not compatible, skipping";
68 continue;
69 }
70
71 if (ar0->user_count() != 1 || ar1->user_count() != 1) {
72 VLOG(2) << "All-Reduce operations have > 1 users";
73 continue;
74 }
75
76 // Found pattern op(ar(x), ar(y)). Transform it into ar(op(x,y)).
77 HloInstruction *new_op = computation->AddInstruction(
78 inst->CloneWithNewOperands(inst->shape(), {ar0->mutable_operand(0),
79 ar1->mutable_operand(0)}));
80 HloInstruction *new_ar = computation->AddInstruction(
81 ar0->CloneWithNewOperands(inst->shape(), {new_op}));
82
83 // Do not reuse channel_id from the existing instruction.
84 if (new_ar->channel_id()) {
85 new_ar->set_channel_id(next_channel_id++);
86 }
87
88 TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(new_ar));
89 // Note that RemoveInstructionAndUnusedOperands may not remove the 2
90 // all-reduce operands of `inst` if they are not safe to remove otherwise,
91 // so manually these instructions.
92 TF_RETURN_IF_ERROR(computation->RemoveInstruction(inst));
93 TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar0));
94 TF_RETURN_IF_ERROR(computation->RemoveInstruction(ar1));
95
96 changed = true;
97 }
98 }
99 return changed;
100 }
101
102 } // namespace xla
103