1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_concat_code_motion.h"
17
18 #include <map>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_dce.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
34 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
35 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/status.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46
47 namespace xla {
48
49 namespace {
50
51 // This algorithm tries to group HLO instructions into concat candidates. Each
52 // instruction can only belong to a single group.
53 //
54 // For simplicity, after finding the groups, it in-place updates the first group
55 // member to the full shape, and replaces non-grouped uses with slices of it.
56 // Then it relies on TupleSimplifier, WhileLoopSimplifier, and DCE passes to
57 // remove other elements.
58
59 // Represents a group of elements and how to concat them.
60 struct ConcatGroup {
ConcatGroupxla::__anon7614bdcb0111::ConcatGroup61 ConcatGroup(std::vector<HloInstruction*> elements, int64_t concat_dim,
62 bool inserted_concat_dim)
63 : elements(std::move(elements)),
64 element_sizes(this->elements.size(), 1),
65 element_offsets(this->elements.size(), 0),
66 concat_dim(concat_dim),
67 inserted_concat_dim(inserted_concat_dim) {
68 if (inserted_concat_dim) {
69 absl::c_iota(element_offsets, 0);
70 } else {
71 for (int64_t i = 0; i < element_sizes.size(); ++i) {
72 element_sizes[i] = this->elements[i]->shape().dimensions(concat_dim);
73 if (i > 0) {
74 element_offsets[i] = element_offsets[i - 1] + element_sizes[i - 1];
75 }
76 }
77 }
78 }
79
GetConcatShapexla::__anon7614bdcb0111::ConcatGroup80 Shape GetConcatShape() const {
81 if (inserted_concat_dim) {
82 std::vector<int64> dims;
83 const Shape& element_shape = elements.back()->shape();
84 dims.reserve(element_shape.rank() + 1);
85 for (int64_t i = 0; i < element_shape.rank(); ++i) {
86 if (i == concat_dim) {
87 dims.push_back(elements.size());
88 }
89 dims.push_back(element_shape.dimensions(i));
90 }
91 if (dims.size() == concat_dim) {
92 dims.push_back(elements.size());
93 }
94 return ShapeUtil::MakeShape(element_shape.element_type(), dims);
95 } else {
96 int64_t dim_size = 0;
97 for (int64_t size : element_sizes) {
98 dim_size += size;
99 }
100 Shape shape = elements.back()->shape();
101 shape.set_dimensions(concat_dim, dim_size);
102 return shape;
103 }
104 }
105
CreateSlicexla::__anon7614bdcb0111::ConcatGroup106 HloInstruction* CreateSlice(HloInstruction* full_data, int64_t element_index,
107 HloComputation* comp) const {
108 Shape shape = full_data->shape();
109 shape.set_dimensions(concat_dim, element_sizes[element_index]);
110 std::vector<int64> starts(shape.rank(), 0);
111 std::vector<int64> limits(shape.dimensions().begin(),
112 shape.dimensions().end());
113 starts[concat_dim] = element_offsets[element_index];
114 limits[concat_dim] += starts[concat_dim];
115 auto slice = comp->AddInstruction(HloInstruction::CreateSlice(
116 shape, full_data, starts, limits, std::vector<int64>(shape.rank(), 1)));
117 if (!inserted_concat_dim) {
118 return slice;
119 }
120 std::vector<int64> element_shape;
121 element_shape.reserve(shape.rank());
122 for (int64_t i = 0; i < shape.rank(); ++i) {
123 if (i != concat_dim) {
124 element_shape.push_back(shape.dimensions(i));
125 }
126 }
127 return comp->AddInstruction(HloInstruction::CreateReshape(
128 ShapeUtil::MakeShape(shape.element_type(), element_shape), slice));
129 }
130
CreateConcatxla::__anon7614bdcb0111::ConcatGroup131 HloInstruction* CreateConcat(std::vector<HloInstruction*> input_elements,
132 HloComputation* comp) const {
133 if (inserted_concat_dim) {
134 for (int64_t i = 0; i < input_elements.size(); ++i) {
135 std::vector<int64> element_shape;
136 element_shape.reserve(input_elements[i]->shape().rank());
137 for (int64_t j = 0; j < input_elements[i]->shape().rank(); ++j) {
138 if (j == concat_dim) {
139 element_shape.push_back(1);
140 }
141 element_shape.push_back(input_elements[i]->shape().dimensions(j));
142 }
143 if (element_shape.size() == concat_dim) {
144 element_shape.push_back(1);
145 }
146 input_elements[i] = comp->AddInstruction(HloInstruction::CreateReshape(
147 ShapeUtil::MakeShape(input_elements[i]->shape().element_type(),
148 element_shape),
149 input_elements[i]));
150 }
151 }
152
153 return comp->AddInstruction(HloInstruction::CreateConcatenate(
154 GetConcatShape(), input_elements, concat_dim));
155 }
156
157 std::vector<HloInstruction*> elements;
158 std::vector<int64> element_sizes;
159 std::vector<int64> element_offsets;
160 int64 concat_dim;
161 // Whether the concat dim is an inserted new dimension.
162 bool inserted_concat_dim;
163 };
164
165 // A collection of ConcatGroup's where each HLO can only belong to a single
166 // group.
167 class ConcatGroups {
168 public:
169 // Returns the group index and element index in group for an HLO, if it
170 // belongs to a group.
GetGroupIndex(const HloInstruction * hlo) const171 absl::optional<std::pair<int64, int64>> GetGroupIndex(
172 const HloInstruction* hlo) const {
173 auto it = element_to_group_.find(hlo);
174 if (it == element_to_group_.end()) {
175 return absl::nullopt;
176 }
177 return it->second;
178 }
179
GetGroup(int64_t index) const180 const ConcatGroup& GetGroup(int64_t index) const { return groups_[index]; }
181
182 // Creates a new group and returns the index if it doesn't exist, or returns
183 // existing group index. If the new group doesn't match exactly with an
184 // existing group but shared some of the elements, returns -1 as the index.
185 // It also returns whether a new group is created. So the return value is a
186 // pair of {whether created, group index}.
MaybeCreateNewGroup(ConcatGroup group)187 std::pair<bool, int64> MaybeCreateNewGroup(ConcatGroup group) {
188 int64_t group_id = -1;
189 absl::flat_hash_set<HloInstruction*> elements_dedup;
190 for (int64_t i = 0; i < group.elements.size(); ++i) {
191 if (!elements_dedup.insert(group.elements[i]).second) {
192 VLOG(2) << "Duplicates in group. Element: "
193 << group.elements[i]->ToString();
194 }
195 if (concat_disallowed_.contains(group.elements[i])) {
196 VLOG(2) << "Failed creating group. Grouping disallowed on "
197 << group.elements[i]->ToString();
198 return std::pair<bool, int64>(false, -1);
199 }
200 auto existing = GetGroupIndex(group.elements[i]);
201 if (existing.has_value() &&
202 (i != existing->second ||
203 groups_[existing->first].concat_dim != group.concat_dim)) {
204 // We allow mismatched inserted_concat_dim, since that only requires a
205 // trivial reshape.
206 VLOG(2)
207 << "Failed creating group. Different than existing group. Element: "
208 << group.elements[i]->ToString();
209 return std::pair<bool, int64>(false, -1);
210 }
211 if (i == 0 && existing.has_value()) {
212 group_id = existing->first;
213 }
214 if (i > 0) {
215 if (existing.has_value() && existing->first != group_id) {
216 VLOG(2) << "Failed creating group. Different than existing group. "
217 "Element: "
218 << group.elements[i]->ToString();
219 return std::pair<bool, int64>(false, -1);
220 }
221 if (!existing.has_value() && group_id >= 0) {
222 VLOG(2) << "Failed creating group. Different than existing group. "
223 "Element: "
224 << group.elements[i]->ToString();
225 return std::pair<bool, int64>(false, -1);
226 }
227 }
228 }
229 if (group_id >= 0) {
230 VLOG(2) << "Group already exists at " << group_id << " for "
231 << group.elements[0]->ToString();
232 return std::pair<bool, int64>(false, group_id);
233 }
234 int64_t index = groups_.size();
235 for (int64_t i = 0; i < group.elements.size(); ++i) {
236 element_to_group_[group.elements[i]] = std::pair<int64, int64>(index, i);
237 }
238 VLOG(2) << "Created new group at " << index << " for "
239 << group.elements[0]->ToString()
240 << ", concat_dim: " << group.concat_dim
241 << ", inserted: " << group.inserted_concat_dim;
242 groups_.push_back(std::move(group));
243 return std::pair<bool, int64>(true, index);
244 }
245
Groups() const246 const std::vector<ConcatGroup>& Groups() const { return groups_; }
247
NextGroupIndex() const248 int64 NextGroupIndex() const { return groups_.size(); }
249
RemoveTailingGroups(int64_t start_index)250 void RemoveTailingGroups(int64_t start_index) {
251 while (groups_.size() > start_index) {
252 for (auto element : groups_.back().elements) {
253 element_to_group_.erase(element);
254 }
255 groups_.pop_back();
256 }
257 }
258
DisallowGroupingOn(const HloInstruction * hlo)259 void DisallowGroupingOn(const HloInstruction* hlo) {
260 VLOG(2) << "Disallow grouping on " << hlo->ToString();
261 concat_disallowed_.insert(hlo);
262 }
263
264 private:
265 // element -> {group index in groups_, element index in group}.
266 absl::flat_hash_map<const HloInstruction*, std::pair<int64, int64>>
267 element_to_group_;
268 std::vector<ConcatGroup> groups_;
269 absl::flat_hash_set<const HloInstruction*> concat_disallowed_;
270 };
271
272 // Infers an operand's concat dim and whether it's an inserted dim. For example,
273 // if hlo is f32[2,4,2] broadcast(f32[2,4]), dimensions={0,1} concatenated on
274 // dim 2, then this function will return {2, true}.
275 //
276 // If the operand is already transformed to the combined shape, specify its
277 // group in combined_operand_group. (Only required for kReshape.)
GetOperandConcatDim(const HloInstruction * hlo,int64_t operand_index,int64_t hlo_concat_dim,bool hlo_inserted_concat_dim,const ConcatGroup * combined_operand_group=nullptr)278 absl::optional<std::pair<int64, bool>> GetOperandConcatDim(
279 const HloInstruction* hlo, int64_t operand_index, int64_t hlo_concat_dim,
280 bool hlo_inserted_concat_dim,
281 const ConcatGroup* combined_operand_group = nullptr) {
282 if (hlo->IsElementwise() || hlo->opcode() == HloOpcode::kAllReduce) {
283 return std::pair<int64, bool>(hlo_concat_dim, hlo_inserted_concat_dim);
284 }
285 int64_t operand_concat_dim = -1;
286 bool operand_inserted_concat_dim = false;
287 const Shape& operand_shape =
288 combined_operand_group == nullptr
289 ? hlo->operand(operand_index)->shape()
290 : combined_operand_group->elements.back()->shape();
291 if (hlo->opcode() == HloOpcode::kBroadcast) {
292 operand_concat_dim = 0;
293 operand_inserted_concat_dim = true;
294 // Try to place operand_concat_dim adjacent to dims the same way as the
295 // output, if it does not exist in the operand..
296 int64_t min_dist_to_concat_dim = hlo->shape().rank();
297 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
298 if (hlo->dimensions(i) == hlo_concat_dim) {
299 operand_concat_dim = i;
300 operand_inserted_concat_dim = hlo_inserted_concat_dim;
301 break;
302 }
303 if (hlo->dimensions(i) < hlo_concat_dim &&
304 min_dist_to_concat_dim > hlo_concat_dim - hlo->dimensions(i)) {
305 operand_concat_dim = i + 1;
306 min_dist_to_concat_dim = hlo_concat_dim - hlo->dimensions(i);
307 }
308 if (hlo->dimensions(i) > hlo_concat_dim &&
309 min_dist_to_concat_dim > hlo->dimensions(i) - hlo_concat_dim) {
310 operand_concat_dim = i;
311 min_dist_to_concat_dim = hlo->dimensions(i) - hlo_concat_dim;
312 }
313 }
314 } else if (hlo->opcode() == HloOpcode::kReduce) {
315 if (operand_index != 0) {
316 return absl::nullopt;
317 }
318 operand_concat_dim = hlo_concat_dim;
319 operand_inserted_concat_dim = hlo_inserted_concat_dim;
320 std::set<int64> sorted_reduce_dims;
321 for (int64_t dim : hlo->dimensions()) {
322 sorted_reduce_dims.insert(dim);
323 }
324 for (int64_t dim : sorted_reduce_dims) {
325 if ((hlo_inserted_concat_dim && dim < operand_concat_dim) ||
326 (!hlo_inserted_concat_dim && dim <= operand_concat_dim)) {
327 operand_concat_dim++;
328 }
329 }
330 } else if (hlo->opcode() == HloOpcode::kReshape) {
331 int64_t i = 0;
332 int64_t j = 0;
333 operand_inserted_concat_dim = false;
334 // Only support adding/removing trivial dims.
335 while (i < operand_shape.rank() || j <= hlo_concat_dim) {
336 if (i < operand_shape.rank() && j < hlo->shape().rank() &&
337 operand_shape.dimensions(i) == hlo->shape().dimensions(j)) {
338 if (j == hlo_concat_dim) {
339 operand_inserted_concat_dim =
340 hlo_inserted_concat_dim && operand_shape.dimensions(i) != 1;
341 operand_concat_dim = i;
342 break;
343 }
344 i++;
345 j++;
346 continue;
347 }
348 if (i < operand_shape.rank() && operand_shape.dimensions(i) == 1) {
349 if (j == hlo_concat_dim && hlo_inserted_concat_dim) {
350 operand_concat_dim = i;
351 break;
352 }
353 i++;
354 continue;
355 }
356 if (j == hlo_concat_dim) {
357 operand_concat_dim = i;
358 operand_inserted_concat_dim = true;
359 break;
360 }
361 if (j < hlo->shape().rank() && hlo->shape().dimensions(j) == 1) {
362 j++;
363 continue;
364 }
365 return absl::nullopt;
366 }
367 } else {
368 return absl::nullopt;
369 }
370 CHECK_GE(operand_concat_dim, 0);
371 return std::pair<int64, bool>(operand_concat_dim,
372 operand_inserted_concat_dim);
373 }
374
ModifyHloPropertiesForConcatShape(const ConcatGroup & group,HloInstruction * hlo)375 void ModifyHloPropertiesForConcatShape(const ConcatGroup& group,
376 HloInstruction* hlo) {
377 *hlo->mutable_shape() = group.GetConcatShape();
378 if (hlo->opcode() == HloOpcode::kBroadcast) {
379 // Use the last element to infer the operand concat dim, since the first
380 // element's operand might have been rewriten.
381 auto operand_dim = GetOperandConcatDim(
382 group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
383 CHECK(operand_dim.has_value());
384 int64_t operand_concat_dim = operand_dim->first;
385 bool operand_inserted_concat_dim = operand_dim->second;
386 if (operand_inserted_concat_dim) {
387 // We should have added an dimension on the operand.
388 CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size() + 1)
389 << hlo->ToString();
390 } else {
391 CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size());
392 }
393 std::vector<int64> dims;
394 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
395 if (i == operand_concat_dim && operand_inserted_concat_dim) {
396 dims.push_back(group.concat_dim);
397 } else {
398 if (i > operand_concat_dim && operand_inserted_concat_dim) {
399 dims.push_back(hlo->dimensions(i - 1));
400 } else {
401 dims.push_back(hlo->dimensions(i));
402 }
403 if (group.inserted_concat_dim && dims.back() >= group.concat_dim) {
404 dims.back()++;
405 }
406 }
407 }
408 *hlo->mutable_dimensions() = std::move(dims);
409 } else if (hlo->opcode() == HloOpcode::kReduce) {
410 auto operand_dim = GetOperandConcatDim(
411 group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
412 int64_t operand_concat_dim = operand_dim->first;
413 bool operand_inserted_concat_dim = operand_dim->second;
414 CHECK(operand_dim.has_value());
415 if (operand_inserted_concat_dim) {
416 auto dims = hlo->mutable_dimensions();
417 for (int64_t i = 0; i < dims->size(); ++i) {
418 if ((*dims)[i] >= operand_concat_dim) {
419 (*dims)[i]++;
420 }
421 }
422 }
423 }
424 }
425
426 // Main method to assign groups to HLOs, based on a concat.
GroupHlosForConcat(HloComputation * body,HloInstruction * concat,absl::flat_hash_map<const HloInstruction *,int64> topological_order,ConcatGroups * groups)427 bool GroupHlosForConcat(
428 HloComputation* body, HloInstruction* concat,
429 absl::flat_hash_map<const HloInstruction*, int64> topological_order,
430 ConcatGroups* groups) {
431 const int64_t group_size = concat->operand_count();
432 absl::flat_hash_set<int64> used_groups;
433 auto root_tuple = body->root_instruction();
434 CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
435 absl::flat_hash_map<HloInstruction*, int64> root_tuple_element_use_count;
436 for (auto operand : root_tuple->operands()) {
437 root_tuple_element_use_count.emplace(operand, 0).first->second++;
438 }
439 // Priority Queue sorted by topological order. Users come before operands, so
440 // it uses -topological_order[element0] as the key. We start with the concat
441 // operands.
442 std::multimap<int64, ConcatGroup> pq;
443 const int64_t first_group_id_to_create = groups->NextGroupIndex();
444 auto fail_and_cleanup = [&] {
445 VLOG(1) << "Failed to get the subcomputation to optimize for "
446 << concat->ToString() << ", clear groups starting at "
447 << first_group_id_to_create;
448 groups->RemoveTailingGroups(first_group_id_to_create);
449 return false;
450 };
451 struct GroupUse {
452 int64 group_id;
453 bool newly_created;
454 bool already_used_by_subcomp;
455 };
456 auto maybe_create_group = [&](ConcatGroup group) {
457 auto res = groups->MaybeCreateNewGroup(std::move(group));
458 GroupUse use{res.second, false, false};
459 if (res.second < 0) {
460 return use;
461 }
462 use.newly_created = res.first;
463 use.already_used_by_subcomp = !used_groups.insert(res.second).second;
464 return use;
465 };
466 std::vector<HloInstruction*> concat_operands(concat->operands().begin(),
467 concat->operands().end());
468 int64_t concat_operand_order = -topological_order[concat_operands[0]];
469 pq.emplace(concat_operand_order,
470 ConcatGroup(std::move(concat_operands),
471 concat->concatenate_dimension(), false));
472
473 // Find the subcomputation on elements to combine, in order to move `concat`
474 // out of the loop without adding new concats. We start from the concat's
475 // operands, and the priority queue is ordered in reverse topological order
476 // so we process outputs before inputs. Each entry in the queue is a group of
477 // elements to combine. A legitimate group consists of identical ops, except
478 // that they each operate on one element. When a group of loop inputs are
479 // processed, we also enqueue the corresponding loop outputs to keep them
480 // match in shape.
481 while (!pq.empty()) {
482 auto group = std::move(pq.begin()->second);
483 pq.erase(pq.begin());
484 const auto& hlos = group.elements;
485 VLOG(2) << "GroupHlosForConcat dequeued " << hlos[0]->ToString();
486 bool group_is_param_gtes = false;
487 if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
488 return element == hlos[0];
489 })) {
490 // Shared operand.
491 if (groups->GetGroupIndex(hlos[0]).has_value()) {
492 VLOG(1) << "We do not support the case if a shared operand also part "
493 "of a group: "
494 << hlos[0]->ToString();
495 return fail_and_cleanup();
496 }
497 groups->DisallowGroupingOn(hlos[0]);
498 continue;
499 }
500 if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
501 return element->opcode() == HloOpcode::kGetTupleElement &&
502 element->operand(0) == body->parameter_instruction(0);
503 })) {
504 group_is_param_gtes = true;
505 } else if (((hlos[0]->IsElementwise() ||
506 hlos[0]->opcode() == HloOpcode::kAllReduce) &&
507 !hlos[0]->HasSideEffect()) ||
508 hlos[0]->opcode() == HloOpcode::kBroadcast ||
509 hlos[0]->opcode() == HloOpcode::kReduce ||
510 hlos[0]->opcode() == HloOpcode::kReshape ||
511 hlos[0]->IsCustomCall("Sharding")) {
512 if (hlos[0]->opcode() == HloOpcode::kAllReduce &&
513 (!hlos[0]->shape().IsArray() || hlos[0]->IsCrossModuleAllReduce())) {
514 VLOG(2) << "Unsupported allreduce: " << hlos[0]->ToString();
515 return fail_and_cleanup();
516 }
517 // Check if these elements can be concatenated.
518 if (absl::c_any_of(hlos, [&](const HloInstruction* element) {
519 auto eq_operand = [](const HloInstruction* a,
520 const HloInstruction* b) {
521 return ShapeUtil::Compatible(a->shape(), b->shape());
522 };
523 auto eq_computations = [](const HloComputation* lhs,
524 const HloComputation* rhs) {
525 return lhs->Equal(*rhs, /*is_layout_sensitive=*/false);
526 };
527 if (!hlos[0]->Identical(*element, eq_operand, eq_computations,
528 /*layout_sensitive=*/false)) {
529 return true;
530 }
531 if (element->opcode() == HloOpcode::kReduce &&
532 (element->operand_count() != 2 ||
533 element->operand(1) != hlos[0]->operand(1))) {
534 return true;
535 }
536 return false;
537 })) {
538 VLOG(2) << "Different types of elements. First element: "
539 << hlos[0]->ToString();
540 return fail_and_cleanup();
541 }
542 // Now enqueue the inputs.
543 int64_t input_count = hlos[0]->operand_count();
544 if (hlos[0]->opcode() == HloOpcode::kReduce) {
545 CHECK_EQ(input_count, 2);
546 // Exclude the init value that we have checked to be the same.
547 input_count = 1;
548 }
549 for (int64_t i = 0; i < input_count; ++i) {
550 std::vector<HloInstruction*> elements(group_size);
551 for (int64_t j = 0; j < group_size; ++j) {
552 elements[j] = hlos[j]->mutable_operand(i);
553 }
554 auto maybe_new_concat_dim = GetOperandConcatDim(
555 hlos[0], i, group.concat_dim, group.inserted_concat_dim);
556 if (!maybe_new_concat_dim.has_value()) {
557 VLOG(2) << "Cannot find operand concat dimension for operand " << i
558 << " of " << hlos[0]->ToString();
559 return fail_and_cleanup();
560 }
561 int64_t new_group_concat_dim = maybe_new_concat_dim->first;
562 bool inserted_concat_dim = maybe_new_concat_dim->second;
563 // Enqueue the input group.
564 int64_t element_order = -topological_order[elements[0]];
565 pq.emplace(element_order,
566 ConcatGroup(std::move(elements), new_group_concat_dim,
567 inserted_concat_dim));
568 }
569 } else if (hlos[0]->opcode() == HloOpcode::kSlice) {
570 int64_t offset = 0;
571 auto operand = hlos[0]->operand(0);
572 if (group.inserted_concat_dim) {
573 VLOG(2) << "Slices cannot be grouped on new dimension.";
574 return fail_and_cleanup();
575 }
576 if (groups->GetGroupIndex(operand).has_value()) {
577 // Should not slice an operand to be grouped.
578 return fail_and_cleanup();
579 }
580 groups->DisallowGroupingOn(operand);
581 for (int64_t i = 0; i < group_size; ++i) {
582 if (hlos[i]->operand(0) != operand) {
583 VLOG(2) << "Slices of different operands.";
584 return fail_and_cleanup();
585 }
586 for (int64_t j = 0; j < hlos[i]->shape().rank(); ++j) {
587 if (hlos[i]->slice_strides(j) != 1) {
588 VLOG(2) << "Slices with strides.";
589 return fail_and_cleanup();
590 }
591 if (j == group.concat_dim) {
592 if (hlos[i]->slice_starts(j) != offset) {
593 VLOG(2) << "Slices with unsupported offsets.";
594 return fail_and_cleanup();
595 }
596 offset += hlos[i]->shape().dimensions(j);
597 } else {
598 if (hlos[i]->slice_starts(j) != 0 ||
599 hlos[i]->slice_limits(j) != operand->shape().dimensions(j)) {
600 VLOG(2) << "Slice with unsupported offsets at dimension " << j
601 << ", " << hlos[i]->ToString();
602 return fail_and_cleanup();
603 }
604 }
605 }
606 }
607 if (offset != operand->shape().dimensions(group.concat_dim)) {
608 VLOG(2) << "Slices with unsupported sizes.";
609 return fail_and_cleanup();
610 }
611 } else {
612 VLOG(2) << "Unsupported opcode: " << hlos[0]->ToString();
613 return fail_and_cleanup();
614 }
615 auto guse = maybe_create_group(std::move(group));
616 if (guse.group_id < 0) {
617 VLOG(2) << "Failed to create group.";
618 return fail_and_cleanup();
619 }
620 const auto& registered_group = groups->GetGroup(guse.group_id);
621 if (!guse.already_used_by_subcomp && group_is_param_gtes) {
622 // When we processed a group of parameter GTEs, we should also enqueue the
623 // corresponding root tuple operands, so that they have matching shapes.
624 std::vector<HloInstruction*> new_outputs(group_size);
625 for (int64_t i = 0; i < group_size; ++i) {
626 new_outputs[i] = root_tuple->mutable_operand(
627 registered_group.elements[i]->tuple_index());
628 }
629 int64_t new_output_order = -topological_order[new_outputs[0]];
630 pq.emplace(
631 new_output_order,
632 ConcatGroup(std::move(new_outputs), registered_group.concat_dim,
633 registered_group.inserted_concat_dim));
634 }
635 }
636 return groups->Groups().size() > first_group_id_to_create;
637 }
638
TupleElementsUsedInCond(HloInstruction * loop)639 std::vector<bool> TupleElementsUsedInCond(HloInstruction* loop) {
640 std::vector<bool> result(loop->shape().tuple_shapes_size(), false);
641 for (auto user : loop->while_condition()->parameter_instruction(0)->users()) {
642 if (user->opcode() != HloOpcode::kGetTupleElement) {
643 absl::c_fill(result, true);
644 return result;
645 }
646 result[user->tuple_index()] = true;
647 }
648 return result;
649 }
650
651 // Adds copies to returned values to keep RewriteLoopWithConcatGroups simple:
652 // the copies do not have other users and only appear once in the root tuple.
AddCopiesToRoot(HloComputation * body,absl::Span<HloInstruction * const> param_gtes,ConcatGroups * groups)653 Status AddCopiesToRoot(HloComputation* body,
654 absl::Span<HloInstruction* const> param_gtes,
655 ConcatGroups* groups) {
656 auto root = body->root_instruction();
657 CHECK_EQ(root->opcode(), HloOpcode::kTuple);
658 std::vector<HloInstruction*> copies(root->operand_count(), nullptr);
659 for (int64_t i = 0; i < copies.size(); ++i) {
660 auto element = root->mutable_operand(i);
661 if (!element->shape().IsArray()) {
662 continue;
663 }
664 copies[i] = body->AddInstruction(HloInstruction::CreateUnary(
665 element->shape(), HloOpcode::kCopy, element));
666 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i]));
667 }
668 for (int64_t i = 0; i < copies.size(); ++i) {
669 auto copy = copies[i];
670 if (groups->GetGroupIndex(copy).has_value()) {
671 // Already handled by earlier group members.
672 continue;
673 }
674 auto param_group_index = groups->GetGroupIndex(param_gtes[i]);
675 if (!param_group_index.has_value()) {
676 continue;
677 }
678 const auto& param_group = groups->GetGroup(param_group_index->first);
679 std::vector<HloInstruction*> copy_group(param_group.elements.size());
680 for (int64_t j = 0; j < copy_group.size(); ++j) {
681 copy_group[j] = copies[param_group.elements[j]->tuple_index()];
682 }
683 CHECK(groups
684 ->MaybeCreateNewGroup(
685 ConcatGroup(std::move(copy_group), param_group.concat_dim,
686 param_group.inserted_concat_dim))
687 .first);
688 }
689 return Status::OK();
690 }
691
RemoveCopiesFromRoot(HloComputation * body)692 Status RemoveCopiesFromRoot(HloComputation* body) {
693 auto root = body->root_instruction();
694 CHECK_EQ(root->opcode(), HloOpcode::kTuple);
695 for (int64_t i = 0; i < root->operand_count(); ++i) {
696 auto copy = root->mutable_operand(i);
697 if (copy->opcode() == HloOpcode::kCopy) {
698 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0)));
699 }
700 }
701 return Status::OK();
702 }
703
RewriteLoopWithConcatGroups(HloInstruction * loop,absl::Span<HloInstruction * const> param_gtes,ConcatGroups & groups)704 Status RewriteLoopWithConcatGroups(HloInstruction* loop,
705 absl::Span<HloInstruction* const> param_gtes,
706 ConcatGroups& groups) {
707 VLOG(1) << "RewriteLoopWithConcatGroups with " << groups.Groups().size()
708 << " groups.";
709 // For simplicity, for each group, we rewrite the first element into full
710 // shape, and leave the other elements unchagned. Non-grouped users will be
711 // have slices of the expanded first element as the new input. Later
712 // simplification and DCE passes can remove the other elements.
713 absl::flat_hash_set<int64> processed_groups;
714 auto body = loop->while_body();
715 auto param = body->parameter_instruction(0);
716 auto cond_param = loop->while_condition()->parameter_instruction(0);
717
718 // First, modify loop signature and operands/users.
719 std::vector<HloInstruction*> init_elements(loop->shape().tuple_shapes_size());
720 for (int64_t i = 0; i < param_gtes.size(); ++i) {
721 init_elements[i] =
722 loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
723 loop->shape().tuple_shapes(i), loop->mutable_operand(0), i));
724 }
725 for (int64_t i = 0; i < param_gtes.size(); ++i) {
726 const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
727 if (!group_and_index.has_value() || group_and_index->second != 0) {
728 continue;
729 }
730 const auto& group = groups.GetGroup(group_and_index->first);
731 // Change body parameter shape.
732 *param_gtes[i]->mutable_shape() = group.GetConcatShape();
733 *param->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
734 *body->root_instruction()->mutable_shape()->mutable_tuple_shapes(i) =
735 param_gtes[i]->shape();
736 *cond_param->mutable_shape()->mutable_tuple_shapes(i) =
737 param_gtes[i]->shape();
738 *loop->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
739 processed_groups.insert(group_and_index->first);
740 std::vector<HloInstruction*> input_concat_elements;
741 input_concat_elements.reserve(group.elements.size());
742 for (auto param_gte : group.elements) {
743 input_concat_elements.push_back(init_elements[param_gte->tuple_index()]);
744 }
745 init_elements[i] =
746 group.CreateConcat(std::move(input_concat_elements), loop->parent());
747 }
748 TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(
749 0, loop->parent()->AddInstruction(
750 HloInstruction::CreateTuple(init_elements))));
751 // Adjust loop users.
752 auto original_loop_users = loop->users();
753 const bool loop_is_root = loop == loop->parent()->root_instruction();
754 std::vector<HloInstruction*> output_elements(
755 loop->shape().tuple_shapes_size());
756 for (int64_t i = 0; i < param_gtes.size(); ++i) {
757 output_elements[i] =
758 loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
759 init_elements[i]->shape(), loop, i));
760 }
761 for (int64_t i = 0; i < param_gtes.size(); ++i) {
762 const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
763 if (!group_and_index.has_value() || group_and_index->second != 0) {
764 continue;
765 }
766 const auto& group = groups.GetGroup(group_and_index->first);
767 auto concat_output = output_elements[group.elements[0]->tuple_index()];
768 for (int64_t j = 0; j < group.elements.size(); ++j) {
769 const auto param_gte = group.elements[j];
770 output_elements[param_gte->tuple_index()] =
771 group.CreateSlice(concat_output, j, loop->parent());
772 }
773 }
774 auto new_output_tuple = loop->parent()->AddInstruction(
775 HloInstruction::CreateTuple(output_elements));
776 for (auto user : original_loop_users) {
777 TF_RETURN_IF_ERROR(
778 loop->ReplaceUseWithDifferentShape(user, new_output_tuple));
779 }
780 if (loop_is_root) {
781 loop->parent()->set_root_instruction(new_output_tuple,
782 /*accept_different_shape=*/true);
783 }
784
785 // Now rewrite the loop body.
786 std::vector<HloInstruction*> slices_to_remove;
787 absl::flat_hash_set<HloInstruction*> new_reshapes;
788 for (auto hlo : body->MakeInstructionPostOrder()) {
789 const auto& group_and_index = groups.GetGroupIndex(hlo);
790 if (!group_and_index.has_value() || group_and_index->second != 0) {
791 continue;
792 }
793
794 if (!processed_groups.insert(group_and_index->first).second) {
795 // Already processed the group at the first element.
796 continue;
797 }
798 const auto& group = groups.GetGroup(group_and_index->first);
799 if (hlo->opcode() == HloOpcode::kSlice) {
800 // We could just replace hlo with its operand; however, to follow the
801 // practice of using the first element as full data, we defer that
802 // replacement.
803 slices_to_remove.push_back(hlo);
804 } else {
805 int64_t operand_count_to_adjust = hlo->operand_count();
806 if (hlo->opcode() == HloOpcode::kReduce) {
807 CHECK_EQ(operand_count_to_adjust, 2);
808 operand_count_to_adjust = 1;
809 }
810 for (int64_t i = 0; i < operand_count_to_adjust; ++i) {
811 auto operand_group_index = groups.GetGroupIndex(hlo->operand(i));
812 const ConcatGroup* operand_group =
813 operand_group_index.has_value()
814 ? &groups.GetGroup(operand_group_index->first)
815 : nullptr;
816 auto maybe_operand_concat_dim = GetOperandConcatDim(
817 hlo, i, group.concat_dim, group.inserted_concat_dim, operand_group);
818 CHECK(maybe_operand_concat_dim.has_value())
819 << "Operand " << i << " of " << hlo->ToString();
820 int64_t operand_concat_dim = maybe_operand_concat_dim->first;
821 bool operand_inserted_concat_dim = maybe_operand_concat_dim->second;
822 if (operand_group != nullptr) {
823 CHECK_EQ(operand_concat_dim, operand_group->concat_dim);
824 if (operand_inserted_concat_dim !=
825 operand_group->inserted_concat_dim) {
826 // The operand's actual inserted_concat_dim doesn't match the
827 // expected operand_inserted_concat_dim. Need a reshape.
828 std::vector<int64> new_dims;
829 int64_t d = 0;
830 for (; d < operand_concat_dim; ++d) {
831 new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
832 }
833 if (operand_inserted_concat_dim) {
834 // Split operand concat dim.
835 new_dims.push_back(group.elements.size());
836 new_dims.push_back(
837 hlo->operand(i)->shape().dimensions(operand_concat_dim) /
838 group.elements.size());
839 d = operand_concat_dim + 1;
840 } else {
841 // Combine operand concat dim with the next.
842 new_dims.push_back(
843 group.elements.size() *
844 hlo->operand(i)->shape().dimensions(operand_concat_dim + 1));
845 d = operand_concat_dim + 2;
846 }
847 for (; d < hlo->operand(i)->shape().rank(); ++d) {
848 new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
849 }
850 auto reshape = body->AddInstruction(HloInstruction::CreateReshape(
851 ShapeUtil::MakeShape(hlo->operand(i)->shape().element_type(),
852 new_dims),
853 hlo->mutable_operand(i)));
854 new_reshapes.insert(reshape);
855 TF_RETURN_IF_ERROR(
856 hlo->ReplaceOperandWithDifferentShape(i, reshape));
857 }
858 continue;
859 }
860 // This is a shared operand, we need to broadcast it.
861 CHECK(
862 absl::c_all_of(group.elements, [&](const HloInstruction* element) {
863 return element->operand(i) == hlo->operand(i);
864 }));
865 VLOG(2) << "Broadcasting shared operand "
866 << hlo->operand(i)->ToString();
867 Shape data_shape = hlo->operand(i)->shape();
868 std::vector<int64> broadcast_dims;
869 std::vector<int64> broadcast_shape;
870 for (int64_t j = 0; j < data_shape.rank(); ++j) {
871 if (j < operand_concat_dim) {
872 broadcast_dims.push_back(j);
873 } else {
874 broadcast_dims.push_back(j + 1);
875 }
876 if (j == operand_concat_dim) {
877 broadcast_shape.push_back(group.elements.size());
878 }
879 broadcast_shape.push_back(data_shape.dimensions(j));
880 }
881 if (broadcast_shape.size() == data_shape.rank()) {
882 // New dim at the end.
883 broadcast_shape.push_back(group.elements.size());
884 }
885 auto broadcast = body->AddInstruction(HloInstruction::CreateBroadcast(
886 ShapeUtil::MakeShape(data_shape.element_type(), broadcast_shape),
887 hlo->mutable_operand(i), broadcast_dims));
888
889 if (!operand_inserted_concat_dim) {
890 // Concat on existing dim. Reshape to merge the broadcast dim.
891 data_shape.set_dimensions(
892 operand_concat_dim,
893 data_shape.dimensions(operand_inserted_concat_dim) *
894 group.elements.size());
895 broadcast = body->AddInstruction(
896 HloInstruction::CreateReshape(data_shape, broadcast));
897 }
898 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast));
899 }
900 }
901 VLOG(2) << "Modifying HLO to full shape " << hlo->ToString();
902 ModifyHloPropertiesForConcatShape(group, hlo);
903 VLOG(2) << "Modified HLO to full shape " << hlo->ToString();
904 }
905
906 // For non-grouped HLOs, replace grouped inputs with slices. Also inlcude
907 // grouped reduce HLOs because their init values are not grouped.
908 for (auto hlo : body->MakeInstructionPostOrder()) {
909 if (new_reshapes.contains(hlo)) {
910 continue;
911 }
912 const auto& group_and_index = groups.GetGroupIndex(hlo);
913 if ((!group_and_index.has_value() || hlo->opcode() == HloOpcode::kReduce) &&
914 hlo != body->root_instruction()) {
915 auto operands = hlo->operands();
916 if (group_and_index.has_value()) {
917 // Only handle reduce init value.
918 CHECK_EQ(operands.size(), 2);
919 CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
920 operands.erase(operands.begin());
921 }
922 for (int64_t i = 0; i < operands.size(); ++i) {
923 auto operand = operands[i];
924 auto operand_group_index = groups.GetGroupIndex(operand);
925 if (!operand_group_index.has_value()) {
926 continue;
927 }
928 const auto& operand_group = groups.GetGroup(operand_group_index->first);
929 auto slice = operand_group.CreateSlice(
930 operand_group.elements[0], operand_group_index->second, body);
931 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice));
932 }
933 }
934 }
935 for (auto slice : slices_to_remove) {
936 TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0)));
937 TF_RETURN_IF_ERROR(body->RemoveInstruction(slice));
938 }
939 return Status::OK();
940 }
941
RunOnLoop(HloInstruction * loop,int64_t min_operand_count_to_optimize)942 StatusOr<bool> RunOnLoop(HloInstruction* loop,
943 int64_t min_operand_count_to_optimize) {
944 auto body = loop->while_body();
945 auto param = body->parameter_instruction(0);
946 auto root = body->root_instruction();
947 if (!param->shape().IsTuple() || root->opcode() != HloOpcode::kTuple) {
948 return false;
949 }
950 std::vector<HloInstruction*> gtes(param->shape().tuple_shapes_size(),
951 nullptr);
952 ConcatGroups groups;
953 auto indices_used_in_cond = TupleElementsUsedInCond(loop);
954 for (auto user : param->users()) {
955 if (user->opcode() != HloOpcode::kGetTupleElement) {
956 // Unhandled user opcode.
957 return false;
958 }
959 int64_t idx = user->tuple_index();
960 if (gtes[idx] != nullptr) {
961 // Seen this index before.
962 return false;
963 }
964 gtes[idx] = user;
965 if (indices_used_in_cond[idx]) {
966 groups.DisallowGroupingOn(user);
967 }
968 }
969 std::vector<HloInstruction*> concats;
970 auto body_instructions = body->MakeInstructionPostOrder();
971 absl::flat_hash_map<const HloInstruction*, int64> topological_order;
972 for (int64_t i = 0; i < body_instructions.size(); ++i) {
973 auto hlo = body_instructions[i];
974 topological_order[hlo] = i;
975 if (hlo->opcode() == HloOpcode::kConcatenate &&
976 hlo->operand_count() >= min_operand_count_to_optimize) {
977 concats.push_back(hlo);
978 }
979 }
980
981 for (auto& concat : concats) {
982 if (!GroupHlosForConcat(body, concat, topological_order, &groups)) {
983 concat = nullptr;
984 }
985 }
986 if (groups.Groups().empty()) {
987 return false;
988 }
989
990 TF_RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups));
991 TF_RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups));
992 for (auto concat : concats) {
993 if (concat == nullptr) {
994 continue;
995 }
996 // We have repalced the operands of the concat with slices of full data.
997 auto new_slice = concat->mutable_operand(0);
998 CHECK_EQ(new_slice->opcode(), HloOpcode::kSlice);
999 TF_RETURN_IF_ERROR(
1000 concat->ReplaceAllUsesWith(new_slice->mutable_operand(0)));
1001 TF_RETURN_IF_ERROR(body->RemoveInstruction(concat));
1002 }
1003 TF_RETURN_IF_ERROR(RemoveCopiesFromRoot(body));
1004 // Finally pass-through replaced elements from parameter to root, so that
1005 // while loop simplifier can get rid of them.
1006 for (auto gte : gtes) {
1007 auto group_index = groups.GetGroupIndex(gte);
1008 if (group_index.has_value() && group_index->second > 0) {
1009 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte));
1010 }
1011 }
1012 return true;
1013 }
1014
1015 } // namespace
1016
Run(HloModule * module)1017 StatusOr<bool> WhileLoopConcatCodeMotion::Run(HloModule* module) {
1018 bool changed = false;
1019 for (HloComputation* comp : module->MakeComputationPostOrder()) {
1020 for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) {
1021 if (hlo->opcode() == HloOpcode::kWhile) {
1022 TF_ASSIGN_OR_RETURN(bool loop_changed,
1023 RunOnLoop(hlo, min_operand_count_to_optimize_));
1024 changed |= loop_changed;
1025 }
1026 }
1027 }
1028 if (changed) {
1029 HloPassPipeline pipeline("loop-concat-motion-cleanup");
1030 pipeline.AddPass<TupleSimplifier>();
1031 pipeline.AddPass<HloDCE>();
1032 pipeline.AddPass<WhileLoopSimplifier>();
1033 pipeline.AddPass<TupleSimplifier>();
1034 pipeline.AddPass<HloDCE>();
1035 TF_RETURN_IF_ERROR(pipeline.Run(module).status());
1036 }
1037 return changed;
1038 }
1039
1040 } // namespace xla
1041