• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_concat_code_motion.h"
17 
18 #include <map>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_dce.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
34 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
35 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/status.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46 
47 namespace xla {
48 
49 namespace {
50 
51 // This algorithm tries to group HLO instructions into concat candidates. Each
52 // instruction can only belong to a single group.
53 //
54 // For simplicity, after finding the groups, it in-place updates the first group
55 // member to the full shape, and replaces non-grouped uses with slices of it.
56 // Then it relies on TupleSimplifier, WhileLoopSimplifier, and DCE passes to
57 // remove other elements.
58 
59 // Represents a group of elements and how to concat them.
60 struct ConcatGroup {
ConcatGroupxla::__anon7614bdcb0111::ConcatGroup61   ConcatGroup(std::vector<HloInstruction*> elements, int64_t concat_dim,
62               bool inserted_concat_dim)
63       : elements(std::move(elements)),
64         element_sizes(this->elements.size(), 1),
65         element_offsets(this->elements.size(), 0),
66         concat_dim(concat_dim),
67         inserted_concat_dim(inserted_concat_dim) {
68     if (inserted_concat_dim) {
69       absl::c_iota(element_offsets, 0);
70     } else {
71       for (int64_t i = 0; i < element_sizes.size(); ++i) {
72         element_sizes[i] = this->elements[i]->shape().dimensions(concat_dim);
73         if (i > 0) {
74           element_offsets[i] = element_offsets[i - 1] + element_sizes[i - 1];
75         }
76       }
77     }
78   }
79 
GetConcatShapexla::__anon7614bdcb0111::ConcatGroup80   Shape GetConcatShape() const {
81     if (inserted_concat_dim) {
82       std::vector<int64> dims;
83       const Shape& element_shape = elements.back()->shape();
84       dims.reserve(element_shape.rank() + 1);
85       for (int64_t i = 0; i < element_shape.rank(); ++i) {
86         if (i == concat_dim) {
87           dims.push_back(elements.size());
88         }
89         dims.push_back(element_shape.dimensions(i));
90       }
91       if (dims.size() == concat_dim) {
92         dims.push_back(elements.size());
93       }
94       return ShapeUtil::MakeShape(element_shape.element_type(), dims);
95     } else {
96       int64_t dim_size = 0;
97       for (int64_t size : element_sizes) {
98         dim_size += size;
99       }
100       Shape shape = elements.back()->shape();
101       shape.set_dimensions(concat_dim, dim_size);
102       return shape;
103     }
104   }
105 
CreateSlicexla::__anon7614bdcb0111::ConcatGroup106   HloInstruction* CreateSlice(HloInstruction* full_data, int64_t element_index,
107                               HloComputation* comp) const {
108     Shape shape = full_data->shape();
109     shape.set_dimensions(concat_dim, element_sizes[element_index]);
110     std::vector<int64> starts(shape.rank(), 0);
111     std::vector<int64> limits(shape.dimensions().begin(),
112                               shape.dimensions().end());
113     starts[concat_dim] = element_offsets[element_index];
114     limits[concat_dim] += starts[concat_dim];
115     auto slice = comp->AddInstruction(HloInstruction::CreateSlice(
116         shape, full_data, starts, limits, std::vector<int64>(shape.rank(), 1)));
117     if (!inserted_concat_dim) {
118       return slice;
119     }
120     std::vector<int64> element_shape;
121     element_shape.reserve(shape.rank());
122     for (int64_t i = 0; i < shape.rank(); ++i) {
123       if (i != concat_dim) {
124         element_shape.push_back(shape.dimensions(i));
125       }
126     }
127     return comp->AddInstruction(HloInstruction::CreateReshape(
128         ShapeUtil::MakeShape(shape.element_type(), element_shape), slice));
129   }
130 
CreateConcatxla::__anon7614bdcb0111::ConcatGroup131   HloInstruction* CreateConcat(std::vector<HloInstruction*> input_elements,
132                                HloComputation* comp) const {
133     if (inserted_concat_dim) {
134       for (int64_t i = 0; i < input_elements.size(); ++i) {
135         std::vector<int64> element_shape;
136         element_shape.reserve(input_elements[i]->shape().rank());
137         for (int64_t j = 0; j < input_elements[i]->shape().rank(); ++j) {
138           if (j == concat_dim) {
139             element_shape.push_back(1);
140           }
141           element_shape.push_back(input_elements[i]->shape().dimensions(j));
142         }
143         if (element_shape.size() == concat_dim) {
144           element_shape.push_back(1);
145         }
146         input_elements[i] = comp->AddInstruction(HloInstruction::CreateReshape(
147             ShapeUtil::MakeShape(input_elements[i]->shape().element_type(),
148                                  element_shape),
149             input_elements[i]));
150       }
151     }
152 
153     return comp->AddInstruction(HloInstruction::CreateConcatenate(
154         GetConcatShape(), input_elements, concat_dim));
155   }
156 
157   std::vector<HloInstruction*> elements;
158   std::vector<int64> element_sizes;
159   std::vector<int64> element_offsets;
160   int64 concat_dim;
161   // Whether the concat dim is an inserted new dimension.
162   bool inserted_concat_dim;
163 };
164 
165 // A collection of ConcatGroup's where each HLO can only belong to a single
166 // group.
167 class ConcatGroups {
168  public:
169   // Returns the group index and element index in group for an HLO, if it
170   // belongs to a group.
GetGroupIndex(const HloInstruction * hlo) const171   absl::optional<std::pair<int64, int64>> GetGroupIndex(
172       const HloInstruction* hlo) const {
173     auto it = element_to_group_.find(hlo);
174     if (it == element_to_group_.end()) {
175       return absl::nullopt;
176     }
177     return it->second;
178   }
179 
GetGroup(int64_t index) const180   const ConcatGroup& GetGroup(int64_t index) const { return groups_[index]; }
181 
182   // Creates a new group and returns the index if it doesn't exist, or returns
183   // existing group index. If the new group doesn't match exactly with an
184   // existing group but shared some of the elements, returns -1 as the index.
185   // It also returns whether a new group is created. So the return value is a
186   // pair of {whether created, group index}.
MaybeCreateNewGroup(ConcatGroup group)187   std::pair<bool, int64> MaybeCreateNewGroup(ConcatGroup group) {
188     int64_t group_id = -1;
189     absl::flat_hash_set<HloInstruction*> elements_dedup;
190     for (int64_t i = 0; i < group.elements.size(); ++i) {
191       if (!elements_dedup.insert(group.elements[i]).second) {
192         VLOG(2) << "Duplicates in group. Element: "
193                 << group.elements[i]->ToString();
194       }
195       if (concat_disallowed_.contains(group.elements[i])) {
196         VLOG(2) << "Failed creating group. Grouping disallowed on "
197                 << group.elements[i]->ToString();
198         return std::pair<bool, int64>(false, -1);
199       }
200       auto existing = GetGroupIndex(group.elements[i]);
201       if (existing.has_value() &&
202           (i != existing->second ||
203            groups_[existing->first].concat_dim != group.concat_dim)) {
204         // We allow mismatched inserted_concat_dim, since that only requires a
205         // trivial reshape.
206         VLOG(2)
207             << "Failed creating group. Different than existing group. Element: "
208             << group.elements[i]->ToString();
209         return std::pair<bool, int64>(false, -1);
210       }
211       if (i == 0 && existing.has_value()) {
212         group_id = existing->first;
213       }
214       if (i > 0) {
215         if (existing.has_value() && existing->first != group_id) {
216           VLOG(2) << "Failed creating group. Different than existing group. "
217                      "Element: "
218                   << group.elements[i]->ToString();
219           return std::pair<bool, int64>(false, -1);
220         }
221         if (!existing.has_value() && group_id >= 0) {
222           VLOG(2) << "Failed creating group. Different than existing group. "
223                      "Element: "
224                   << group.elements[i]->ToString();
225           return std::pair<bool, int64>(false, -1);
226         }
227       }
228     }
229     if (group_id >= 0) {
230       VLOG(2) << "Group already exists at " << group_id << " for "
231               << group.elements[0]->ToString();
232       return std::pair<bool, int64>(false, group_id);
233     }
234     int64_t index = groups_.size();
235     for (int64_t i = 0; i < group.elements.size(); ++i) {
236       element_to_group_[group.elements[i]] = std::pair<int64, int64>(index, i);
237     }
238     VLOG(2) << "Created new group at " << index << " for "
239             << group.elements[0]->ToString()
240             << ", concat_dim: " << group.concat_dim
241             << ", inserted: " << group.inserted_concat_dim;
242     groups_.push_back(std::move(group));
243     return std::pair<bool, int64>(true, index);
244   }
245 
Groups() const246   const std::vector<ConcatGroup>& Groups() const { return groups_; }
247 
NextGroupIndex() const248   int64 NextGroupIndex() const { return groups_.size(); }
249 
RemoveTailingGroups(int64_t start_index)250   void RemoveTailingGroups(int64_t start_index) {
251     while (groups_.size() > start_index) {
252       for (auto element : groups_.back().elements) {
253         element_to_group_.erase(element);
254       }
255       groups_.pop_back();
256     }
257   }
258 
DisallowGroupingOn(const HloInstruction * hlo)259   void DisallowGroupingOn(const HloInstruction* hlo) {
260     VLOG(2) << "Disallow grouping on " << hlo->ToString();
261     concat_disallowed_.insert(hlo);
262   }
263 
264  private:
265   // element -> {group index in groups_, element index in group}.
266   absl::flat_hash_map<const HloInstruction*, std::pair<int64, int64>>
267       element_to_group_;
268   std::vector<ConcatGroup> groups_;
269   absl::flat_hash_set<const HloInstruction*> concat_disallowed_;
270 };
271 
272 // Infers an operand's concat dim and whether it's an inserted dim. For example,
273 // if hlo is f32[2,4,2] broadcast(f32[2,4]), dimensions={0,1} concatenated on
274 // dim 2, then this function will return {2, true}.
275 //
276 // If the operand is already transformed to the combined shape, specify its
277 // group in combined_operand_group. (Only required for kReshape.)
GetOperandConcatDim(const HloInstruction * hlo,int64_t operand_index,int64_t hlo_concat_dim,bool hlo_inserted_concat_dim,const ConcatGroup * combined_operand_group=nullptr)278 absl::optional<std::pair<int64, bool>> GetOperandConcatDim(
279     const HloInstruction* hlo, int64_t operand_index, int64_t hlo_concat_dim,
280     bool hlo_inserted_concat_dim,
281     const ConcatGroup* combined_operand_group = nullptr) {
282   if (hlo->IsElementwise() || hlo->opcode() == HloOpcode::kAllReduce) {
283     return std::pair<int64, bool>(hlo_concat_dim, hlo_inserted_concat_dim);
284   }
285   int64_t operand_concat_dim = -1;
286   bool operand_inserted_concat_dim = false;
287   const Shape& operand_shape =
288       combined_operand_group == nullptr
289           ? hlo->operand(operand_index)->shape()
290           : combined_operand_group->elements.back()->shape();
291   if (hlo->opcode() == HloOpcode::kBroadcast) {
292     operand_concat_dim = 0;
293     operand_inserted_concat_dim = true;
294     // Try to place operand_concat_dim adjacent to dims the same way as the
295     // output, if it does not exist in the operand..
296     int64_t min_dist_to_concat_dim = hlo->shape().rank();
297     for (int64_t i = 0; i < operand_shape.rank(); ++i) {
298       if (hlo->dimensions(i) == hlo_concat_dim) {
299         operand_concat_dim = i;
300         operand_inserted_concat_dim = hlo_inserted_concat_dim;
301         break;
302       }
303       if (hlo->dimensions(i) < hlo_concat_dim &&
304           min_dist_to_concat_dim > hlo_concat_dim - hlo->dimensions(i)) {
305         operand_concat_dim = i + 1;
306         min_dist_to_concat_dim = hlo_concat_dim - hlo->dimensions(i);
307       }
308       if (hlo->dimensions(i) > hlo_concat_dim &&
309           min_dist_to_concat_dim > hlo->dimensions(i) - hlo_concat_dim) {
310         operand_concat_dim = i;
311         min_dist_to_concat_dim = hlo->dimensions(i) - hlo_concat_dim;
312       }
313     }
314   } else if (hlo->opcode() == HloOpcode::kReduce) {
315     if (operand_index != 0) {
316       return absl::nullopt;
317     }
318     operand_concat_dim = hlo_concat_dim;
319     operand_inserted_concat_dim = hlo_inserted_concat_dim;
320     std::set<int64> sorted_reduce_dims;
321     for (int64_t dim : hlo->dimensions()) {
322       sorted_reduce_dims.insert(dim);
323     }
324     for (int64_t dim : sorted_reduce_dims) {
325       if ((hlo_inserted_concat_dim && dim < operand_concat_dim) ||
326           (!hlo_inserted_concat_dim && dim <= operand_concat_dim)) {
327         operand_concat_dim++;
328       }
329     }
330   } else if (hlo->opcode() == HloOpcode::kReshape) {
331     int64_t i = 0;
332     int64_t j = 0;
333     operand_inserted_concat_dim = false;
334     // Only support adding/removing trivial dims.
335     while (i < operand_shape.rank() || j <= hlo_concat_dim) {
336       if (i < operand_shape.rank() && j < hlo->shape().rank() &&
337           operand_shape.dimensions(i) == hlo->shape().dimensions(j)) {
338         if (j == hlo_concat_dim) {
339           operand_inserted_concat_dim =
340               hlo_inserted_concat_dim && operand_shape.dimensions(i) != 1;
341           operand_concat_dim = i;
342           break;
343         }
344         i++;
345         j++;
346         continue;
347       }
348       if (i < operand_shape.rank() && operand_shape.dimensions(i) == 1) {
349         if (j == hlo_concat_dim && hlo_inserted_concat_dim) {
350           operand_concat_dim = i;
351           break;
352         }
353         i++;
354         continue;
355       }
356       if (j == hlo_concat_dim) {
357         operand_concat_dim = i;
358         operand_inserted_concat_dim = true;
359         break;
360       }
361       if (j < hlo->shape().rank() && hlo->shape().dimensions(j) == 1) {
362         j++;
363         continue;
364       }
365       return absl::nullopt;
366     }
367   } else {
368     return absl::nullopt;
369   }
370   CHECK_GE(operand_concat_dim, 0);
371   return std::pair<int64, bool>(operand_concat_dim,
372                                 operand_inserted_concat_dim);
373 }
374 
ModifyHloPropertiesForConcatShape(const ConcatGroup & group,HloInstruction * hlo)375 void ModifyHloPropertiesForConcatShape(const ConcatGroup& group,
376                                        HloInstruction* hlo) {
377   *hlo->mutable_shape() = group.GetConcatShape();
378   if (hlo->opcode() == HloOpcode::kBroadcast) {
379     // Use the last element to infer the operand concat dim, since the first
380     // element's operand might have been rewriten.
381     auto operand_dim = GetOperandConcatDim(
382         group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
383     CHECK(operand_dim.has_value());
384     int64_t operand_concat_dim = operand_dim->first;
385     bool operand_inserted_concat_dim = operand_dim->second;
386     if (operand_inserted_concat_dim) {
387       // We should have added an dimension on the operand.
388       CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size() + 1)
389           << hlo->ToString();
390     } else {
391       CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size());
392     }
393     std::vector<int64> dims;
394     for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
395       if (i == operand_concat_dim && operand_inserted_concat_dim) {
396         dims.push_back(group.concat_dim);
397       } else {
398         if (i > operand_concat_dim && operand_inserted_concat_dim) {
399           dims.push_back(hlo->dimensions(i - 1));
400         } else {
401           dims.push_back(hlo->dimensions(i));
402         }
403         if (group.inserted_concat_dim && dims.back() >= group.concat_dim) {
404           dims.back()++;
405         }
406       }
407     }
408     *hlo->mutable_dimensions() = std::move(dims);
409   } else if (hlo->opcode() == HloOpcode::kReduce) {
410     auto operand_dim = GetOperandConcatDim(
411         group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
412     int64_t operand_concat_dim = operand_dim->first;
413     bool operand_inserted_concat_dim = operand_dim->second;
414     CHECK(operand_dim.has_value());
415     if (operand_inserted_concat_dim) {
416       auto dims = hlo->mutable_dimensions();
417       for (int64_t i = 0; i < dims->size(); ++i) {
418         if ((*dims)[i] >= operand_concat_dim) {
419           (*dims)[i]++;
420         }
421       }
422     }
423   }
424 }
425 
426 // Main method to assign groups to HLOs, based on a concat.
GroupHlosForConcat(HloComputation * body,HloInstruction * concat,absl::flat_hash_map<const HloInstruction *,int64> topological_order,ConcatGroups * groups)427 bool GroupHlosForConcat(
428     HloComputation* body, HloInstruction* concat,
429     absl::flat_hash_map<const HloInstruction*, int64> topological_order,
430     ConcatGroups* groups) {
431   const int64_t group_size = concat->operand_count();
432   absl::flat_hash_set<int64> used_groups;
433   auto root_tuple = body->root_instruction();
434   CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
435   absl::flat_hash_map<HloInstruction*, int64> root_tuple_element_use_count;
436   for (auto operand : root_tuple->operands()) {
437     root_tuple_element_use_count.emplace(operand, 0).first->second++;
438   }
439   // Priority Queue sorted by topological order. Users come before operands, so
440   // it uses -topological_order[element0] as the key. We start with the concat
441   // operands.
442   std::multimap<int64, ConcatGroup> pq;
443   const int64_t first_group_id_to_create = groups->NextGroupIndex();
444   auto fail_and_cleanup = [&] {
445     VLOG(1) << "Failed to get the subcomputation to optimize for "
446             << concat->ToString() << ", clear groups starting at "
447             << first_group_id_to_create;
448     groups->RemoveTailingGroups(first_group_id_to_create);
449     return false;
450   };
451   struct GroupUse {
452     int64 group_id;
453     bool newly_created;
454     bool already_used_by_subcomp;
455   };
456   auto maybe_create_group = [&](ConcatGroup group) {
457     auto res = groups->MaybeCreateNewGroup(std::move(group));
458     GroupUse use{res.second, false, false};
459     if (res.second < 0) {
460       return use;
461     }
462     use.newly_created = res.first;
463     use.already_used_by_subcomp = !used_groups.insert(res.second).second;
464     return use;
465   };
466   std::vector<HloInstruction*> concat_operands(concat->operands().begin(),
467                                                concat->operands().end());
468   int64_t concat_operand_order = -topological_order[concat_operands[0]];
469   pq.emplace(concat_operand_order,
470              ConcatGroup(std::move(concat_operands),
471                          concat->concatenate_dimension(), false));
472 
473   // Find the subcomputation on elements to combine, in order to move `concat`
474   // out of the loop without adding new concats. We start from the concat's
475   // operands, and the priority queue is ordered in reverse topological order
476   // so we process outputs before inputs. Each entry in the queue is a group of
477   // elements to combine. A legitimate group consists of identical ops, except
478   // that they each operate on one element. When a group of loop inputs are
479   // processed, we also enqueue the corresponding loop outputs to keep them
480   // match in shape.
481   while (!pq.empty()) {
482     auto group = std::move(pq.begin()->second);
483     pq.erase(pq.begin());
484     const auto& hlos = group.elements;
485     VLOG(2) << "GroupHlosForConcat dequeued " << hlos[0]->ToString();
486     bool group_is_param_gtes = false;
487     if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
488           return element == hlos[0];
489         })) {
490       // Shared operand.
491       if (groups->GetGroupIndex(hlos[0]).has_value()) {
492         VLOG(1) << "We do not support the case if a shared operand also part "
493                    "of a group: "
494                 << hlos[0]->ToString();
495         return fail_and_cleanup();
496       }
497       groups->DisallowGroupingOn(hlos[0]);
498       continue;
499     }
500     if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
501           return element->opcode() == HloOpcode::kGetTupleElement &&
502                  element->operand(0) == body->parameter_instruction(0);
503         })) {
504       group_is_param_gtes = true;
505     } else if (((hlos[0]->IsElementwise() ||
506                  hlos[0]->opcode() == HloOpcode::kAllReduce) &&
507                 !hlos[0]->HasSideEffect()) ||
508                hlos[0]->opcode() == HloOpcode::kBroadcast ||
509                hlos[0]->opcode() == HloOpcode::kReduce ||
510                hlos[0]->opcode() == HloOpcode::kReshape ||
511                hlos[0]->IsCustomCall("Sharding")) {
512       if (hlos[0]->opcode() == HloOpcode::kAllReduce &&
513           (!hlos[0]->shape().IsArray() || hlos[0]->IsCrossModuleAllReduce())) {
514         VLOG(2) << "Unsupported allreduce: " << hlos[0]->ToString();
515         return fail_and_cleanup();
516       }
517       // Check if these elements can be concatenated.
518       if (absl::c_any_of(hlos, [&](const HloInstruction* element) {
519             auto eq_operand = [](const HloInstruction* a,
520                                  const HloInstruction* b) {
521               return ShapeUtil::Compatible(a->shape(), b->shape());
522             };
523             auto eq_computations = [](const HloComputation* lhs,
524                                       const HloComputation* rhs) {
525               return lhs->Equal(*rhs, /*is_layout_sensitive=*/false);
526             };
527             if (!hlos[0]->Identical(*element, eq_operand, eq_computations,
528                                     /*layout_sensitive=*/false)) {
529               return true;
530             }
531             if (element->opcode() == HloOpcode::kReduce &&
532                 (element->operand_count() != 2 ||
533                  element->operand(1) != hlos[0]->operand(1))) {
534               return true;
535             }
536             return false;
537           })) {
538         VLOG(2) << "Different types of elements. First element: "
539                 << hlos[0]->ToString();
540         return fail_and_cleanup();
541       }
542       // Now enqueue the inputs.
543       int64_t input_count = hlos[0]->operand_count();
544       if (hlos[0]->opcode() == HloOpcode::kReduce) {
545         CHECK_EQ(input_count, 2);
546         // Exclude the init value that we have checked to be the same.
547         input_count = 1;
548       }
549       for (int64_t i = 0; i < input_count; ++i) {
550         std::vector<HloInstruction*> elements(group_size);
551         for (int64_t j = 0; j < group_size; ++j) {
552           elements[j] = hlos[j]->mutable_operand(i);
553         }
554         auto maybe_new_concat_dim = GetOperandConcatDim(
555             hlos[0], i, group.concat_dim, group.inserted_concat_dim);
556         if (!maybe_new_concat_dim.has_value()) {
557           VLOG(2) << "Cannot find operand concat dimension for operand " << i
558                   << " of " << hlos[0]->ToString();
559           return fail_and_cleanup();
560         }
561         int64_t new_group_concat_dim = maybe_new_concat_dim->first;
562         bool inserted_concat_dim = maybe_new_concat_dim->second;
563         // Enqueue the input group.
564         int64_t element_order = -topological_order[elements[0]];
565         pq.emplace(element_order,
566                    ConcatGroup(std::move(elements), new_group_concat_dim,
567                                inserted_concat_dim));
568       }
569     } else if (hlos[0]->opcode() == HloOpcode::kSlice) {
570       int64_t offset = 0;
571       auto operand = hlos[0]->operand(0);
572       if (group.inserted_concat_dim) {
573         VLOG(2) << "Slices cannot be grouped on new dimension.";
574         return fail_and_cleanup();
575       }
576       if (groups->GetGroupIndex(operand).has_value()) {
577         // Should not slice an operand to be grouped.
578         return fail_and_cleanup();
579       }
580       groups->DisallowGroupingOn(operand);
581       for (int64_t i = 0; i < group_size; ++i) {
582         if (hlos[i]->operand(0) != operand) {
583           VLOG(2) << "Slices of different operands.";
584           return fail_and_cleanup();
585         }
586         for (int64_t j = 0; j < hlos[i]->shape().rank(); ++j) {
587           if (hlos[i]->slice_strides(j) != 1) {
588             VLOG(2) << "Slices with strides.";
589             return fail_and_cleanup();
590           }
591           if (j == group.concat_dim) {
592             if (hlos[i]->slice_starts(j) != offset) {
593               VLOG(2) << "Slices with unsupported offsets.";
594               return fail_and_cleanup();
595             }
596             offset += hlos[i]->shape().dimensions(j);
597           } else {
598             if (hlos[i]->slice_starts(j) != 0 ||
599                 hlos[i]->slice_limits(j) != operand->shape().dimensions(j)) {
600               VLOG(2) << "Slice with unsupported offsets at dimension " << j
601                       << ", " << hlos[i]->ToString();
602               return fail_and_cleanup();
603             }
604           }
605         }
606       }
607       if (offset != operand->shape().dimensions(group.concat_dim)) {
608         VLOG(2) << "Slices with unsupported sizes.";
609         return fail_and_cleanup();
610       }
611     } else {
612       VLOG(2) << "Unsupported opcode: " << hlos[0]->ToString();
613       return fail_and_cleanup();
614     }
615     auto guse = maybe_create_group(std::move(group));
616     if (guse.group_id < 0) {
617       VLOG(2) << "Failed to create group.";
618       return fail_and_cleanup();
619     }
620     const auto& registered_group = groups->GetGroup(guse.group_id);
621     if (!guse.already_used_by_subcomp && group_is_param_gtes) {
622       // When we processed a group of parameter GTEs, we should also enqueue the
623       // corresponding root tuple operands, so that they have matching shapes.
624       std::vector<HloInstruction*> new_outputs(group_size);
625       for (int64_t i = 0; i < group_size; ++i) {
626         new_outputs[i] = root_tuple->mutable_operand(
627             registered_group.elements[i]->tuple_index());
628       }
629       int64_t new_output_order = -topological_order[new_outputs[0]];
630       pq.emplace(
631           new_output_order,
632           ConcatGroup(std::move(new_outputs), registered_group.concat_dim,
633                       registered_group.inserted_concat_dim));
634     }
635   }
636   return groups->Groups().size() > first_group_id_to_create;
637 }
638 
TupleElementsUsedInCond(HloInstruction * loop)639 std::vector<bool> TupleElementsUsedInCond(HloInstruction* loop) {
640   std::vector<bool> result(loop->shape().tuple_shapes_size(), false);
641   for (auto user : loop->while_condition()->parameter_instruction(0)->users()) {
642     if (user->opcode() != HloOpcode::kGetTupleElement) {
643       absl::c_fill(result, true);
644       return result;
645     }
646     result[user->tuple_index()] = true;
647   }
648   return result;
649 }
650 
651 // Adds copies to returned values to keep RewriteLoopWithConcatGroups simple:
652 // the copies do not have other users and only appear once in the root tuple.
AddCopiesToRoot(HloComputation * body,absl::Span<HloInstruction * const> param_gtes,ConcatGroups * groups)653 Status AddCopiesToRoot(HloComputation* body,
654                        absl::Span<HloInstruction* const> param_gtes,
655                        ConcatGroups* groups) {
656   auto root = body->root_instruction();
657   CHECK_EQ(root->opcode(), HloOpcode::kTuple);
658   std::vector<HloInstruction*> copies(root->operand_count(), nullptr);
659   for (int64_t i = 0; i < copies.size(); ++i) {
660     auto element = root->mutable_operand(i);
661     if (!element->shape().IsArray()) {
662       continue;
663     }
664     copies[i] = body->AddInstruction(HloInstruction::CreateUnary(
665         element->shape(), HloOpcode::kCopy, element));
666     TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i]));
667   }
668   for (int64_t i = 0; i < copies.size(); ++i) {
669     auto copy = copies[i];
670     if (groups->GetGroupIndex(copy).has_value()) {
671       // Already handled by earlier group members.
672       continue;
673     }
674     auto param_group_index = groups->GetGroupIndex(param_gtes[i]);
675     if (!param_group_index.has_value()) {
676       continue;
677     }
678     const auto& param_group = groups->GetGroup(param_group_index->first);
679     std::vector<HloInstruction*> copy_group(param_group.elements.size());
680     for (int64_t j = 0; j < copy_group.size(); ++j) {
681       copy_group[j] = copies[param_group.elements[j]->tuple_index()];
682     }
683     CHECK(groups
684               ->MaybeCreateNewGroup(
685                   ConcatGroup(std::move(copy_group), param_group.concat_dim,
686                               param_group.inserted_concat_dim))
687               .first);
688   }
689   return Status::OK();
690 }
691 
RemoveCopiesFromRoot(HloComputation * body)692 Status RemoveCopiesFromRoot(HloComputation* body) {
693   auto root = body->root_instruction();
694   CHECK_EQ(root->opcode(), HloOpcode::kTuple);
695   for (int64_t i = 0; i < root->operand_count(); ++i) {
696     auto copy = root->mutable_operand(i);
697     if (copy->opcode() == HloOpcode::kCopy) {
698       TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0)));
699     }
700   }
701   return Status::OK();
702 }
703 
RewriteLoopWithConcatGroups(HloInstruction * loop,absl::Span<HloInstruction * const> param_gtes,ConcatGroups & groups)704 Status RewriteLoopWithConcatGroups(HloInstruction* loop,
705                                    absl::Span<HloInstruction* const> param_gtes,
706                                    ConcatGroups& groups) {
707   VLOG(1) << "RewriteLoopWithConcatGroups with " << groups.Groups().size()
708           << " groups.";
709   // For simplicity, for each group, we rewrite the first element into full
710   // shape, and leave the other elements unchagned. Non-grouped users will be
711   // have slices of the expanded first element as the new input. Later
712   // simplification and DCE passes can remove the other elements.
713   absl::flat_hash_set<int64> processed_groups;
714   auto body = loop->while_body();
715   auto param = body->parameter_instruction(0);
716   auto cond_param = loop->while_condition()->parameter_instruction(0);
717 
718   // First, modify loop signature and operands/users.
719   std::vector<HloInstruction*> init_elements(loop->shape().tuple_shapes_size());
720   for (int64_t i = 0; i < param_gtes.size(); ++i) {
721     init_elements[i] =
722         loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
723             loop->shape().tuple_shapes(i), loop->mutable_operand(0), i));
724   }
725   for (int64_t i = 0; i < param_gtes.size(); ++i) {
726     const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
727     if (!group_and_index.has_value() || group_and_index->second != 0) {
728       continue;
729     }
730     const auto& group = groups.GetGroup(group_and_index->first);
731     // Change body parameter shape.
732     *param_gtes[i]->mutable_shape() = group.GetConcatShape();
733     *param->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
734     *body->root_instruction()->mutable_shape()->mutable_tuple_shapes(i) =
735         param_gtes[i]->shape();
736     *cond_param->mutable_shape()->mutable_tuple_shapes(i) =
737         param_gtes[i]->shape();
738     *loop->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
739     processed_groups.insert(group_and_index->first);
740     std::vector<HloInstruction*> input_concat_elements;
741     input_concat_elements.reserve(group.elements.size());
742     for (auto param_gte : group.elements) {
743       input_concat_elements.push_back(init_elements[param_gte->tuple_index()]);
744     }
745     init_elements[i] =
746         group.CreateConcat(std::move(input_concat_elements), loop->parent());
747   }
748   TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(
749       0, loop->parent()->AddInstruction(
750              HloInstruction::CreateTuple(init_elements))));
751   // Adjust loop users.
752   auto original_loop_users = loop->users();
753   const bool loop_is_root = loop == loop->parent()->root_instruction();
754   std::vector<HloInstruction*> output_elements(
755       loop->shape().tuple_shapes_size());
756   for (int64_t i = 0; i < param_gtes.size(); ++i) {
757     output_elements[i] =
758         loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
759             init_elements[i]->shape(), loop, i));
760   }
761   for (int64_t i = 0; i < param_gtes.size(); ++i) {
762     const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
763     if (!group_and_index.has_value() || group_and_index->second != 0) {
764       continue;
765     }
766     const auto& group = groups.GetGroup(group_and_index->first);
767     auto concat_output = output_elements[group.elements[0]->tuple_index()];
768     for (int64_t j = 0; j < group.elements.size(); ++j) {
769       const auto param_gte = group.elements[j];
770       output_elements[param_gte->tuple_index()] =
771           group.CreateSlice(concat_output, j, loop->parent());
772     }
773   }
774   auto new_output_tuple = loop->parent()->AddInstruction(
775       HloInstruction::CreateTuple(output_elements));
776   for (auto user : original_loop_users) {
777     TF_RETURN_IF_ERROR(
778         loop->ReplaceUseWithDifferentShape(user, new_output_tuple));
779   }
780   if (loop_is_root) {
781     loop->parent()->set_root_instruction(new_output_tuple,
782                                          /*accept_different_shape=*/true);
783   }
784 
785   // Now rewrite the loop body.
786   std::vector<HloInstruction*> slices_to_remove;
787   absl::flat_hash_set<HloInstruction*> new_reshapes;
788   for (auto hlo : body->MakeInstructionPostOrder()) {
789     const auto& group_and_index = groups.GetGroupIndex(hlo);
790     if (!group_and_index.has_value() || group_and_index->second != 0) {
791       continue;
792     }
793 
794     if (!processed_groups.insert(group_and_index->first).second) {
795       // Already processed the group at the first element.
796       continue;
797     }
798     const auto& group = groups.GetGroup(group_and_index->first);
799     if (hlo->opcode() == HloOpcode::kSlice) {
800       // We could just replace hlo with its operand; however, to follow the
801       // practice of using the first element as full data, we defer that
802       // replacement.
803       slices_to_remove.push_back(hlo);
804     } else {
805       int64_t operand_count_to_adjust = hlo->operand_count();
806       if (hlo->opcode() == HloOpcode::kReduce) {
807         CHECK_EQ(operand_count_to_adjust, 2);
808         operand_count_to_adjust = 1;
809       }
810       for (int64_t i = 0; i < operand_count_to_adjust; ++i) {
811         auto operand_group_index = groups.GetGroupIndex(hlo->operand(i));
812         const ConcatGroup* operand_group =
813             operand_group_index.has_value()
814                 ? &groups.GetGroup(operand_group_index->first)
815                 : nullptr;
816         auto maybe_operand_concat_dim = GetOperandConcatDim(
817             hlo, i, group.concat_dim, group.inserted_concat_dim, operand_group);
818         CHECK(maybe_operand_concat_dim.has_value())
819             << "Operand " << i << " of " << hlo->ToString();
820         int64_t operand_concat_dim = maybe_operand_concat_dim->first;
821         bool operand_inserted_concat_dim = maybe_operand_concat_dim->second;
822         if (operand_group != nullptr) {
823           CHECK_EQ(operand_concat_dim, operand_group->concat_dim);
824           if (operand_inserted_concat_dim !=
825               operand_group->inserted_concat_dim) {
826             // The operand's actual inserted_concat_dim doesn't match the
827             // expected operand_inserted_concat_dim. Need a reshape.
828             std::vector<int64> new_dims;
829             int64_t d = 0;
830             for (; d < operand_concat_dim; ++d) {
831               new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
832             }
833             if (operand_inserted_concat_dim) {
834               // Split operand concat dim.
835               new_dims.push_back(group.elements.size());
836               new_dims.push_back(
837                   hlo->operand(i)->shape().dimensions(operand_concat_dim) /
838                   group.elements.size());
839               d = operand_concat_dim + 1;
840             } else {
841               // Combine operand concat dim with the next.
842               new_dims.push_back(
843                   group.elements.size() *
844                   hlo->operand(i)->shape().dimensions(operand_concat_dim + 1));
845               d = operand_concat_dim + 2;
846             }
847             for (; d < hlo->operand(i)->shape().rank(); ++d) {
848               new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
849             }
850             auto reshape = body->AddInstruction(HloInstruction::CreateReshape(
851                 ShapeUtil::MakeShape(hlo->operand(i)->shape().element_type(),
852                                      new_dims),
853                 hlo->mutable_operand(i)));
854             new_reshapes.insert(reshape);
855             TF_RETURN_IF_ERROR(
856                 hlo->ReplaceOperandWithDifferentShape(i, reshape));
857           }
858           continue;
859         }
860         // This is a shared operand, we need to broadcast it.
861         CHECK(
862             absl::c_all_of(group.elements, [&](const HloInstruction* element) {
863               return element->operand(i) == hlo->operand(i);
864             }));
865         VLOG(2) << "Broadcasting shared operand "
866                 << hlo->operand(i)->ToString();
867         Shape data_shape = hlo->operand(i)->shape();
868         std::vector<int64> broadcast_dims;
869         std::vector<int64> broadcast_shape;
870         for (int64_t j = 0; j < data_shape.rank(); ++j) {
871           if (j < operand_concat_dim) {
872             broadcast_dims.push_back(j);
873           } else {
874             broadcast_dims.push_back(j + 1);
875           }
876           if (j == operand_concat_dim) {
877             broadcast_shape.push_back(group.elements.size());
878           }
879           broadcast_shape.push_back(data_shape.dimensions(j));
880         }
881         if (broadcast_shape.size() == data_shape.rank()) {
882           // New dim at the end.
883           broadcast_shape.push_back(group.elements.size());
884         }
885         auto broadcast = body->AddInstruction(HloInstruction::CreateBroadcast(
886             ShapeUtil::MakeShape(data_shape.element_type(), broadcast_shape),
887             hlo->mutable_operand(i), broadcast_dims));
888 
889         if (!operand_inserted_concat_dim) {
890           // Concat on existing dim. Reshape to merge the broadcast dim.
891           data_shape.set_dimensions(
892               operand_concat_dim,
893               data_shape.dimensions(operand_inserted_concat_dim) *
894                   group.elements.size());
895           broadcast = body->AddInstruction(
896               HloInstruction::CreateReshape(data_shape, broadcast));
897         }
898         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast));
899       }
900     }
901     VLOG(2) << "Modifying HLO to full shape " << hlo->ToString();
902     ModifyHloPropertiesForConcatShape(group, hlo);
903     VLOG(2) << "Modified HLO to full shape " << hlo->ToString();
904   }
905 
906   // For non-grouped HLOs, replace grouped inputs with slices. Also inlcude
907   // grouped reduce HLOs because their init values are not grouped.
908   for (auto hlo : body->MakeInstructionPostOrder()) {
909     if (new_reshapes.contains(hlo)) {
910       continue;
911     }
912     const auto& group_and_index = groups.GetGroupIndex(hlo);
913     if ((!group_and_index.has_value() || hlo->opcode() == HloOpcode::kReduce) &&
914         hlo != body->root_instruction()) {
915       auto operands = hlo->operands();
916       if (group_and_index.has_value()) {
917         // Only handle reduce init value.
918         CHECK_EQ(operands.size(), 2);
919         CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
920         operands.erase(operands.begin());
921       }
922       for (int64_t i = 0; i < operands.size(); ++i) {
923         auto operand = operands[i];
924         auto operand_group_index = groups.GetGroupIndex(operand);
925         if (!operand_group_index.has_value()) {
926           continue;
927         }
928         const auto& operand_group = groups.GetGroup(operand_group_index->first);
929         auto slice = operand_group.CreateSlice(
930             operand_group.elements[0], operand_group_index->second, body);
931         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice));
932       }
933     }
934   }
935   for (auto slice : slices_to_remove) {
936     TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0)));
937     TF_RETURN_IF_ERROR(body->RemoveInstruction(slice));
938   }
939   return Status::OK();
940 }
941 
RunOnLoop(HloInstruction * loop,int64_t min_operand_count_to_optimize)942 StatusOr<bool> RunOnLoop(HloInstruction* loop,
943                          int64_t min_operand_count_to_optimize) {
944   auto body = loop->while_body();
945   auto param = body->parameter_instruction(0);
946   auto root = body->root_instruction();
947   if (!param->shape().IsTuple() || root->opcode() != HloOpcode::kTuple) {
948     return false;
949   }
950   std::vector<HloInstruction*> gtes(param->shape().tuple_shapes_size(),
951                                     nullptr);
952   ConcatGroups groups;
953   auto indices_used_in_cond = TupleElementsUsedInCond(loop);
954   for (auto user : param->users()) {
955     if (user->opcode() != HloOpcode::kGetTupleElement) {
956       // Unhandled user opcode.
957       return false;
958     }
959     int64_t idx = user->tuple_index();
960     if (gtes[idx] != nullptr) {
961       // Seen this index before.
962       return false;
963     }
964     gtes[idx] = user;
965     if (indices_used_in_cond[idx]) {
966       groups.DisallowGroupingOn(user);
967     }
968   }
969   std::vector<HloInstruction*> concats;
970   auto body_instructions = body->MakeInstructionPostOrder();
971   absl::flat_hash_map<const HloInstruction*, int64> topological_order;
972   for (int64_t i = 0; i < body_instructions.size(); ++i) {
973     auto hlo = body_instructions[i];
974     topological_order[hlo] = i;
975     if (hlo->opcode() == HloOpcode::kConcatenate &&
976         hlo->operand_count() >= min_operand_count_to_optimize) {
977       concats.push_back(hlo);
978     }
979   }
980 
981   for (auto& concat : concats) {
982     if (!GroupHlosForConcat(body, concat, topological_order, &groups)) {
983       concat = nullptr;
984     }
985   }
986   if (groups.Groups().empty()) {
987     return false;
988   }
989 
990   TF_RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups));
991   TF_RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups));
992   for (auto concat : concats) {
993     if (concat == nullptr) {
994       continue;
995     }
996     // We have repalced the operands of the concat with slices of full data.
997     auto new_slice = concat->mutable_operand(0);
998     CHECK_EQ(new_slice->opcode(), HloOpcode::kSlice);
999     TF_RETURN_IF_ERROR(
1000         concat->ReplaceAllUsesWith(new_slice->mutable_operand(0)));
1001     TF_RETURN_IF_ERROR(body->RemoveInstruction(concat));
1002   }
1003   TF_RETURN_IF_ERROR(RemoveCopiesFromRoot(body));
1004   // Finally pass-through replaced elements from parameter to root, so that
1005   // while loop simplifier can get rid of them.
1006   for (auto gte : gtes) {
1007     auto group_index = groups.GetGroupIndex(gte);
1008     if (group_index.has_value() && group_index->second > 0) {
1009       TF_RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte));
1010     }
1011   }
1012   return true;
1013 }
1014 
1015 }  // namespace
1016 
Run(HloModule * module)1017 StatusOr<bool> WhileLoopConcatCodeMotion::Run(HloModule* module) {
1018   bool changed = false;
1019   for (HloComputation* comp : module->MakeComputationPostOrder()) {
1020     for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) {
1021       if (hlo->opcode() == HloOpcode::kWhile) {
1022         TF_ASSIGN_OR_RETURN(bool loop_changed,
1023                             RunOnLoop(hlo, min_operand_count_to_optimize_));
1024         changed |= loop_changed;
1025       }
1026     }
1027   }
1028   if (changed) {
1029     HloPassPipeline pipeline("loop-concat-motion-cleanup");
1030     pipeline.AddPass<TupleSimplifier>();
1031     pipeline.AddPass<HloDCE>();
1032     pipeline.AddPass<WhileLoopSimplifier>();
1033     pipeline.AddPass<TupleSimplifier>();
1034     pipeline.AddPass<HloDCE>();
1035     TF_RETURN_IF_ERROR(pipeline.Run(module).status());
1036   }
1037   return changed;
1038 }
1039 
1040 }  // namespace xla
1041