1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/call_graph.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32
33 namespace xla {
34
35 namespace m = match;
36
37 // Checks if the argument instruction is an AllReduce, followed by a certain
38 // sequence of instructions and then a CRS. It must be possible to move
39 // the AR past each instruction in the sequence. Returns the CRS, which is the
40 // last instruction in the sequence.
MatchesArCrsPattern(HloInstruction * instruction)41 absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern(
42 HloInstruction* instruction) {
43 auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
44 if (instruction->user_count() != 1) {
45 return false;
46 }
47 switch (instruction->opcode()) {
48 case HloOpcode::kBitcast:
49 case HloOpcode::kTranspose:
50 case HloOpcode::kReshape:
51 return true;
52 case HloOpcode::kConvert:
53 // Can be moved across if both input and output is either float or
54 // integer (e.g. S32<->U32 or F32<->BF16)
55 return ShapeUtil::ElementIsFloating(instruction->shape()) ==
56 ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
57 case HloOpcode::kAdd:
58 case HloOpcode::kSubtract:
59 case HloOpcode::kMultiply:
60 // Only supported for floating point operands.
61 return ShapeUtil::ElementIsFloating(instruction->shape());
62 default:
63 return false;
64 }
65 };
66
67 auto computation_is_addition = [](HloComputation* c) {
68 return c->instruction_count() == 3 &&
69 Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
70 };
71
72 if (!instruction->IsCrossModuleAllReduce() ||
73 !computation_is_addition(instruction->called_computations()[0]) ||
74 instruction->user_count() != 1) {
75 return absl::nullopt;
76 }
77 auto next = instruction->users()[0];
78 int64 distance = 1;
79 while (!next->IsCrossReplicaAllReduce()) {
80 if (can_ar_move_past_instruction(next)) {
81 next = next->users()[0];
82 } else {
83 return absl::nullopt;
84 }
85 ++distance;
86 }
87 if (!Cast<HloAllReduceInstruction>(next)->IsNoop() &&
88 computation_is_addition(next->called_computations()[0])) {
89 return absl::optional<ArCrsPair>(ArCrsPair(instruction, next, distance));
90 } else {
91 return absl::nullopt;
92 }
93 }
94
WhileFromBodyParameter(HloInstruction * instruction)95 absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter(
96 HloInstruction* instruction) {
97 CHECK_EQ(HloOpcode::kParameter, instruction->opcode());
98 HloComputation* computation = instruction->parent();
99 auto caller_instructions = call_graph_->GetComputationCallers(computation);
100 if (caller_instructions.size() == 1) {
101 auto caller_instruction = caller_instructions[0];
102 if (caller_instruction->opcode() == HloOpcode::kWhile) {
103 return caller_instruction;
104 }
105 }
106 return absl::nullopt;
107 }
108
GetAllTuples(HloInstruction * instruction)109 std::vector<HloInstruction*> ArCrsCombiner::GetAllTuples(
110 HloInstruction* instruction) {
111 if (instruction->opcode() == HloOpcode::kTuple) {
112 return {instruction};
113 }
114 if (instruction->opcode() == HloOpcode::kDomain) {
115 return GetAllTuples(instruction->operands()[0]);
116 }
117 if (instruction->opcode() == HloOpcode::kParameter) {
118 auto maybe_while = WhileFromBodyParameter(instruction);
119 if (!maybe_while) {
120 return {};
121 }
122 auto while_instr = *maybe_while;
123 auto init_tuples = GetAllTuples(while_instr->while_init());
124 auto body_tuples =
125 GetAllTuples(while_instr->while_body()->root_instruction());
126 if (init_tuples.empty() || body_tuples.empty()) {
127 return {};
128 }
129 init_tuples.insert(init_tuples.end(), body_tuples.begin(),
130 body_tuples.end());
131 return init_tuples;
132 }
133 if (instruction->opcode() == HloOpcode::kGetTupleElement) {
134 std::vector<HloInstruction*> result_tuples;
135 for (auto tuple : GetAllTuples(instruction->operands()[0])) {
136 auto tmp_tuples =
137 GetAllTuples(tuple->mutable_operand(instruction->tuple_index()));
138 if (tmp_tuples.empty()) {
139 return {};
140 }
141 result_tuples.insert(result_tuples.end(), tmp_tuples.begin(),
142 tmp_tuples.end());
143 }
144 return result_tuples;
145 }
146 return {};
147 }
148
TupleElementsComputeSameValue(HloInstruction * tuple_shaped_instruction,int64 i1,int64 i2,absl::flat_hash_map<int64,int64> * visited_pairs)149 bool ArCrsCombiner::TupleElementsComputeSameValue(
150 HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
151 absl::flat_hash_map<int64, int64>* visited_pairs) {
152 auto tuples = GetAllTuples(tuple_shaped_instruction);
153 if (tuples.empty()) {
154 return false;
155 }
156 for (auto tuple : tuples) {
157 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
158 if (!InstructionsComputeSameValue(tuple->mutable_operand(i1),
159 tuple->mutable_operand(i2),
160 visited_pairs)) {
161 return false;
162 }
163 }
164 return true;
165 }
166
167 /* static */
TestInstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2)168 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1,
169 HloInstruction* i2) {
170 ArCrsCombiner combiner(/*num_spatial_partitions=*/2);
171 auto module = i1->parent()->parent();
172 CHECK_EQ(module, i2->parent()->parent());
173 combiner.call_graph_ = CallGraph::Build(module);
174 absl::flat_hash_map<int64, int64> visited_pairs;
175 return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs);
176 }
177
InstructionsComputeSameValue(HloInstruction * i1,HloInstruction * i2,absl::flat_hash_map<int64,int64> * visited_pairs)178 bool ArCrsCombiner::InstructionsComputeSameValue(
179 HloInstruction* i1, HloInstruction* i2,
180 absl::flat_hash_map<int64, int64>* visited_pairs) {
181 if (i1 == i2) {
182 return true;
183 }
184 auto uid1 = i1->unique_id();
185 auto uid2 = i2->unique_id();
186 auto min_uid = std::min(uid1, uid2);
187 auto max_uid = std::max(uid1, uid2);
188 auto it = visited_pairs->find(min_uid);
189 if (it != visited_pairs->end() && max_uid == it->second) {
190 return true;
191 }
192 auto opcode1 = i1->opcode();
193 auto operands1 = i1->operands();
194 if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) {
195 return false;
196 }
197 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
198 return *a == *b;
199 };
200 if (i1->IsCrossModuleAllReduce()) {
201 return i1->Identical(*i2,
202 /*eq_operands=*/std::equal_to<const HloInstruction*>(),
203 eq_computations,
204 /*layout_sensitive=*/false);
205 }
206 visited_pairs->emplace(min_uid, max_uid);
207 for (int i = 0; i < operands1.size(); ++i) {
208 auto operand1 = operands1[i];
209 auto operand2 = i2->operands()[i];
210 if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) {
211 return false;
212 }
213 }
214 if (opcode1 == HloOpcode::kParameter) {
215 // In the general case, we don't try to prove equality of parameters.
216 // We only try in the context of get-tuple-element
217 // (see TupleElementsComputeSameValue).
218 return false;
219 }
220 if (opcode1 == HloOpcode::kGetTupleElement) {
221 return i1->tuple_index() == i2->tuple_index() ||
222 TupleElementsComputeSameValue(operands1[0], i1->tuple_index(),
223 i2->tuple_index(), visited_pairs);
224 }
225 // Don't check that the operands are identical, because Identical can
226 // return false for instructions that compute the same value but are not
227 // identical, which we don't want. We have checked the arguments with
228 // InstructionsComputeSameValue earlier.
229 auto eq_instructions = [](const HloInstruction* i1,
230 const HloInstruction* i2) -> bool { return true; };
231 return i1->Identical(*i2, eq_instructions, eq_computations,
232 /*layout_sensitive=*/false);
233 }
234
GroupAllReducesById(HloModule * module)235 void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
236 // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS),
237 // ... , (ARn, CRS).
238 // If as we traverse the HLO graph we start tracking the pair (AR2, CRS),
239 // and later find that AR1's distance from the CRS is longer, we discard
240 // AR2 and start tracking AR1. We put the discarded ids in this set, in order
241 // to skip processing of short paths when we encounter the other ARs that
242 // have the same id as AR2.
243 absl::flat_hash_set<int64> discarded_ar_ids;
244 for (HloComputation* computation : module->MakeNonfusionComputations()) {
245 for (HloInstruction* instruction : computation->instructions()) {
246 auto maybe_pair = MatchesArCrsPattern(instruction);
247 if (maybe_pair) {
248 auto pair = *maybe_pair;
249 int64 ar_id = *(instruction->all_reduce_id());
250 if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
251 continue;
252 }
253 auto it = crs_reserved_map_.find(pair.crs);
254 if (it != crs_reserved_map_.end()) {
255 auto prev_ar_id = it->second;
256 // Since there is another AR paired with CRS,
257 // all_reduce_map_[prev_ar_id] should exist, but
258 // all_reduce_map_[ar_id] shouldn't.
259 CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end());
260 CHECK_NE(prev_ar_id, ar_id);
261 auto prev_pair = all_reduce_map_[prev_ar_id].back();
262 int64 prev_distance = prev_pair.distance;
263 if (prev_distance < pair.distance) {
264 // The current AR's distance to CRS is longer than the previously
265 // tracked AR, so we discard the previous AR.
266 all_reduce_map_.erase(prev_ar_id);
267 discarded_ar_ids.insert(prev_ar_id);
268 all_reduce_map_[ar_id].push_back(pair);
269 crs_reserved_map_[pair.crs] = ar_id;
270 } else {
271 // Discard the current AR id because we are keeping the previously
272 // tracked AR.
273 discarded_ar_ids.insert(ar_id);
274 }
275 } else {
276 if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) {
277 int64 prev_distance = all_reduce_map_[ar_id].back().distance;
278 CHECK_EQ(prev_distance, pair.distance)
279 << "All ARs with the same AR ID must have the same distance "
280 "from the corresponding CRSs. Found: "
281 << prev_distance << " and " << pair.distance;
282 }
283 all_reduce_map_[ar_id].push_back(pair);
284 crs_reserved_map_[pair.crs] = ar_id;
285 }
286 }
287 }
288 }
289 }
290
KeepProvablyEqualInstructionGroups()291 void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
292 for (auto it : all_reduce_map_) {
293 auto all_reduce_id = it.first;
294 auto pairs_vec = it.second;
295 CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
296 auto instr_0 = pairs_vec[0].ar;
297 for (int i = 1; i < pairs_vec.size(); ++i) {
298 auto instr_i = pairs_vec[i].ar;
299 auto next_0 = instr_0->users()[0];
300 auto next_i = instr_i->users()[0];
301 absl::flat_hash_map<int64, int64> visited_pairs;
302 while (true) {
303 if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
304 all_reduce_map_.erase(all_reduce_id);
305 break;
306 }
307 if (next_0->IsCrossReplicaAllReduce()) {
308 break;
309 }
310 next_0 = next_0->users()[0];
311 next_i = next_i->users()[0];
312 }
313 }
314 }
315 }
316
RewriteGraph()317 StatusOr<bool> ArCrsCombiner::RewriteGraph() {
318 if (all_reduce_map_.empty()) {
319 return false;
320 }
321 for (auto it : all_reduce_map_) {
322 auto pairs_vec = it.second;
323 for (auto pair : pairs_vec) {
324 auto all_reduce = pair.ar;
325 auto parent_computation = all_reduce->parent();
326 auto all_reduce_id = all_reduce->all_reduce_id();
327 auto prev = all_reduce->mutable_operand(0);
328 auto next = all_reduce->users()[0];
329 TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
330 TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
331 while (!next->IsCrossReplicaAllReduce()) {
332 switch (next->opcode()) {
333 case HloOpcode::kBitcast:
334 case HloOpcode::kTranspose:
335 case HloOpcode::kReshape:
336 case HloOpcode::kConvert:
337 case HloOpcode::kMultiply:
338 break;
339 case HloOpcode::kAdd:
340 case HloOpcode::kSubtract: {
341 auto other_operand = (next->operands()[0] == prev)
342 ? next->operands()[1]
343 : next->operands()[0];
344 // To move the AR past the addition/subtraction, we need to divide
345 // other_operand by the number of spatial partitions, except if
346 // other_operand is a cross-module AR, which can be eliminated.
347 if (other_operand->IsCrossModuleAllReduce() &&
348 other_operand->user_count() == 1) {
349 TF_CHECK_OK(other_operand->ReplaceAllUsesWith(
350 other_operand->mutable_operand(0)));
351 } else {
352 auto shape = other_operand->shape();
353 Literal lit(shape);
354 lit.PopulateWithValue<float>(num_spatial_partitions_);
355 auto divisor = parent_computation->AddInstruction(
356 HloInstruction::CreateConstant(lit.Clone()));
357 auto division = parent_computation->AddInstruction(
358 HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
359 other_operand, divisor));
360 TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
361 }
362 break;
363 }
364 default:
365 LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
366 }
367 prev = next;
368 next = next->users()[0];
369 }
370 // The AllReduce and the CRS are combined to an all-core AllReduce.
371 next->set_all_reduce_id(all_reduce_id);
372 }
373 }
374 return true;
375 }
376
Run(HloModule * module)377 StatusOr<bool> ArCrsCombiner::Run(HloModule* module) {
378 call_graph_ = CallGraph::Build(module);
379
380 GroupAllReducesById(module);
381
382 KeepProvablyEqualInstructionGroups();
383
384 return RewriteGraph();
385 }
386
387 } // namespace xla
388