1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <queue>
17
18 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
19
20 namespace xla {
21
HloReachabilityMap(absl::Span<const HloInstruction * const> instructions)22 HloReachabilityMap::HloReachabilityMap(
23 absl::Span<const HloInstruction* const> instructions)
24 : size_(instructions.size()) {
25 bit_vectors_.reserve(size_);
26 for (const HloInstruction* hlo : instructions) {
27 indices_[GetKey(hlo)] = bit_vectors_.size();
28 bit_vectors_.emplace_back(size_);
29 }
30 CHECK_EQ(size_, indices_.size()); // instructions should be unique
31 }
32
SetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)33 bool HloReachabilityMap::SetReachabilityToUnion(
34 absl::Span<const HloInstruction* const> inputs,
35 const HloInstruction* instruction) {
36 BitVector& bit_vector = GetBitVector(instruction);
37 tmp_bit_vector_ = bit_vector;
38 SetReachabilityToUnionHelper(inputs, instruction, &bit_vector);
39 return bit_vector != tmp_bit_vector_;
40 }
41
FastSetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)42 void HloReachabilityMap::FastSetReachabilityToUnion(
43 absl::Span<const HloInstruction* const> inputs,
44 const HloInstruction* instruction) {
45 SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
46 }
47
SetReachabilityToUnionHelper(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction,BitVector * bit_vector)48 void HloReachabilityMap::SetReachabilityToUnionHelper(
49 absl::Span<const HloInstruction* const> inputs,
50 const HloInstruction* instruction, BitVector* bit_vector) {
51 // If instruction is part of inputs, don't reset the bit_vector.
52 if (!absl::c_linear_search(inputs, instruction)) {
53 bit_vector->SetToZero();
54 }
55 bit_vector->Set(GetIndex(instruction));
56 for (const HloInstruction* input : inputs) {
57 if (input != instruction) {
58 bit_vector->OrWith(GetBitVector(input));
59 }
60 }
61 }
62
SetReachable(const HloInstruction * a,const HloInstruction * b)63 void HloReachabilityMap::SetReachable(const HloInstruction* a,
64 const HloInstruction* b) {
65 GetBitVector(b).Set(GetIndex(a));
66 }
67
IsReachable(const HloInstruction * a,const HloInstruction * b) const68 bool HloReachabilityMap::IsReachable(const HloInstruction* a,
69 const HloInstruction* b) const {
70 return GetBitVector(b).Get(GetIndex(a));
71 }
72
IsConnected(const HloInstruction * a,const HloInstruction * b) const73 bool HloReachabilityMap::IsConnected(const HloInstruction* a,
74 const HloInstruction* b) const {
75 return IsReachable(a, b) || IsReachable(b, a);
76 }
77
Build(const HloComputation * computation)78 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
79 const HloComputation* computation) {
80 const auto& all = computation->MakeInstructionPostOrder();
81 auto result = absl::make_unique<HloReachabilityMap>(all);
82 auto channel_group = computation->ComputeChannelDependencies();
83
84 for (const HloInstruction* hlo : all) {
85 std::vector<HloInstruction*> inputs;
86 const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
87 inputs.push_back(input);
88 if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) {
89 auto it = channel_group.find(*input->all_reduce_id());
90 if (it != channel_group.end()) {
91 inputs.insert(inputs.end(), it->second.begin(), it->second.end());
92 }
93 }
94 };
95
96 const auto add_dependencies = [&add_input](const HloInstruction* hlo) {
97 for (HloInstruction* operand : hlo->operands()) {
98 add_input(operand);
99 }
100 for (HloInstruction* predecessor : hlo->control_predecessors()) {
101 add_input(predecessor);
102 }
103 };
104
105 add_dependencies(hlo);
106
107 switch (hlo->opcode()) {
108 case HloOpcode::kRecvDone: {
109 auto it = channel_group.find(hlo->channel_id());
110 if (it != channel_group.end()) {
111 for (HloInstruction* channel : it->second) {
112 if (channel->opcode() == HloOpcode::kSend) {
113 add_input(channel);
114 }
115 }
116 }
117 break;
118 }
119 case HloOpcode::kAllReduce: {
120 auto all_reduce_id = hlo->all_reduce_id();
121 if (all_reduce_id) {
122 auto it = channel_group.find(all_reduce_id.value());
123 if (it != channel_group.end()) {
124 for (HloInstruction* all_reduce : it->second) {
125 add_dependencies(all_reduce);
126 }
127 }
128 }
129 break;
130 }
131 default:
132 break;
133 }
134
135 result->FastSetReachabilityToUnion(inputs, hlo);
136 }
137 return result;
138 }
139
UpdateReachabilityThroughInstruction(const HloInstruction * instruction)140 void HloReachabilityMap::UpdateReachabilityThroughInstruction(
141 const HloInstruction* instruction) {
142 std::queue<const HloInstruction*> worklist;
143 worklist.push(instruction);
144
145 std::vector<HloInstruction*> inputs;
146
147 while (!worklist.empty()) {
148 const HloInstruction* item = worklist.front();
149 worklist.pop();
150
151 inputs.assign(item->operands().begin(), item->operands().end());
152 inputs.insert(inputs.end(), item->control_predecessors().begin(),
153 item->control_predecessors().end());
154
155 if (SetReachabilityToUnion(inputs, item)) {
156 // Add immediate successors to worklist.
157 for (const HloInstruction* user : item->users()) {
158 worklist.push(user);
159 }
160 for (const HloInstruction* succ : item->control_successors()) {
161 worklist.push(succ);
162 }
163 }
164 }
165 }
166
167 } // namespace xla
168