1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <queue>
17
18 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
19
20 namespace xla {
21
HloReachabilityMap(absl::Span<const HloInstruction * const> instructions)22 HloReachabilityMap::HloReachabilityMap(
23 absl::Span<const HloInstruction* const> instructions)
24 : size_(instructions.size()) {
25 bit_vectors_.reserve(size_);
26 for (const HloInstruction* hlo : instructions) {
27 indices_[GetKey(hlo)] = bit_vectors_.size();
28 bit_vectors_.emplace_back(size_);
29 }
30 CHECK_EQ(size_, indices_.size()); // instructions should be unique
31 }
32
SetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)33 bool HloReachabilityMap::SetReachabilityToUnion(
34 absl::Span<const HloInstruction* const> inputs,
35 const HloInstruction* instruction) {
36 BitVector& bit_vector = GetBitVector(instruction);
37 tmp_bit_vector_ = bit_vector;
38 SetReachabilityToUnionHelper(inputs, instruction, &bit_vector);
39 return bit_vector != tmp_bit_vector_;
40 }
41
FastSetReachabilityToUnion(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction)42 void HloReachabilityMap::FastSetReachabilityToUnion(
43 absl::Span<const HloInstruction* const> inputs,
44 const HloInstruction* instruction) {
45 SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
46 }
47
SetReachabilityToUnionHelper(absl::Span<const HloInstruction * const> inputs,const HloInstruction * instruction,BitVector * bit_vector)48 void HloReachabilityMap::SetReachabilityToUnionHelper(
49 absl::Span<const HloInstruction* const> inputs,
50 const HloInstruction* instruction, BitVector* bit_vector) {
51 // If instruction is part of inputs, don't reset the bit_vector.
52 if (!absl::c_linear_search(inputs, instruction)) {
53 bit_vector->SetToZero();
54 }
55 bit_vector->Set(GetIndex(instruction).v);
56 for (const HloInstruction* input : inputs) {
57 if (input != instruction) {
58 bit_vector->OrWith(GetBitVector(input));
59 }
60 }
61 }
62
Replace(const HloInstruction * original,const HloInstruction * replacement)63 void HloReachabilityMap::Replace(const HloInstruction* original,
64 const HloInstruction* replacement) {
65 if (GetKey(original) == GetKey(replacement)) {
66 return;
67 }
68 indices_[GetKey(replacement)] = GetIndex(original).v;
69 indices_.erase(GetKey(original));
70 }
71
SetReachable(Index a,Index b)72 void HloReachabilityMap::SetReachable(Index a, Index b) {
73 GetBitVector(b).Set(a.v);
74 }
75
BuildWithRestrictions(const HloComputation * computation,absl::FunctionRef<void (const HloInstruction *,std::vector<HloInstruction * > *)> add_dependencies)76 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::BuildWithRestrictions(
77 const HloComputation* computation,
78 absl::FunctionRef<void(const HloInstruction*,
79 std::vector<HloInstruction*>*)>
80 add_dependencies) {
81 const auto& all = computation->MakeInstructionPostOrder();
82 auto result = absl::make_unique<HloReachabilityMap>(all);
83
84 std::vector<HloInstruction*> inputs;
85 for (const HloInstruction* hlo : all) {
86 inputs.clear();
87 add_dependencies(hlo, &inputs);
88 result->FastSetReachabilityToUnion(inputs, hlo);
89 }
90 return result;
91 }
92
Build(const HloComputation * computation)93 std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
94 const HloComputation* computation) {
95 const auto& all = computation->MakeInstructionPostOrder();
96 auto result = absl::make_unique<HloReachabilityMap>(all);
97 auto channel_group = computation->ComputeChannelDependencies();
98
99 std::vector<HloInstruction*> inputs;
100
101 const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
102 inputs.push_back(input);
103 if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) {
104 auto it = channel_group.find(*input->channel_id());
105 if (it != channel_group.end()) {
106 inputs.insert(inputs.end(), it->second.begin(), it->second.end());
107 }
108 }
109 };
110
111 const auto add_dependencies = [&add_input](const HloInstruction* hlo) {
112 for (HloInstruction* operand : hlo->operands()) {
113 add_input(operand);
114 }
115 for (HloInstruction* predecessor : hlo->control_predecessors()) {
116 add_input(predecessor);
117 }
118 };
119
120 for (const HloInstruction* hlo : all) {
121 inputs.clear();
122 add_dependencies(hlo);
123
124 switch (hlo->opcode()) {
125 case HloOpcode::kRecvDone: {
126 auto it = channel_group.find(*hlo->channel_id());
127 if (it != channel_group.end()) {
128 for (HloInstruction* channel : it->second) {
129 if (channel->opcode() == HloOpcode::kSend) {
130 add_input(channel);
131 }
132 }
133 }
134 break;
135 }
136 case HloOpcode::kAllReduce: {
137 auto channel_id = hlo->channel_id();
138 if (channel_id) {
139 auto it = channel_group.find(channel_id.value());
140 if (it != channel_group.end()) {
141 for (HloInstruction* all_reduce : it->second) {
142 add_dependencies(all_reduce);
143 }
144 }
145 }
146 break;
147 }
148 default:
149 break;
150 }
151
152 result->FastSetReachabilityToUnion(inputs, hlo);
153 }
154 return result;
155 }
156
UpdateReachabilityThroughInstruction(const HloInstruction * instruction)157 void HloReachabilityMap::UpdateReachabilityThroughInstruction(
158 const HloInstruction* instruction) {
159 std::queue<const HloInstruction*> worklist;
160 worklist.push(instruction);
161
162 std::vector<HloInstruction*> inputs;
163
164 while (!worklist.empty()) {
165 const HloInstruction* item = worklist.front();
166 worklist.pop();
167
168 inputs.assign(item->operands().begin(), item->operands().end());
169 inputs.insert(inputs.end(), item->control_predecessors().begin(),
170 item->control_predecessors().end());
171
172 if (SetReachabilityToUnion(inputs, item)) {
173 // Add immediate successors to worklist.
174 for (const HloInstruction* user : item->users()) {
175 worklist.push(user);
176 }
177 for (const HloInstruction* succ : item->control_successors()) {
178 worklist.push(succ);
179 }
180 }
181 }
182 }
183
184 } // namespace xla
185