• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/bits.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/types.h"
66 #include "tensorflow/stream_executor/lib/statusor.h"
67 
68 namespace xla {
69 
70 namespace {
71 
72 namespace m = match;
73 
IsAll(const HloInstruction * op,int8 value)74 bool IsAll(const HloInstruction* op, int8 value) {
75   switch (op->opcode()) {
76     case HloOpcode::kBroadcast:
77       return IsAll(op->operand(0), value);
78     case HloOpcode::kConstant:
79       return op->literal().IsAll(value);
80     default:
81       return false;
82   }
83 }
84 
IsAnyOperandComplex(const HloInstruction * hlo)85 bool IsAnyOperandComplex(const HloInstruction* hlo) {
86   for (auto operand : hlo->operands()) {
87     if (ShapeUtil::ElementIsComplex(operand->shape())) {
88       return true;
89     }
90   }
91   return false;
92 }
93 
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)94 bool IsPositive(const HloInstruction* hlo,
95                 const AlgebraicSimplifierOptions& options) {
96   // Utility only handles real types.
97   if (IsAnyOperandComplex(hlo)) {
98     return false;
99   }
100   switch (hlo->opcode()) {
101     case HloOpcode::kGetTupleElement: {
102       const HloInstruction* gte_operand = hlo->operand(0);
103       switch (gte_operand->opcode()) {
104         case HloOpcode::kCustomCall: {
105           const auto& target = gte_operand->custom_call_target();
106           return target ==
107                      options.get_cudnn_batchnorm_forward_training_metadata() &&
108                  hlo->tuple_index() == 2;
109         }
110         default:
111           return false;
112       }
113     }
114     case HloOpcode::kPower:
115     case HloOpcode::kAbs:
116     case HloOpcode::kRsqrt:
117     case HloOpcode::kSqrt:
118       return IsPositive(hlo->operand(0), options);
119 
120     case HloOpcode::kMultiply: {
121       return hlo->operand(0) == hlo->operand(1) &&
122              IsPositive(hlo->operand(0), options);
123     }
124     default:
125       return false;
126   }
127 }
128 
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)129 bool IsNonNegative(const HloInstruction* hlo,
130                    const AlgebraicSimplifierOptions& options) {
131   // Utility only handles real types.
132   if (IsAnyOperandComplex(hlo)) {
133     return false;
134   }
135   switch (hlo->opcode()) {
136     case HloOpcode::kMultiply: {
137       return hlo->operand(0) == hlo->operand(1);
138     }
139     case HloOpcode::kAbs: {
140       return true;
141     }
142     default:
143       return IsPositive(hlo, options);
144   }
145 }
146 
147 // Checks whether `op` is a floating-point constant or broadcast of a constant
148 // of the form +/- 2^k for some integer k positive, negative, or zero.  Such
149 // values are interesting because multiplying by a power of 2 just moves the
150 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)151 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
152   // Unwrap the broadcast if necessary.
153   const HloInstruction* c;
154   if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
155       !Match(op, m::Broadcast(m::Constant(&c).WithShape(
156                      m::Shape().IsEffectiveScalar())))) {
157     return false;
158   }
159   auto val = [&]() -> absl::optional<double> {
160     switch (c->shape().element_type()) {
161       case BF16:
162         return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
163       case F16:
164         return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
165       case F32:
166         return c->literal().GetFirstElement<float>();
167       case F64:
168         return c->literal().GetFirstElement<double>();
169       default:
170         // Cowardly refuse to consider complex types.
171         return absl::nullopt;
172     }
173   }();
174   if (!val) {
175     return false;
176   }
177 
178   int exp;
179   double mantissa = std::frexp(*val, &exp);
180   // frexp returns a value in the range (-1, -0.5] U [0.5, 1).  A return value
181   // of +/-0.5 therefore indicates that the floating point value is a power of
182   // 2.
183   return mantissa == 0.5 || mantissa == -0.5;
184 }
185 
186 // Returns whether the given transpose produces a result which is bit-wise
187 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)188 bool TransposeIsBitcast(const HloInstruction* transpose) {
189   CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
190   const HloInstruction* operand = transpose->operand(0);
191   return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
192                                        transpose->dimensions());
193 }
194 
195 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)196 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
197     HloInstruction* instr, HloInstruction* operand,
198     const AlgebraicSimplifierOptions& options) {
199   // Can't replace chain of copies and reshapes with bitcasts if the compiler
200   // used a memory layout which isn't compatible.
201   if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
202     return operand;
203   }
204 
205   // If the operand is a copy or reshape try to see if the operand's operand
206   // would produce a bitcast with initial instruction.
207   if (HloOpcode::kReshape == operand->opcode() ||
208       HloOpcode::kCopy == operand->opcode()) {
209     return BitcastingOperandOfReshapeOrCopyChainHelper(
210         instr, operand->mutable_operand(0), options);
211   }
212   return nullptr;
213 }
214 
215 // Returns an operand of a chain of reshapes and copies that is bit-wise
216 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)217 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
218     HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
219   if (!options.is_layout_sensitive()) {
220     return nullptr;
221   }
222   CHECK(HloOpcode::kReshape == instr->opcode() ||
223         HloOpcode::kCopy == instr->opcode());
224   return BitcastingOperandOfReshapeOrCopyChainHelper(
225       instr, instr->mutable_operand(0), options);
226 }
227 
IsUnstridedSlice(const HloInstruction * hlo)228 bool IsUnstridedSlice(const HloInstruction* hlo) {
229   return absl::c_all_of(hlo->slice_strides(),
230                         [](int64 stride) { return stride == 1; });
231 }
232 
233 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)234 bool IsConvertPairNoOp(const HloInstruction* convert) {
235   //    [operand_convert]         [convert]
236   // (src)->convert-(intermediate)->convert-(dest)
237   const HloInstruction* operand_convert = convert->operand(0);
238   CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
239   const Shape& src_shape = operand_convert->operand(0)->shape();
240   const Shape& intermediate_shape = operand_convert->shape();
241   const Shape& dest_shape = convert->shape();
242 
243   const PrimitiveType src_type = src_shape.element_type();
244   const PrimitiveType intermediate_type = intermediate_shape.element_type();
245   const PrimitiveType dest_type = dest_shape.element_type();
246 
247   // src_type must be equal to dest_type.
248   if (src_type != dest_type) {
249     return false;
250   }
251 
252   // src_type must be a larger container than intermediate_type.
253   if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
254       ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
255     return false;
256   }
257 
258   // Both src_type and intermediate_type must be either floating or integral.
259   bool is_conversion_floating =
260       ShapeUtil::ElementIsFloating(src_shape) &&
261       ShapeUtil::ElementIsFloating(intermediate_shape);
262   bool is_conversion_integral =
263       ShapeUtil::ElementIsIntegral(src_shape) &&
264       ShapeUtil::ElementIsIntegral(intermediate_shape);
265 
266   return is_conversion_floating || is_conversion_integral;
267 }
268 
269 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
270 // algebraic expressions to simplified forms. Note: This only supports
271 // simplifications that simply look at the operands of an instruction. For the
272 // more general case a worklist based approach would be needed.
273 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
274  public:
AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)275   explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
276                                       AlgebraicSimplifier* simplifier)
277       : options_(options), simplifier_(simplifier) {}
278 
279   Status HandleAbs(HloInstruction* abs) override;
280 
281   Status HandleAdd(HloInstruction* add) override;
282 
283   Status HandleAnd(HloInstruction* logical_and) override;
284 
285   Status HandleBitcast(HloInstruction* bitcast) override;
286 
287   Status HandleBitcastConvert(HloInstruction* bitcast) override;
288 
289   Status HandleBroadcast(HloInstruction* broadcast) override;
290 
291   Status HandleCompare(HloInstruction* compare) override;
292 
293   Status HandleConcatenate(HloInstruction* concatenate) override;
294 
295   Status HandleConstant(HloInstruction* constant) override;
296 
297   Status HandleCopy(HloInstruction* copy) override;
298 
299   Status HandleConvert(HloInstruction* convert) override;
300 
301   Status HandleComplex(HloInstruction* complex) override;
302 
303   Status HandleReal(HloInstruction* real) override;
304 
305   Status HandleImag(HloInstruction* imag) override;
306 
307   Status HandleIota(HloInstruction* instruction) override;
308 
309   Status HandleConvolution(HloInstruction* convolution) override;
310 
311   Status HandleDivide(HloInstruction* divide) override;
312 
313   Status HandleDot(HloInstruction* dot) override;
314 
315   Status HandleGather(HloInstruction* gather) override;
316 
317   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
318 
319   Status HandleLog(HloInstruction* log) override;
320 
321   Status HandleMaximum(HloInstruction* maximum) override;
322 
323   Status HandleMinimum(HloInstruction* minimum) override;
324 
325   Status HandleClamp(HloInstruction* clamp) override;
326 
327   Status HandleMultiply(HloInstruction* multiply) override;
328 
329   Status HandleNegate(HloInstruction* negate) override;
330 
331   Status HandleNot(HloInstruction* logical_not) override;
332 
333   Status HandleOr(HloInstruction* logical_or) override;
334 
335   Status HandlePad(HloInstruction* pad) override;
336 
337   Status HandlePower(HloInstruction* power) override;
338 
339   Status HandleRemainder(HloInstruction* remainder) override;
340 
341   Status HandleReshape(HloInstruction* reshape) override;
342 
343   Status HandleReduce(HloInstruction* hlo) override;
344 
345   Status HandleReduceWindow(HloInstruction* reduce_window) override;
346 
347   Status HandleReverse(HloInstruction* reverse) override;
348 
349   Status HandleRsqrt(HloInstruction* rsqrt) override;
350 
351   Status HandleSlice(HloInstruction* slice) override;
352 
353   Status HandleSqrt(HloInstruction* sqrt) override;
354 
355   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
356 
357   Status HandleDynamicUpdateSlice(
358       HloInstruction* dynamic_update_slice) override;
359   Status HandleScatter(HloInstruction* scatter) override;
360 
361   Status HandleSelect(HloInstruction* select) override;
362 
363   Status HandleSort(HloInstruction* sort) override;
364 
365   Status HandleTranspose(HloInstruction* transpose) override;
366 
367   Status HandleSubtract(HloInstruction* sub) override;
368 
369   Status HandleMap(HloInstruction* map) override;
370 
371   // Runs the visitor on a computation.
372   bool Run(HloComputation* computation,
373            const AlgebraicSimplifierOptions& options,
374            AlgebraicSimplifier* simplifier);
375 
376  private:
377   // Removes degenerate dimension from dot.
378   StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
379 
380   // Converts to primitive type if the input hlo is not that type, otherwise
381   // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)382   HloInstruction* AsType(HloInstruction* hlo,
383                          const PrimitiveType element_type) {
384     if (hlo->shape().element_type() == element_type) {
385       return hlo;
386     }
387     Shape changed_shape =
388         ShapeUtil::ChangeElementType(hlo->shape(), element_type);
389     simplifier_->UpdateLayout(&changed_shape);
390     return computation_->AddInstruction(
391         HloInstruction::CreateConvert(changed_shape, hlo));
392   }
393 
394   // Transposes a dot operand such that the batch dimensions are the most major,
395   // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)396   StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
397       HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
398       absl::Span<const int64> contracting_dimensions) {
399     std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
400                                             batch_dimensions.end());
401     for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
402       if (!(absl::c_linear_search(batch_dimensions, i) ||
403             absl::c_linear_search(contracting_dimensions, i))) {
404         transpose_dimensions.push_back(i);
405       }
406     }
407     transpose_dimensions.insert(transpose_dimensions.end(),
408                                 contracting_dimensions.begin(),
409                                 contracting_dimensions.end());
410     if (absl::c_is_sorted(transpose_dimensions)) {
411       return dot_operand;
412     }
413     return MakeTransposeHlo(dot_operand, transpose_dimensions);
414   }
415 
416   // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims,PrimitiveType type)417   HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims,
418                             PrimitiveType type) {
419     HloInstruction* zero = computation_->AddInstruction(
420         simplifier_->CreateConstantWithLayoutUpdated(
421             LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
422     HloComputation* AddReduce_computation =
423         GetOrCreateScalarAddComputation(type);
424     Shape shape = ShapeUtil::FilterDimensions(
425         [&](int64 dim) { return !absl::c_linear_search(dims, dim); },
426         hlo->shape());
427     simplifier_->UpdateLayout(&shape);
428     return computation_->AddInstruction(HloInstruction::CreateReduce(
429         shape, hlo, zero, dims, AddReduce_computation));
430   }
431 
432   // Move scalar multiply to the smallest side of convolution to
433   // reduce multiply computations.
434   Status ScalarMultiplyReduction(HloInstruction* dot);
435 
436   // Convenience method for replacing an instruction with a bitcast. If operand
437   // is not null, then the bitcast will use the specified operand instead of the
438   // operand of the instruction.
439   void ReplaceWithBitcast(HloInstruction* instruction,
440                           HloInstruction* operand = nullptr);
441 
442   // Replace old instruction with new instruction if old and new instructions
443   // have the same shape. Updates uses and root instruction. Returns whether a
444   // replacement was made.
445   bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
446                                      HloInstruction* new_instruction);
447 
448   // Returns whether the shape of the output of the given instructions are the
449   // same for the purposes of simplification. If options_.is_layout_sensitive()
450   // is true, then this tests shape equality including layout
451   // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
452   // tests shape compatibility (ShapeUtil::Compatible).
453   bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
454 
455   // Returns whether it was possible to transform `root` to a clamp instruction.
456   // With min a minimum instruction, max a maximum instruction, min_operand a
457   // operand of min and max_operand a operand of max.
458   // Precondition: root is either a minimum or a maximum.
459   bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
460                                    HloInstruction* min_operand,
461                                    HloInstruction* operand, HloInstruction* max,
462                                    HloInstruction* max_operand);
463 
464   // A Broadcast that feeds an element-wise operation with a unique non-scalar
465   // operand can sink to after the operation.
466   StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
467       HloInstruction* broadcast);
468 
469   StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
470   StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
471       const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
472       HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
473 
474   StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
475 
476   StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
477       HloInstruction* dot);
478 
GetOrCreateScalarAddComputation(PrimitiveType type)479   HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
480     HloComputation*& scalar_add_computation = scalar_add_computations_[type];
481     if (scalar_add_computation) {
482       return scalar_add_computation;
483     }
484 
485     HloComputation::Builder b("scalar_add_computation");
486     Shape shape = ShapeUtil::MakeShape(type, {});
487     simplifier_->UpdateLayout(&shape);
488     auto scalar_lhs = b.AddInstruction(
489         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
490     auto scalar_rhs = b.AddInstruction(
491         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
492     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
493         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
494     scalar_add_computation =
495         computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
496     return scalar_add_computation;
497   }
498 
499   // Tries to fold a kPad in the input or filter into the convolution
500   // instruction's window.
501   StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
502   StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
503 
504   // Tries to swap convolution operands if they would result in a more efficient
505   // convolution.
506   StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
507 
508   // Tries to use a kDot in place of the given convolution.
509   StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
510 
511   // Tries to simplify a slice where the result of the slice is a scalar.
512   StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
513 
514   // Tries to convert slice(reshape(X)) into reshape(slice(X))
515   StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
516 
517   // Tries to convert slice(reverse(X)) into reverse(slice(X))
518   StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice);
519 
520   // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
521   // `(< a N)`. This is crucial for being able to figure out the loop trip
522   // count.
523   //
524   // Assumes that the input is conjunction.
525   StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
526 
527   // Useful when we want to use the same visitor over multiple computations.
528   void ResetState(HloComputation* computation);
529 
530   // Current HloComputation instance the AlgebraicSimplifierVisitor is
531   // traversing.
532   HloComputation* computation_;
533 
534   // The backend-specific options selected for the algebraic simplifier.
535   const AlgebraicSimplifierOptions& options_;
536 
537   // Whether algebraic simplification has occurred.
538   bool changed_ = false;
539 
540   // Cached computation for adding two scalars of a given type.
541   absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
542 
543   AlgebraicSimplifier* simplifier_ = nullptr;
544 };
545 
546 }  // namespace
547 
ResetState(HloComputation * computation)548 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
549   changed_ = false;
550   ResetVisitStates();
551   computation_ = computation;
552 }
553 
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)554 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
555                                      const AlgebraicSimplifierOptions& options,
556                                      AlgebraicSimplifier* simplifier) {
557   ResetState(computation);
558   TF_CHECK_OK(computation->Accept(this));
559   return changed_ || changed();
560 }
561 
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const562 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
563                                            const HloInstruction* rhs) const {
564   if (options_.is_layout_sensitive()) {
565     return ShapeUtil::Equal(lhs->shape(), rhs->shape());
566   } else {
567     return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
568   }
569 }
570 
571 namespace {
572 
GetConstantValue(HloInstruction * inst)573 float GetConstantValue(HloInstruction* inst) {
574   switch (inst->shape().element_type()) {
575     case BF16:
576       return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
577     case F32:
578       return inst->literal().GetFirstElement<float>();
579     default:
580       LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type();
581   }
582 }
583 
IsOpCodeMultiplyCommutative(HloOpcode opcode)584 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
585   switch (opcode) {
586     case HloOpcode::kMultiply:
587     case HloOpcode::kTranspose:
588     case HloOpcode::kReshape:
589     case HloOpcode::kSelect:
590       return true;
591     default:
592       return false;
593   }
594 }
595 
MakeScalarInstruction(HloInstruction * target,float multiplier)596 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
597                                                       float multiplier) {
598   switch (target->shape().element_type()) {
599     case BF16:
600       return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
601           LiteralUtil::CreateR0<float>(multiplier)));
602       break;
603     case F32:
604       return HloInstruction::CreateConstant(
605           LiteralUtil::CreateR0<float>(multiplier));
606       break;
607     default:
608       LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
609   }
610 }
611 
612 }  // namespace
613 
ScalarMultiplyReduction(HloInstruction * dot)614 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
615     HloInstruction* dot) {
616   // We only process bfloat16 and float32 for now.
617   if (dot->shape().element_type() != BF16 &&
618       dot->shape().element_type() != F32) {
619     return Status::OK();
620   }
621 
622   auto lhs = dot->mutable_operand(0);
623   auto rhs = dot->mutable_operand(1);
624 
625   const int64 dot_size = ShapeUtil::ElementsIn(dot->shape());
626   const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape());
627   const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape());
628 
629   HloInstruction* target = nullptr;
630   // (current node, user, operand_index)
631   std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
632   std::vector<HloInstruction*> users;
633 
634   // Find which side of dot has the smallest size:
635   // operand 0, operand 1, or output.
636   if (dot_size <= std::min(lhs_size, rhs_size)) {
637     target = dot;
638     if (dot_size < lhs_size) {
639       operands.emplace_back(lhs, dot, 0);
640     }
641     if (dot_size < rhs_size) {
642       operands.emplace_back(rhs, dot, 1);
643     }
644   } else if (lhs_size <= rhs_size) {
645     target = lhs;
646     if (lhs_size < rhs_size) {
647       operands.emplace_back(rhs, dot, 1);
648     }
649     if (lhs_size < dot_size && dot->user_count() == 1) {
650       users.push_back(dot->users().front());
651     }
652   } else {
653     target = rhs;
654     if (rhs_size < lhs_size) {
655       operands.emplace_back(lhs, dot, 0);
656     }
657     if (rhs_size < dot_size && dot->user_count() == 1) {
658       users.push_back(dot->users().front());
659     }
660   }
661 
662   std::vector<float> values;
663 
664   // DFS to find scalar multiply ops from the operands.
665   while (!operands.empty()) {
666     HloInstruction* inst;
667     HloInstruction* user;
668     int64 index;
669     std::tie(inst, user, index) = operands.back();
670     operands.pop_back();
671 
672     // Skip the op types that are not commutative with multiply.
673     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
674       continue;
675     }
676 
677     HloInstruction* operand;
678     HloInstruction* multiplier;
679     // Pattern match a scalar multiply.
680     if (Match(inst, m::MultiplyAnyOrder(
681                         m::Op(&operand),
682                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
683       CHECK_LT(index, user->operand_count());
684       CHECK_EQ(inst, user->operands()[index]);
685 
686       // When found a scalar multiply, save its scalar value.
687       values.push_back(GetConstantValue(multiplier));
688       // And remove the scalar multiply op.
689       TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
690       inst = operand;
691     }
692 
693     // Push the operands of inst.
694     int64 i = 0;
695     for (auto* operand : inst->operands()) {
696       operands.emplace_back(operand, inst, i++);
697     }
698   }
699 
700   // DFS to find scalar multiply ops from the users.
701   while (!users.empty()) {
702     auto inst = users.back();
703     users.pop_back();
704 
705     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
706       continue;
707     }
708 
709     HloInstruction* operand;
710     HloInstruction* multiplier;
711     if (Match(inst, m::MultiplyAnyOrder(
712                         m::Op(&operand),
713                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
714       values.push_back(GetConstantValue(multiplier));
715 
716       TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
717       inst = operand;
718     }
719 
720     // Process the instructions with only one user.
721     // Otherwise moving scalar multiply to the operands changes the values of
722     // other users.
723     if (inst->user_count() == 1) {
724       users.push_back(inst->users().front());
725     }
726   }
727 
728   if (values.empty()) {
729     return Status::OK();
730   }
731 
732   changed_ = true;
733 
734   // Combine all constant multipliers.
735   float multiplier = 1.0;
736   for (const float v : values) {
737     multiplier *= v;
738   }
739 
740   // Create a new const scalar multiply instruction.
741   HloInstruction* new_const_inst;
742   new_const_inst =
743       computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
744 
745   // Broadcast the scalar multiplier.
746   HloInstruction* new_broadcast = computation_->AddInstruction(
747       HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
748   // Create a new scalar multiply instruction.
749   HloInstruction* new_multiply =
750       computation_->AddInstruction(HloInstruction::CreateBinary(
751           target->shape(), HloOpcode::kMultiply, target, new_broadcast));
752   CHECK_EQ(new_multiply->shape(), target->shape());
753 
754   // Update the dependency with the rest of the instructions.
755   if (target == lhs) {
756     return dot->ReplaceOperandWith(0, new_multiply);
757   } else if (target == rhs) {
758     return dot->ReplaceOperandWith(1, new_multiply);
759   } else {
760     CHECK_EQ(target, dot);
761     return dot->ReplaceAllUsesWith(new_multiply);
762   }
763 }
764 
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)765 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
766                                                     HloInstruction* operand) {
767   CHECK_EQ(1, instruction->operand_count());
768   if (operand == nullptr) {
769     operand = instruction->mutable_operand(0);
770   }
771   CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
772            ShapeUtil::ElementsIn(operand->shape()));
773   CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
774            ShapeUtil::ByteSizeOf(operand->shape()));
775 
776   auto bitcast = computation_->AddInstruction(
777       HloInstruction::CreateBitcast(instruction->shape(), operand));
778   bitcast->set_metadata(instruction->metadata());
779   TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
780 }
781 
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)782 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
783     HloInstruction* old_instruction, HloInstruction* new_instruction) {
784   if (!SameShape(old_instruction, new_instruction)) {
785     return false;
786   }
787   TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
788   return true;
789 }
790 
HandleAbs(HloInstruction * abs)791 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
792   HloInstruction* abs_operand = abs->mutable_operand(0);
793   VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
794            << " Abs operand is: " << abs_operand->ToString();
795   if (IsNonNegative(abs->operand(0), options_)) {
796     return ReplaceInstruction(abs, abs_operand);
797   }
798   return Status::OK();
799 }
800 
HandleAdd(HloInstruction * add)801 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
802   HloInstruction *lhs, *rhs;
803   CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
804 
805   // A + 0 => A
806   VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
807   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
808     return Status::OK();
809   }
810   // 0 + A => A
811   VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
812   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
813     return Status::OK();
814   }
815 
816   // Canonicalization: Put constants on the right.  This makes the reassociation
817   // rules below simpler.
818   VLOG(10) << "trying transform [Const + A => A + Const]";
819   if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
820     return ReplaceWithNewInstruction(
821         add,
822         HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
823   }
824 
825   // Reassociate to allow constant folding.
826   //
827   // Note: This is not general.  For example, we won't reassociate
828   //
829   //   (A + C1) + (B + C2) =>  A + B + (C1 + C2).
830   //
831   VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
832   HloInstruction *a, *c1, *c2;
833   if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
834                         m::Constant(&c2))) ||
835       Match(add, m::Add(m::Add(m::NonConstant(&a),
836                                m::Broadcast(m::ConstantScalar(&c1))),
837                         m::Broadcast(m::ConstantScalar(&c2))))) {
838     TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
839                         MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
840     if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
841         !ShapeUtil::IsScalar(add->shape())) {
842       sum_of_constants = computation_->AddInstruction(
843           HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
844     }
845     return ReplaceWithNewInstruction(
846         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
847                                           sum_of_constants));
848   }
849 
850   // Convert add with fullshape into add with partial shape when a
851   // portion of add is effective:
852   //             zero (fullshape)   rhs (partialshape)
853   // .           |                  |
854   // . lhs .    dynamic_update_slice (fullshape)
855   // . |         |
856   // Add (fullshape)
857   //
858   // to:
859   //              lhs
860   //              |
861   //             dynamic_slice (partialshape)   rhs (partialshape)
862   // .           |                      |
863   // . lhs .    add (partial_shape)+----+
864   // . |         |
865   // dynamic_update_slice (fullshape)
866   //
867   // This is pattern is discovered in control flow V2 gradient update.
868   if (Match(add,
869             m::Add(m::Op(&lhs),
870                    m::Op(&rhs)
871                        .WithOpcode(HloOpcode::kDynamicUpdateSlice)
872                        .WithOperand(
873                            0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
874     const Shape& partial_shape = rhs->operand(1)->shape();
875     auto sliced_lhs =
876         computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
877             partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
878             partial_shape.dimensions()));
879 
880     auto add_partial = computation_->AddInstruction(
881         HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
882                                      sliced_lhs, rhs->mutable_operand(1)));
883 
884     auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
885         lhs->shape(), lhs, add_partial,
886         absl::MakeSpan(rhs->operands()).subspan(2));
887 
888     return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
889   }
890 
891   // A*C + B*C => (A+B)*C
892   //
893   //  - If A, B, and C are integers, do this unconditionally. Proof of
894   //    correctness: https://rise4fun.com/Alive/u9X.
895   //
896   //  - If A, B, and C are floating point, do this if C is a scalar constant or
897   //    broadcast of scalar constant and is equal to +/- 2^k for some (possibly
898   //    negative) integer k.
899   //
900   //    Multiplying by a power of 2 just moves the exponent, so our answer is
901   //    exact modulo rounding of intermediate results so long as
902   //
903   //     - none of the three products has an exponent which underflows (so the
904   //       result is 0 or denormal), and
905   //     - none of the three products overflows to inf.
906   //
907   //    Proof: See algebraic_simplifier_proof_distributive_property.py.
908   //
909   //    We deem these differences in rounding, underflow, and overflow
910   //    acceptable in the ML context.
911   HloInstruction *b, *c;
912   if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
913         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
914        (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
915         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
916       (ShapeUtil::ElementIsIntegral(add->shape()) ||
917        options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
918     return ReplaceWithNewInstruction(
919         add, HloInstruction::CreateBinary(
920                  add->shape(), HloOpcode::kMultiply,
921                  computation_->AddInstruction(HloInstruction::CreateBinary(
922                      add->shape(), HloOpcode::kAdd, a, b)),
923                  c));
924   }
925 
926   if (options_.is_layout_sensitive()) {
927     return Status::OK();
928   }
929 
930   HloInstruction* lhs_scatter_operand = nullptr;
931   HloInstruction* rhs_scatter_operand = nullptr;
932   HloInstruction* lhs_scatter_update = nullptr;
933   HloInstruction* rhs_scatter_update = nullptr;
934   HloInstruction* lhs_scatter_index = nullptr;
935   HloInstruction* rhs_scatter_index = nullptr;
936   bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
937                                            m::Op(&lhs_scatter_index),
938                                            m::Op(&lhs_scatter_update))
939                                     .WithOneUse()) &&
940                      Match(lhs->to_apply()->root_instruction(),
941                            m::Add(m::Parameter(), m::Parameter()));
942   bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
943                                            m::Op(&rhs_scatter_index),
944                                            m::Op(&rhs_scatter_update))
945                                     .WithOneUse()) &&
946                      Match(rhs->to_apply()->root_instruction(),
947                            m::Add(m::Parameter(), m::Parameter()));
948   if (rhs_scatter && lhs_scatter) {
949     const auto& lhs_dnums = lhs->scatter_dimension_numbers();
950     const auto& rhs_dnums = rhs->scatter_dimension_numbers();
951     absl::optional<int64> index_concat_dimension;
952     absl::optional<int64> update_concat_dimension;
953     // Don't try to combine scatters of different ranks.
954     if (lhs_scatter_index->shape().rank() !=
955         rhs_scatter_index->shape().rank()) {
956       return Status::OK();
957     }
958 
959     int64 first_index_dim = lhs_scatter_index->shape().rank();
960     int64 first_update_dim = lhs_scatter_update->shape().rank();
961     // Find a dimension where it is possible to concatenate the indices and
962     // updates. This is the first and only non-equal dimension or the first
963     // equally sized dimension.
964     for (int64 d = lhs_scatter_index->shape().rank() - 1,
965                update_dim = lhs_scatter_update->shape().rank() - 1;
966          d >= 0; --d) {
967       if (d == lhs_dnums.index_vector_dim()) {
968         continue;
969       }
970       while (
971           absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
972         --update_dim;
973       }
974       if (lhs_scatter_index->shape().dimensions(d) ==
975           rhs_scatter_index->shape().dimensions(d)) {
976         first_index_dim = d;
977         first_update_dim = update_dim--;
978         continue;
979       }
980       // More than one dimension of unequal size was found, bail out.
981       if (index_concat_dimension) {
982         return Status::OK();
983       }
984       index_concat_dimension = d;
985       update_concat_dimension = update_dim--;
986     }
987     if (!index_concat_dimension) {
988       index_concat_dimension = first_index_dim;
989       update_concat_dimension = first_update_dim;
990     }
991 
992     // A scalar scatter will require additional reshapes of the index and
993     // update.
994     if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
995       return Status::OK();
996     }
997     const bool update_concat_is_cheap =
998         ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
999             ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
1000         ShapeUtil::ElementsIn(lhs->shape());
1001     if (!update_concat_is_cheap) {
1002       return Status::OK();
1003     }
1004     const bool same_dimension_numbers =
1005         lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
1006         absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
1007                       rhs_dnums.scatter_dims_to_operand_dims()) &&
1008         absl::c_equal(lhs_dnums.inserted_window_dims(),
1009                       rhs_dnums.inserted_window_dims()) &&
1010         absl::c_equal(lhs_dnums.update_window_dims(),
1011                       rhs_dnums.update_window_dims());
1012     const bool index_concat_is_safe =
1013         !lhs->unique_indices() && !rhs->unique_indices() &&
1014         !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
1015         !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
1016 
1017     Shape lhs_update_window = ShapeUtil::FilterDimensions(
1018         [&](int64 dim) {
1019           return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
1020         },
1021         lhs_scatter_update->shape());
1022     Shape rhs_update_window = ShapeUtil::FilterDimensions(
1023         [&](int64 dim) {
1024           return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
1025         },
1026         rhs_scatter_update->shape());
1027     // Concatenate the indices and updates
1028     if (index_concat_is_safe && same_dimension_numbers &&
1029         index_concat_dimension &&
1030         lhs_scatter_index->shape().element_type() ==
1031             rhs_scatter_index->shape().element_type() &&
1032         ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
1033       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1034                           MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1035                                         rhs_scatter_operand));
1036       TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
1037                           MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
1038                                         *index_concat_dimension));
1039       TF_ASSIGN_OR_RETURN(
1040           HloInstruction * new_update,
1041           MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
1042                         *update_concat_dimension));
1043       return ReplaceWithNewInstruction(
1044           add, HloInstruction::CreateScatter(
1045                    add->shape(), new_operand, new_index, new_update,
1046                    lhs->to_apply(), lhs_dnums, false, false));
1047     }
1048     TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1049                         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1050                                       rhs_scatter_operand));
1051     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1052     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
1053     return ReplaceInstruction(add, lhs);
1054   } else if (rhs_scatter) {
1055     TF_ASSIGN_OR_RETURN(
1056         HloInstruction * new_operand,
1057         MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
1058     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1059     return ReplaceInstruction(add, rhs);
1060   } else if (lhs_scatter) {
1061     TF_ASSIGN_OR_RETURN(
1062         HloInstruction * new_operand,
1063         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
1064     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
1065     return ReplaceInstruction(add, lhs);
1066   }
1067   return Status::OK();
1068 }
1069 
TrySimplifyTautologicalCompare(HloInstruction * conjunction)1070 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
1071     HloInstruction* conjunction) {
1072   HloInstruction *lhs, *rhs;
1073   if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
1074     return false;
1075   }
1076   struct LessThanCompareInfo {  // (LT var constant)
1077     HloInstruction* var;
1078     int64 constant;
1079   };
1080 
1081   auto get_compare_info =
1082       [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
1083     HloInstruction *lhs, *rhs;
1084     auto scalar_shape_matcher =
1085         m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
1086     if (Match(cmp, m::Compare(m::Op(&lhs),
1087                               m::Constant(&rhs).WithShape(scalar_shape_matcher))
1088                        .WithComparisonDirection(ComparisonDirection::kLt))) {
1089       return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
1090     } else if (Match(
1091                    cmp,
1092                    m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
1093                               m::Op(&rhs))
1094                        .WithComparisonDirection(ComparisonDirection::kGt))) {
1095       return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
1096     }
1097     return absl::nullopt;
1098   };
1099 
1100   absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
1101   absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
1102   if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
1103     int64 new_bound = std::min(lhs_info->constant, rhs_info->constant);
1104     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1105         conjunction,
1106         HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
1107                                       MakeScalarLike(lhs_info->var, new_bound),
1108                                       ComparisonDirection::kLt)));
1109     return true;
1110   }
1111   return false;
1112 }
1113 
HandleAnd(HloInstruction * logical_and)1114 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1115   HloInstruction *lhs, *rhs;
1116   CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1117   // Simplify logical and
1118   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1119       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1120     // A && True => A
1121     VLOG(10) << "trying transform [A && True => A]: "
1122              << logical_and->ToString();
1123     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1124       return Status::OK();
1125     }
1126     // True && A => A
1127     VLOG(10) << "trying transform [True && A => A]: "
1128              << logical_and->ToString();
1129     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1130       return Status::OK();
1131     }
1132   }
1133 
1134   // A && False => False or A & 0 => 0
1135   VLOG(10) << "trying transform [A && False => False]: "
1136            << logical_and->ToString();
1137   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1138     return Status::OK();
1139   }
1140 
1141   // False && A => False or A & 0 => 0
1142   VLOG(10) << "trying transform [False && A => False]: "
1143            << logical_and->ToString();
1144   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1145     return Status::OK();
1146   }
1147 
1148   // Simplify tautological conjunctions.
1149   TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1150                       TrySimplifyTautologicalCompare(logical_and));
1151   if (found_tautological_compare) {
1152     return Status::OK();
1153   }
1154 
1155   return Status::OK();
1156 }
1157 
HandleBitcast(HloInstruction * bitcast)1158 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1159   // If a bitcast feeds a bitcast, make it a single bitcast.
1160   HloInstruction* op;
1161   if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1162     return ReplaceWithNewInstruction(
1163         bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1164   }
1165   // All bitcasts can be eliminated (assuming layout constraints are
1166   // satisfied).
1167   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1168   return Status::OK();
1169 }
1170 
HandleBitcastConvert(HloInstruction * bitcast)1171 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1172     HloInstruction* bitcast) {
1173   // Eliminate bitcast converts between same shape.
1174   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1175   return Status::OK();
1176 }
1177 
HandleCopy(HloInstruction * copy)1178 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1179   // If a copy feeds a copy, make it a single copy.
1180   HloInstruction* op;
1181   if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1182     return ReplaceWithNewInstruction(
1183         copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1184   }
1185   // All copies can be eliminated (assuming layout constraints are satisfied).
1186   if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
1187     return Status::OK();
1188   }
1189 
1190   if (HloInstruction* bitcast_operand =
1191           BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1192     ReplaceWithBitcast(copy, bitcast_operand);
1193     return Status::OK();
1194   }
1195 
1196   // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1197   if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1198       copy->operand(0)->user_count() == 1 &&
1199       ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1200     return ReplaceWithNewInstruction(
1201         copy,
1202         copy->operand(0)->CloneWithNewOperands(
1203             copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1204   }
1205   return Status::OK();
1206 }
1207 
HandleConcatenate(HloInstruction * concatenate)1208 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1209     HloInstruction* concatenate) {
1210   absl::Span<HloInstruction* const> operands(concatenate->operands());
1211   if (operands.size() == 1) {
1212     // Unary concatenates are useless.
1213     ReplaceInstructionIfSameShape(concatenate, operands[0]);
1214     return Status::OK();
1215   }
1216   // Filter out and remove empty operands.
1217   std::vector<HloInstruction*> nonempty_operands;
1218   for (HloInstruction* operand : operands) {
1219     if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1220       nonempty_operands.push_back(operand);
1221     }
1222   }
1223   if (nonempty_operands.size() < operands.size()) {
1224     HloInstruction* replacement;
1225     if (nonempty_operands.empty()) {
1226       replacement = operands[0];
1227     } else if (nonempty_operands.size() == 1) {
1228       replacement = nonempty_operands[0];
1229     } else {
1230       replacement =
1231           computation_->AddInstruction(concatenate->CloneWithNewOperands(
1232               concatenate->shape(), nonempty_operands));
1233     }
1234     VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1235              << replacement->ToString();
1236     ReplaceInstructionIfSameShape(concatenate, replacement);
1237     return Status::OK();
1238   }
1239 
1240   if (options_.is_layout_sensitive()) {
1241     return Status::OK();
1242   }
1243 
1244   // Check if we can merge "adjacent" slice operands which take slices from the
1245   // same other op. For simplicity we only merge unstrided slices.
1246   int64 concatenate_dimension = concatenate->concatenate_dimension();
1247   for (int64 i = 0; i < operands.size(); ++i) {
1248     if (operands[i]->opcode() != HloOpcode::kSlice ||
1249         !IsUnstridedSlice(operands[i])) {
1250       continue;
1251     }
1252     int64 slice_end = operands[i]->slice_limits(concatenate_dimension);
1253     HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1254     int64 j = i + 1;
1255     while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice &&
1256            IsUnstridedSlice(operands[j]) &&
1257            operands[j]->operand(0) == slice_operand &&
1258            operands[j]->slice_starts(concatenate_dimension) == slice_end) {
1259       // Check that all the slice_start values are the same in all other
1260       // dimensions. This implies that the slice_limit values are also the same,
1261       // because operands of concatenate need to have the same shape, and we
1262       // already checked that the slices are unstrided.
1263       bool same_other_starts = true;
1264       for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) {
1265         if (k == concatenate_dimension) {
1266           continue;
1267         }
1268         if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1269           same_other_starts = false;
1270           break;
1271         }
1272       }
1273       if (!same_other_starts) {
1274         break;
1275       }
1276       slice_end = operands[j]->slice_limits(concatenate_dimension);
1277       ++j;
1278     }
1279     if (j - i > 1) {
1280       Shape new_slice_shape = operands[i]->shape();
1281       new_slice_shape.set_dimensions(
1282           concatenate_dimension,
1283           slice_end - operands[i]->slice_starts(concatenate_dimension));
1284       simplifier_->UpdateLayout(&new_slice_shape);
1285       auto new_limit_indices = operands[i]->slice_limits();
1286       new_limit_indices[concatenate_dimension] = slice_end;
1287       auto new_slice_op =
1288           computation_->AddInstruction(HloInstruction::CreateSlice(
1289               new_slice_shape, slice_operand,
1290               /*start_indices=*/operands[i]->slice_starts(),
1291               /*limit_indices=*/new_limit_indices,
1292               /*strides=*/operands[i]->slice_strides()));
1293       std::vector<HloInstruction*> new_operands;
1294       for (int64 k = 0; k < i; ++k) {
1295         new_operands.push_back(operands[k]);
1296       }
1297       new_operands.push_back(new_slice_op);
1298       for (int64 k = j; k < operands.size(); ++k) {
1299         new_operands.push_back(operands[k]);
1300       }
1301       auto replacement =
1302           computation_->AddInstruction(concatenate->CloneWithNewOperands(
1303               concatenate->shape(), new_operands));
1304 
1305       // Recurse to handle multiple disjoint sequence of inputs. The
1306       // logic above merge only 1 sequential series of
1307       // inputs. Otherwise, it can lead to the FixPass optimization
1308       // hitting its threshold.
1309       if (ReplaceInstructionIfSameShape(concatenate, replacement)) {
1310         return HandleConcatenate(replacement);
1311       }
1312 
1313       return Status::OK();
1314     }
1315   }
1316 
1317   if (operands.size() == 2) {
1318     // A binary concat with a broadcasted scalar as an operand can be converted
1319     // into a pad which is simpler to fold into other operations.
1320     bool is_effective_low_pad = Match(
1321         operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1322     bool is_effective_high_pad = Match(
1323         operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1324     if (!is_effective_low_pad && !is_effective_high_pad) {
1325       return Status::OK();
1326     }
1327     PaddingConfig padding_config;
1328     for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1329       auto padding_config_dim = padding_config.add_dimensions();
1330       padding_config_dim->set_edge_padding_high(0);
1331       padding_config_dim->set_edge_padding_low(0);
1332       padding_config_dim->set_interior_padding(0);
1333       if (dim == concatenate_dimension) {
1334         if (is_effective_low_pad) {
1335           padding_config_dim->set_edge_padding_low(
1336               operands[0]->shape().dimensions(dim));
1337         } else {
1338           padding_config_dim->set_edge_padding_high(
1339               operands[1]->shape().dimensions(dim));
1340         }
1341       }
1342     }
1343     int64 operand_to_pad = is_effective_low_pad ? 1 : 0;
1344     int64 pad_value_operand = is_effective_low_pad ? 0 : 1;
1345     HloInstruction* pad =
1346         computation_->AddInstruction(HloInstruction::CreatePad(
1347             concatenate->shape(), operands[operand_to_pad],
1348             operands[pad_value_operand]->mutable_operand(0), padding_config));
1349     return ReplaceInstruction(concatenate, pad);
1350   }
1351 
1352   if (absl::c_count(operands, operands[0]) == operands.size() &&
1353       operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1354     Shape new_shape = operands[0]->shape();
1355     absl::InlinedVector<int64, 8> broadcast_dims;
1356     for (int64 i = 0; i < new_shape.rank(); ++i) {
1357       if (i == concatenate_dimension) {
1358         continue;
1359       }
1360       broadcast_dims.push_back(i);
1361     }
1362     new_shape.DeleteDimension(concatenate_dimension);
1363     return ReplaceInstruction(
1364         concatenate,
1365         MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1366                          broadcast_dims, concatenate->shape()));
1367   }
1368   return Status::OK();
1369 }
1370 
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1371 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1372                                           const LiteralSlice& literal,
1373                                           AlgebraicSimplifier* simplifier) {
1374   if (literal.shape().IsTuple()) {
1375     std::vector<HloInstruction*> elems;
1376     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1377     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1378       elems.push_back(BuildTupleConstant(
1379           computation, LiteralSlice(literal, {i}), simplifier));
1380     }
1381     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1382   } else {
1383     return computation->AddInstruction(
1384         simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1385   }
1386 }
1387 
HandleConstant(HloInstruction * constant)1388 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1389   // Tuple constants aren't directly supported by any backend. Expand them into
1390   // explicit Tuple instructions.
1391   if (constant->shape().IsTuple()) {
1392     return ReplaceInstruction(
1393         constant,
1394         BuildTupleConstant(computation_, constant->literal(), simplifier_));
1395   }
1396 
1397   if (constant->shape().element_type() == TOKEN) {
1398     return Status::OK();
1399   }
1400 
1401   // If a literal is all the same element replace it with a scalar broadcast.
1402   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1403       constant->literal().IsAllFirst()) {
1404     Literal unique_scalar(
1405         LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1406     HloInstruction* scalar = computation_->AddInstruction(
1407         simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1408     return ReplaceWithNewInstruction(
1409         constant,
1410         HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1411   }
1412 
1413   // If a literal is an increasing sequence from zero, replace it with an iota.
1414   if (constant->shape().rank() == 1 &&
1415       ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1416       constant->literal().IsR1Iota()) {
1417     return ReplaceWithNewInstruction(
1418         constant, HloInstruction::CreateIota(constant->shape(), 0));
1419   }
1420   return Status::OK();
1421 }
1422 
HandleSubtract(HloInstruction * sub)1423 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1424   HloInstruction *lhs, *rhs;
1425   CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1426   // A - 0 => A
1427   VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1428   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
1429     return Status::OK();
1430   }
1431 
1432   // Canonicalize subtraction of a constant to addition.
1433   VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1434   if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1435       Match(sub, m::Subtract(m::NonConstant(&lhs),
1436                              m::Broadcast(m::Constant(&rhs))))) {
1437     HloInstruction* negative_const = computation_->AddInstruction(
1438         HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1439     if (const HloInstruction* broadcast =
1440             DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1441       negative_const =
1442           computation_->AddInstruction(HloInstruction::CreateBroadcast(
1443               broadcast->shape(), negative_const, broadcast->dimensions()));
1444     }
1445     return ReplaceWithNewInstruction(
1446         sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1447                                           negative_const));
1448   }
1449 
1450   return Status::OK();
1451 }
1452 namespace {
1453 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1454 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1455   return result->Populate<T>([&](absl::Span<const int64> indices) {
1456     return T{1.0} / constant.literal().Get<T>(indices);
1457   });
1458 }
1459 
1460 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1461 std::unique_ptr<HloInstruction> TryDivideToShift(
1462     HloInstruction* divide, HloComputation* computation,
1463     AlgebraicSimplifier* simplifier) {
1464   HloInstruction *a, *b, *c;
1465   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1466 
1467   if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1468       !Match(b, m::ConstantEffectiveScalar(&c)) &&
1469       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1470     return nullptr;
1471   }
1472 
1473   if (ShapeUtil::ElementIsSigned(divide->shape())) {
1474     int64 b_value = c->literal().GetFirstElement<T>();
1475     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
1476       // Handle negative dividends by negating the result of the division.
1477       HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1478 
1479       Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1480       simplifier->UpdateLayout(&changed_shape);
1481       auto* dividend_is_negative =
1482           computation->AddInstruction(HloInstruction::CreateCompare(
1483               changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1484 
1485       auto* negated_dividend = computation->AddInstruction(
1486           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1487 
1488       auto* abs_dividend =
1489           computation->AddInstruction(HloInstruction::CreateTernary(
1490               a->shape(), HloOpcode::kSelect, dividend_is_negative,
1491               negated_dividend, a));
1492 
1493       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
1494           divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1495           MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
1496 
1497       auto* neqated_quotient =
1498           computation->AddInstruction(HloInstruction::CreateUnary(
1499               quotient->shape(), HloOpcode::kNegate, quotient));
1500 
1501       return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1502                                            dividend_is_negative,
1503                                            neqated_quotient, quotient);
1504     }
1505   } else {
1506     uint64 b_value = c->literal().GetFirstElement<T>();
1507     if (IsPowerOfTwo(b_value)) {
1508       return HloInstruction::CreateBinary(
1509           divide->shape(), HloOpcode::kShiftRightLogical, a,
1510           MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
1511     }
1512   }
1513 
1514   return nullptr;
1515 }
1516 }  // namespace
1517 
HandleDivide(HloInstruction * divide)1518 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1519   HloInstruction *a, *b, *c, *d;
1520   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1521   // A/1 => A
1522   VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1523   if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
1524     return Status::OK();
1525   }
1526 
1527   // A / B => A >> log2(B) if B is a power of 2.
1528   switch (divide->shape().element_type()) {
1529     case S8:
1530       if (std::unique_ptr<HloInstruction> shift =
1531               TryDivideToShift<int8>(divide, computation_, simplifier_)) {
1532         return ReplaceWithNewInstruction(divide, std::move(shift));
1533       }
1534       break;
1535     case S16:
1536       if (std::unique_ptr<HloInstruction> shift =
1537               TryDivideToShift<int16>(divide, computation_, simplifier_)) {
1538         return ReplaceWithNewInstruction(divide, std::move(shift));
1539       }
1540       break;
1541     case S32:
1542       if (std::unique_ptr<HloInstruction> shift =
1543               TryDivideToShift<int32>(divide, computation_, simplifier_)) {
1544         return ReplaceWithNewInstruction(divide, std::move(shift));
1545       }
1546       break;
1547     case S64:
1548       if (std::unique_ptr<HloInstruction> shift =
1549               TryDivideToShift<int64>(divide, computation_, simplifier_)) {
1550         return ReplaceWithNewInstruction(divide, std::move(shift));
1551       }
1552       break;
1553     case U8:
1554       if (std::unique_ptr<HloInstruction> shift =
1555               TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
1556         return ReplaceWithNewInstruction(divide, std::move(shift));
1557       }
1558       break;
1559     case U16:
1560       if (std::unique_ptr<HloInstruction> shift =
1561               TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
1562         return ReplaceWithNewInstruction(divide, std::move(shift));
1563       }
1564       break;
1565     case U32:
1566       if (std::unique_ptr<HloInstruction> shift =
1567               TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
1568         return ReplaceWithNewInstruction(divide, std::move(shift));
1569       }
1570       break;
1571     case U64:
1572       if (std::unique_ptr<HloInstruction> shift =
1573               TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
1574         return ReplaceWithNewInstruction(divide, std::move(shift));
1575       }
1576       break;
1577     default:
1578       break;
1579   }
1580 
1581   Shape* shape;
1582   // exp(A)/exp(B) => exp(A-B)
1583   if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1584                         .WithShape(m::Shape(&shape)))) {
1585     VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1586     HloInstruction* subtract = computation_->AddInstruction(
1587         HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1588     return ReplaceWithNewInstruction(
1589         divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1590   }
1591 
1592   // A/exp(B) => A*exp(-B)
1593   if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1594     VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1595     HloInstruction* negate = computation_->AddInstruction(
1596         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1597     HloInstruction* new_exp = computation_->AddInstruction(
1598         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1599     return ReplaceWithNewInstruction(
1600         divide, HloInstruction::CreateBinary(divide->shape(),
1601                                              HloOpcode::kMultiply, a, new_exp));
1602   }
1603 
1604   // A/pow(B,C) => A*pow(B,-C)
1605   if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1606     VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1607     // The output shape of the created negate operator should be the same as the
1608     // input.
1609     const Shape& negate_shape = c->shape();
1610     HloInstruction* negate = computation_->AddInstruction(
1611         HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1612     // And the power operator should retain the output shape of the old one.
1613     const Shape& new_power_shape = b->shape();
1614     HloInstruction* new_power =
1615         computation_->AddInstruction(HloInstruction::CreateBinary(
1616             new_power_shape, HloOpcode::kPower, b, negate));
1617     return ReplaceWithNewInstruction(
1618         divide, HloInstruction::CreateBinary(
1619                     divide->shape(), HloOpcode::kMultiply, a, new_power));
1620   }
1621 
1622   // A/sqrt(B) => A*rsqrt(X).
1623   if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1624     auto* rsqrt = computation_->AddInstruction(
1625         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1626     return ReplaceWithNewInstruction(
1627         divide, HloInstruction::CreateBinary(rsqrt->shape(),
1628                                              HloOpcode::kMultiply, a, rsqrt));
1629   }
1630 
1631   // A/rsqrt(B) => A*sqrt(B).
1632   if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1633     auto* sqrt = computation_->AddInstruction(
1634         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1635     return ReplaceWithNewInstruction(
1636         divide, HloInstruction::CreateBinary(sqrt->shape(),
1637                                              HloOpcode::kMultiply, a, sqrt));
1638   }
1639 
1640   // Simplifying integral division would produce unexpected results.
1641   if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1642     return Status::OK();
1643   }
1644 
1645   // A / Const => A * (1 / Const)
1646   //
1647   // (Backends can do this transformation, but generally only if the constant is
1648   // a scalar.)
1649   if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1650       (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1651     Shape result_shape = c->literal().shape();
1652     Literal new_literal(result_shape);
1653     switch (result_shape.element_type()) {
1654       case F16:
1655         TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1656         break;
1657       case F32:
1658         TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1659         break;
1660       case BF16:
1661         TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1662         break;
1663       case F64:
1664         TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1665         break;
1666       case C64:
1667         TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1668         break;
1669       case C128:
1670         TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1671         break;
1672       default:
1673         return Status::OK();
1674     }
1675     auto inverse = computation_->AddInstruction(
1676         simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1677     if (b != c) {
1678       inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1679           b->shape(), inverse, b->dimensions()));
1680     }
1681     TF_ASSIGN_OR_RETURN(auto new_divide,
1682                         MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1683     return ReplaceInstruction(divide, new_divide);
1684   }
1685 
1686   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
1687   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1688                               m::Divide(m::Op(&c), m::Op(&d))))) {
1689     TF_ASSIGN_OR_RETURN(auto a_times_d,
1690                         MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1691     TF_ASSIGN_OR_RETURN(auto b_times_c,
1692                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1693     TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1694                                                        a_times_d, b_times_c));
1695 
1696     return ReplaceInstruction(divide, new_divide);
1697   }
1698 
1699   // (A / B) / C => A / (B * C)
1700   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1701     TF_ASSIGN_OR_RETURN(auto b_times_c,
1702                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1703     TF_ASSIGN_OR_RETURN(auto new_divide,
1704                         MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1705     return ReplaceInstruction(divide, new_divide);
1706   }
1707 
1708   // A / (B / C) => (A*C) / B
1709   if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1710     TF_ASSIGN_OR_RETURN(auto a_times_c,
1711                         MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1712     TF_ASSIGN_OR_RETURN(auto new_divide,
1713                         MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1714     return ReplaceInstruction(divide, new_divide);
1715   }
1716 
1717   // If X is a convert from pred, then
1718   // X / broadcast(Y) => broadcast(1/Y) * X
1719   if (Match(divide,
1720             m::Divide(
1721                 m::Convert(&a,
1722                            m::Op().WithShape(m::Shape().WithElementType(PRED))),
1723                 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1724     TF_ASSIGN_OR_RETURN(
1725         auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1726     auto recip_bcast = computation_->AddInstruction(
1727         HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1728     TF_ASSIGN_OR_RETURN(auto mul,
1729                         MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1730     return ReplaceInstruction(divide, mul);
1731   }
1732 
1733   return Status::OK();
1734 }
1735 
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1736 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1737     HloInstruction* dot) {
1738   const Shape& lhs_shape = dot->operand(0)->shape();
1739   int64 num_degenerate_lhs_dims = 0;
1740   std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1);
1741   for (int64 i = 0; i < lhs_shape.rank(); ++i) {
1742     if (lhs_shape.dimensions(i) == 1) {
1743       ++num_degenerate_lhs_dims;
1744     } else {
1745       lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1746     }
1747   }
1748 
1749   const Shape& rhs_shape = dot->operand(1)->shape();
1750   int64 num_degenerate_rhs_dims = 0;
1751   std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1);
1752   for (int64 i = 0; i < rhs_shape.rank(); ++i) {
1753     if (rhs_shape.dimensions(i) == 1) {
1754       ++num_degenerate_rhs_dims;
1755     } else {
1756       rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1757     }
1758   }
1759   if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1760     return false;
1761   }
1762   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1763   DotDimensionNumbers new_dnums;
1764   for (int64 dim : dnums.lhs_batch_dimensions()) {
1765     int64 new_dim = lhs_dimension_map[dim];
1766     if (new_dim != -1) {
1767       new_dnums.add_lhs_batch_dimensions(new_dim);
1768     }
1769   }
1770   for (int64 dim : dnums.lhs_contracting_dimensions()) {
1771     int64 new_dim = lhs_dimension_map[dim];
1772     if (new_dim != -1) {
1773       new_dnums.add_lhs_contracting_dimensions(new_dim);
1774     }
1775   }
1776 
1777   for (int64 dim : dnums.rhs_batch_dimensions()) {
1778     int64 new_dim = rhs_dimension_map[dim];
1779     if (new_dim != -1) {
1780       new_dnums.add_rhs_batch_dimensions(new_dim);
1781     }
1782   }
1783   for (int64 dim : dnums.rhs_contracting_dimensions()) {
1784     int64 new_dim = rhs_dimension_map[dim];
1785     if (new_dim != -1) {
1786       new_dnums.add_rhs_contracting_dimensions(new_dim);
1787     }
1788   }
1789 
1790   HloInstruction* new_lhs =
1791       num_degenerate_lhs_dims > 0
1792           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1793                 ShapeUtil::DropDegenerateDimensions(lhs_shape),
1794                 dot->mutable_operand(0)))
1795           : dot->mutable_operand(0);
1796   HloInstruction* new_rhs =
1797       num_degenerate_rhs_dims > 0
1798           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1799                 ShapeUtil::DropDegenerateDimensions(rhs_shape),
1800                 dot->mutable_operand(1)))
1801           : dot->mutable_operand(1);
1802   TF_ASSIGN_OR_RETURN(
1803       auto new_dot,
1804       MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
1805                  /*preferred_element_type=*/dot->shape().element_type()));
1806   if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
1807     TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
1808   } else {
1809     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1810         dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
1811   }
1812   return true;
1813 }
1814 
OptimizeDotOfConcat(HloInstruction * dot)1815 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1816     HloInstruction* dot) {
1817   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1818   if (dnums.lhs_contracting_dimensions_size() != 1 ||
1819       dnums.lhs_batch_dimensions_size() != 0 ||
1820       dot->shape().dimensions_size() != 2) {  // dot output 2D
1821     return nullptr;
1822   }
1823 
1824   const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1825   const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1826   HloInstruction *lhs, *rhs;
1827   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1828 
1829   TF_ASSIGN_OR_RETURN(
1830       HloInstruction * optimized_lhs_concat,
1831       OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1832                                 rhs_contracting_dim, /*swapped=*/false));
1833   if (optimized_lhs_concat) {
1834     return optimized_lhs_concat;
1835   }
1836 
1837   return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1838                                    lhs_contracting_dim, /*swapped=*/true);
1839 }
1840 
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64 lhs_contracting_dim,HloInstruction * rhs,int64 rhs_contracting_dim,bool swapped)1841 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1842     const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
1843     HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
1844   bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1845                       lhs->concatenate_dimension() == lhs_contracting_dim &&
1846                       rhs->opcode() == HloOpcode::kConstant;
1847   if (!can_optimize) {
1848     return nullptr;
1849   }
1850 
1851   // We're replacing this:
1852   //
1853   //   +-----+-----+-----+      +-------------------+
1854   //   |     |     |     |      |                   |
1855   //   |     |     |     |      |        R_0        |
1856   //   |     |     |     |      |                   |
1857   //   |     |     |     |      +-------------------+
1858   //   |     |     |     |      |                   |
1859   //   | L_0 | L_1 | L_2 |   *  |        R_1        |
1860   //   |     |     |     |      |                   |
1861   //   |     |     |     |      +-------------------+
1862   //   |     |     |     |      |                   |
1863   //   |     |     |     |      |        R_2        |
1864   //   |     |     |     |      |                   |
1865   //   +-----+-----+-----+      +-------------------+
1866   //
1867   // with this:
1868   //
1869   // [Sum over i]
1870   //
1871   //   +-----+     +-------------------+
1872   //   |     |     |                   |
1873   //   |     |  *  |        R_i        |
1874   //   |     |     |                   |
1875   //   |     |     +-------------------+
1876   //   |     |
1877   //   | L_i |
1878   //   |     |
1879   //   |     |
1880   //   |     |
1881   //   |     |
1882   //   |     |
1883   //   +-----+
1884   //
1885   // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1886   // for free) and the RHS is a constant tensor (and thus can be split at
1887   // compile time).  In the future, we may also want to do this when both the
1888   // LHS and the RHS are concatenate operations that line up along the dimension
1889   // being contracted over.
1890   //
1891   // We should be able to generalize this transform to work on a non-constant
1892   // RHS when/if we have in-place slices or support input-fusing slices into
1893   // Dots.
1894 
1895   // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1896   // the diagram above).
1897   DotDimensionNumbers new_dot_dnums;
1898   new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1899                                                        : lhs_contracting_dim);
1900   new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1901                                                        : rhs_contracting_dim);
1902 
1903   // Here we use the MKN notation, where the contracted dimension has K
1904   // elements and the two non-contracted dimensions have M and N elements.
1905   HloInstruction* add_result = nullptr;
1906   int64 rhs_contracting_dim_offset = 0;
1907   int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1908   for (HloInstruction* concat_op : lhs->operands()) {
1909     int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1910     Shape rhs_slice_shape(rhs->shape());
1911     rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1912     simplifier_->UpdateLayout(&rhs_slice_shape);
1913 
1914     std::array<int64, 2> start_indices;
1915     start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1916     start_indices[1 - rhs_contracting_dim] = 0;
1917 
1918     std::array<int64, 2> limit_indices;
1919     limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1920     limit_indices[1 - rhs_contracting_dim] = n;
1921 
1922     HloInstruction* rhs_slice =
1923         computation_->AddInstruction(HloInstruction::CreateSlice(
1924             rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1925             /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1926 
1927     // TODO(b/69062148): We can get rid of `swapped` once all backends support
1928     // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1929     // LHS with dimension 0 of the RHS).  But for now we keep the same
1930     // contraction dimensions as the incoming dot operation to ensure the new
1931     // dot operations can be lowered.
1932     HloInstruction *new_dot_lhs, *new_dot_rhs;
1933     if (swapped) {
1934       new_dot_lhs = rhs_slice;
1935       new_dot_rhs = concat_op;
1936     } else {
1937       new_dot_lhs = concat_op;
1938       new_dot_rhs = rhs_slice;
1939     }
1940 
1941     auto* new_dot = computation_->AddInstruction(
1942         HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
1943                                   new_dot_dnums, dot.precision_config()));
1944 
1945     if (add_result) {
1946       add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
1947           dot.shape(), HloOpcode::kAdd, add_result, new_dot));
1948     } else {
1949       add_result = new_dot;
1950     }
1951 
1952     rhs_contracting_dim_offset += sub_k;
1953   }
1954 
1955   return add_result;
1956 }
1957 
OptimizeDotOfGather(HloInstruction * dot)1958 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
1959     HloInstruction* dot) {
1960   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1961   if (dnums.lhs_contracting_dimensions_size() != 1 ||
1962       dnums.rhs_contracting_dimensions_size() != 1 ||
1963       dnums.lhs_batch_dimensions_size() != 0 ||
1964       dnums.rhs_batch_dimensions_size() != 0 ||
1965       dot->shape().dimensions_size() != 2) {  // dot output 2D
1966     VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
1967     return nullptr;
1968   }
1969 
1970   // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
1971   // Currently a Gather is a DynamicSlice.
1972   auto is_dynamic_slice_constant_combination =
1973       [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
1974         // First operand is a DynamicSlice(Constant).
1975         if (a->opcode() != HloOpcode::kDynamicSlice) {
1976           return false;
1977         }
1978         auto* dynamic_slice_op = a->operand(0);
1979         if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
1980           return false;
1981         }
1982         // Second operand is a Constant.
1983         if (b->opcode() != HloOpcode::kConstant) {
1984           return false;
1985         }
1986         // The DynamicSlice output is a vector.
1987         const Shape& dynamic_slice_shape = a->shape();
1988         if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
1989           return false;
1990         }
1991         // Constant size is the same before and after slice in the contracting
1992         // dimension, otherwise we either must precompute for all possible slice
1993         // indices or dot is invalid.
1994         const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
1995         if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
1996             dynamic_slice_shape.dimensions(a_contracting_dimension)) {
1997           return false;
1998         }
1999         return true;
2000       };
2001 
2002   HloInstruction* lhs = dot->mutable_operand(0);
2003   HloInstruction* rhs = dot->mutable_operand(1);
2004   int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2005   int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2006 
2007   if (!is_dynamic_slice_constant_combination(
2008           lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2009       !is_dynamic_slice_constant_combination(
2010           rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2011     VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2012                 "dot(ctB, DS(ctA)), where the two constants have equal "
2013                 "contracting dimensions.";
2014     return nullptr;
2015   }
2016 
2017   // LHS is DynamicSlice:
2018   // input: dot(DS(ctA), ctB))
2019   // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2020   // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2021   // output: DS(dot(ctA, ctB))
2022   // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2023 
2024   // RHS is DynamicSlice:
2025   // input: dot(ctA, DS(ctB))
2026   // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2027   // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2028   // output: DS(dot(ctA, ctB))
2029   // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2030 
2031   bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2032   HloDynamicSliceInstruction* dynamic_slice =
2033       lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2034                            : Cast<HloDynamicSliceInstruction>(rhs);
2035 
2036   // ctA:
2037   HloInstruction* left_operand =
2038       lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2039   // ctB:
2040   HloInstruction* right_operand =
2041       lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2042   // Build ctA x ctB.
2043   const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2044   const int n =
2045       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2046   auto memoized_shape =
2047       ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2048   simplifier_->UpdateLayout(&memoized_shape);
2049   auto* memoized_inst = computation_->AddInstruction(
2050       HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2051                                 dnums, dot->precision_config()));
2052   // Get pair {start, 0} or {0, start}.
2053   // Position of start:
2054   int index_of_non_zero_start = lhs_is_dynamic_slice
2055                                     ? 1 - lhs_contracting_dimension
2056                                     : 1 - rhs_contracting_dimension;
2057   // Position of zero:
2058   int index_of_zero_start = 1 - index_of_non_zero_start;
2059 
2060   // Slice out start and 0 components and reorder if necessary.
2061   auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2062   Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2063   simplifier_->UpdateLayout(&s_shape);
2064   Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2065   simplifier_->UpdateLayout(&d_shape);
2066   HloInstruction* non_zero_start =
2067       dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2068   HloInstruction* zero_start =
2069       dynamic_slice->mutable_operand(1 + index_of_zero_start);
2070   std::vector<HloInstruction*> new_start_indices;
2071   if (lhs_is_dynamic_slice) {
2072     new_start_indices = {non_zero_start, zero_start};
2073   } else {
2074     new_start_indices = {zero_start, non_zero_start};
2075   }
2076 
2077   // Build DynamicSlice(ctA x ctB).
2078   const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2079   const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2080   auto* memoized_lookup =
2081       computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
2082           dot->shape(), memoized_inst, new_start_indices,
2083           {new_slice_m, new_slice_n}));
2084 
2085   return memoized_lookup;
2086 }
2087 
2088 // This function tries to transform
2089 //   dot(reshape(transpose(A)), Const) to
2090 //   dot(reshape(A), reshape(transpose(reshape(Const)))),
2091 // so that the reshape and transpose on the Const side can be constant folded.
2092 //
2093 // The basic idea is that since the accumulation in the dot operation is
2094 // associative, so as long as we permute the elements of the contracting
2095 // dimensions on both sides of the dot in the same way, the result of the
2096 // dot is not affected.
2097 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2098 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2099     HloInstruction* dot) {
2100   // This transformation assumes layout is not assigned yet.
2101   if (options_.is_layout_sensitive()) {
2102     return nullptr;
2103   }
2104 
2105   // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2106   // remainder of this function easier.
2107   auto dnums = dot->dot_dimension_numbers();
2108   auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2109   auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2110   auto* lhs = dot->mutable_operand(0);
2111   auto* rhs = dot->mutable_operand(1);
2112   if (dot->operand(0)->IsConstant()) {
2113     std::swap(lhs, rhs);
2114     std::swap(lhs_contracting_dims, rhs_contracting_dims);
2115   }
2116 
2117   // Require single contracting dim to make the implementation easier to
2118   // track contracting dims.
2119   if (dnums.lhs_contracting_dimensions_size() != 1) {
2120     return nullptr;
2121   }
2122 
2123   // Pattern match Dot(reshape(transpose(input), constant))
2124   HloInstruction* reshape;
2125   HloInstruction* transpose;
2126   HloInstruction* input;
2127   HloInstruction* constant;
2128   if (!Match(lhs,
2129              m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2130       !Match(rhs, m::Constant(&constant))) {
2131     return nullptr;
2132   }
2133 
2134   // Check that reshape squishes some dims into one dim and that this one
2135   // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2136   // be N - 1, where N is the rank of the reshape output. This means that the
2137   // reshape squishes some dims into one dim. lhs contracting dim should not
2138   // be in unmodified_dims. This means that the squishing target dim is the
2139   // lhs contracting dim.
2140   auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2141       reshape->operand(0)->shape(), reshape->shape());
2142   CHECK_EQ(lhs_contracting_dims.size(), 1);
2143   if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2144       absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
2145         return p.second == lhs_contracting_dims[0];
2146       })) {
2147     return nullptr;
2148   }
2149 
2150   // Virtually pull the reshape into the dot so the dot operates on the
2151   // transpose, with "unsquished" lhs contracting dims.  The new contracting
2152   // dims are all of the dims that are modified by the reshape -- that is, every
2153   // dimension that's not in `unmodified_dims[i].first`.
2154   //
2155   // (We don't need to actually create a new dot instruction. We can just keep
2156   // track of lhs and lhs_contracting_dims.)
2157   absl::flat_hash_set<int64> unmodified_transpose_dims;
2158   for (const auto& pair : unmodified_dims) {
2159     unmodified_transpose_dims.insert(pair.first);
2160   }
2161   lhs_contracting_dims.Clear();
2162   for (int64 i = 0; i < transpose->shape().dimensions_size(); ++i) {
2163     if (!unmodified_transpose_dims.contains(i)) {
2164       lhs_contracting_dims.Add(i);
2165     }
2166   }
2167   // We require the "unsquished" lhs contracting dims to be consecutive.
2168   auto is_iota = [](absl::Span<const int64> dims) {
2169     return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
2170              return (b != a + 1);
2171            }) == dims.end();
2172   };
2173   if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
2174     return nullptr;
2175   }
2176   lhs = lhs->mutable_operand(0);
2177 
2178   // Check that the transpose only permutes the contracting dims.
2179   const auto& transpose_dims = transpose->dimensions();
2180   for (int64 i = 0; i < transpose_dims.size(); ++i) {
2181     if (transpose_dims[i] != i &&
2182         !absl::c_linear_search(lhs_contracting_dims, i)) {
2183       return nullptr;
2184     }
2185   }
2186   // Virtually pull the transpose into the dot. Now the dot is equivalent to
2187   // a new dot with "permuted" lhs contracting dims.
2188   std::vector<int64> permutation;
2189   for (auto dim : lhs_contracting_dims) {
2190     permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2191   }
2192   CHECK(IsPermutation(permutation));
2193   auto new_lhs_contracting_dims =
2194       ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
2195   lhs_contracting_dims.Clear();
2196   for (auto dim : new_lhs_contracting_dims) {
2197     lhs_contracting_dims.Add(dim);
2198   }
2199   lhs = lhs->mutable_operand(0);
2200 
2201   // All checks are passed at this point.
2202   //
2203   // Transform lhs. Remove the transpose and reshape by sorting the lhs
2204   // contracting dims and squishing them into a single one. We don't actually
2205   // squish the lhs_contracting_dims here because we still need the unsquished
2206   // contracting dims to invert reshape and transpose.
2207   absl::c_sort(lhs_contracting_dims);
2208   lhs = computation_->AddInstruction(
2209       HloInstruction::CreateReshape(reshape->shape(), lhs));
2210 
2211   // Transform rhs. Say the input HLO is:
2212   //
2213   //   t0 = f32[2, 2, 3] parameter(0)
2214   //   t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2215   //   t2 = f32[2, 6] reshape(t1)
2216   //   t3 = f32[6, 2] constant(...)
2217   //   dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2218   //                               rhs_contracting_dims={0}
2219   //
2220   // At this point in the function, we have decided that the second and third
2221   // dims of t0 can be switched to remove the transpose, and we have
2222   // "virtually decomposed" the input HLO to:
2223   //
2224   //   t0 = f32[2, 2, 3] parameter(0)
2225   //   t2' = f32[2, 6] reshape(t0)
2226   //   t3' = f32[6, 2] ops-to-be-filled ...
2227   //   dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2228   //                                 rhs_contracting_dims={0}
2229   //
2230   // The rest of this function is to fill in the ops of t3'. To do this, we
2231   // unsquish the contracting dimensions in t3 and then apply the inverse of
2232   // the transpose from t1.
2233 
2234   // Invert reshape.
2235   CHECK_EQ(rhs_contracting_dims.size(), 1);
2236   std::vector<int64> rhs_unsquished_shape_dims =
2237       SpanToVector(constant->shape().dimensions());
2238   auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2239                                             rhs_contracting_dims[0]);
2240   for (auto dim : lhs_contracting_dims) {
2241     it = rhs_unsquished_shape_dims.insert(it,
2242                                           transpose->shape().dimensions(dim));
2243     ++it;
2244   }
2245   HloInstruction* rhs_reshape =
2246       computation_->AddInstruction(HloInstruction::CreateReshape(
2247           ShapeUtil::MakeShape(constant->shape().element_type(),
2248                                rhs_unsquished_shape_dims),
2249           constant));
2250   rhs = rhs_reshape;
2251 
2252   // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2253   rhs_contracting_dims.Resize(lhs_contracting_dims.size(),
2254                               rhs_contracting_dims[0]);
2255   absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2256 
2257   // Invert transpose. First compute the shape.
2258   std::vector<int64> rhs_transpose_shape_dims =
2259       SpanToVector(rhs_reshape->shape().dimensions());
2260   it = rhs_transpose_shape_dims.erase(
2261       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2262       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2263           rhs_contracting_dims.size());
2264   for (auto dim : lhs_contracting_dims) {
2265     it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2266     ++it;
2267   }
2268   // Then compute the transpose dims.
2269   std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank());
2270   absl::c_iota(rhs_transpose_dims, 0);
2271   it = rhs_transpose_dims.erase(
2272       rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2273       rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2274           rhs_contracting_dims.size());
2275   auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2276   for (auto dim : lhs_contracting_dims) {
2277     it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2278                                            lhs_contracting_dims[0] +
2279                                            rhs_contracting_dims[0]);
2280     ++it;
2281   }
2282   HloInstruction* rhs_transpose =
2283       computation_->AddInstruction(HloInstruction::CreateTranspose(
2284           ShapeUtil::MakeShape(constant->shape().element_type(),
2285                                rhs_transpose_shape_dims),
2286           rhs_reshape, rhs_transpose_dims));
2287   rhs = rhs_transpose;
2288 
2289   // Squish the multiple rhs contracting dims into a single one.
2290   rhs = computation_->AddInstruction(
2291       HloInstruction::CreateReshape(constant->shape(), rhs));
2292 
2293   // If we virtually swapped lhs and rhs, we need to swap it back before
2294   // creating new dot.
2295   if (dot->operand(0)->IsConstant()) {
2296     std::swap(lhs, rhs);
2297   }
2298 
2299   HloInstruction* new_dot =
2300       computation_->AddInstruction(HloInstruction::CreateDot(
2301           dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2302   return new_dot;
2303 }
2304 
HandleDot(HloInstruction * dot)2305 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2306   CHECK(computation_ == dot->parent());
2307   HloInstruction *lhs, *rhs;
2308   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2309   if (options_.is_layout_sensitive()) {
2310     return Status::OK();
2311   }
2312   // Replace a zero element dot with a broadcast of the constant 0.
2313   if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2314       ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2315       ShapeUtil::IsZeroElementArray(rhs->shape())) {
2316     auto zero = computation_->AddInstruction(
2317         simplifier_->CreateConstantWithLayoutUpdated(
2318             LiteralUtil::Zero(dot->shape().element_type())));
2319     return ReplaceWithNewInstruction(
2320         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2321   }
2322 
2323   // If there are no contracting dimensions, a dot can be rewritten as
2324   // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2325   if (options_.enable_dot_to_multiply_rewrite() &&
2326       dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
2327     TF_ASSIGN_OR_RETURN(
2328         HloInstruction * new_lhs,
2329         NormalizeDotOperandToBatchMajorAndContractingMinor(
2330             lhs,
2331             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2332             AsInt64Slice(
2333                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2334     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2335       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2336     }
2337     TF_ASSIGN_OR_RETURN(
2338         HloInstruction * new_rhs,
2339         NormalizeDotOperandToBatchMajorAndContractingMinor(
2340             rhs,
2341             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2342             AsInt64Slice(
2343                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2344     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2345       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2346     }
2347     if (dot->shape().rank() != lhs->shape().rank()) {
2348       std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
2349       absl::c_iota(lhs_broadcast_dims, 0);
2350       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2351           dot->shape(), new_lhs, lhs_broadcast_dims));
2352     }
2353     if (dot->shape().rank() != rhs->shape().rank()) {
2354       std::vector<int64> rhs_broadcast_dims(
2355           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2356       absl::c_iota(rhs_broadcast_dims, 0);
2357       for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2358         rhs_broadcast_dims.push_back(i);
2359       }
2360       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2361           dot->shape(), new_rhs, rhs_broadcast_dims));
2362     }
2363     return ReplaceWithNewInstruction(
2364         dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2365                                           new_lhs, new_rhs));
2366   }
2367 
2368   // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2369   // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2370   if (options_.enable_dot_strength_reduction() &&
2371       ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2372             dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
2373         lhs->shape().rank()) ||
2374        (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
2375             dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
2376         rhs->shape().rank()))) {
2377     TF_ASSIGN_OR_RETURN(
2378         HloInstruction * new_lhs,
2379         NormalizeDotOperandToBatchMajorAndContractingMinor(
2380             lhs,
2381             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2382             AsInt64Slice(
2383                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2384     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2385       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2386     }
2387 
2388     TF_ASSIGN_OR_RETURN(
2389         HloInstruction * new_rhs,
2390         NormalizeDotOperandToBatchMajorAndContractingMinor(
2391             rhs,
2392             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2393             AsInt64Slice(
2394                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2395     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2396       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2397     }
2398 
2399     int64 lhs_outer_dims =
2400         lhs->shape().rank() -
2401         (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2402          dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2403     int64 rhs_outer_dims =
2404         rhs->shape().rank() -
2405         (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2406          dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
2407     CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2408     if (rhs_outer_dims > 0) {
2409       std::vector<int64> lhs_broadcast_dims(
2410           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2411       absl::c_iota(lhs_broadcast_dims, 0);
2412       lhs_broadcast_dims.resize(lhs->shape().rank());
2413       std::iota(lhs_broadcast_dims.begin() +
2414                     dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
2415                 lhs_broadcast_dims.end(),
2416                 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2417                     rhs_outer_dims);
2418       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2419           new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2420     } else if (lhs_outer_dims > 0) {
2421       std::vector<int64> rhs_broadcast_dims(
2422           dot->dot_dimension_numbers().rhs_batch_dimensions_size());
2423       absl::c_iota(rhs_broadcast_dims, 0);
2424       rhs_broadcast_dims.resize(rhs->shape().rank());
2425       std::iota(rhs_broadcast_dims.begin() +
2426                     dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
2427                 rhs_broadcast_dims.end(),
2428                 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2429                     lhs_outer_dims);
2430       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2431           new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2432     }
2433 
2434     TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2435                         MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2436     std::vector<int64> reduce_dims(
2437         dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2438     PrimitiveType dot_type =
2439         ShapeUtil::ElementIsFloating(dot->shape())
2440             ? (dot->shape().element_type() == F64 ? F64 : F32)
2441             : dot->shape().element_type();
2442     new_dot = AsType(new_dot, dot_type);
2443     const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2444     absl::c_iota(
2445         reduce_dims,
2446         outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2447     new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2448     new_dot = AsType(new_dot, dot->shape().element_type());
2449     return ReplaceInstruction(dot, new_dot);
2450   }
2451 
2452   // Simplify dot(reshape(transpose(A)), Const) to:
2453   // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2454   // and transpose on the Const side can be constant folded.
2455   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2456                       OptimizeDotOfReorderContractingDims(dot));
2457   if (dot_of_reorder_optimized) {
2458     VLOG(10) << " Replaced dot " << dot->ToString()
2459              << " with new dot operation: "
2460              << dot_of_reorder_optimized->ToString();
2461     return ReplaceInstruction(dot, dot_of_reorder_optimized);
2462   }
2463 
2464   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2465                       OptimizeDotOfConcat(dot));
2466   if (dot_of_concat_optimized) {
2467     VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2468                 "constant)...)";
2469     return ReplaceInstruction(dot, dot_of_concat_optimized);
2470   }
2471 
2472   // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2473   // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2474   // batched version of dot.
2475   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2476                       OptimizeDotOfGather(dot));
2477   if (dot_of_gather_optimized) {
2478     VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2479                 "gather(i, dot*(constA, constB))";
2480     return ReplaceInstruction(dot, dot_of_gather_optimized);
2481   }
2482 
2483   TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2484                       RemoveDegenerateDimensionFromDot(dot));
2485   if (removed_degenerate_dimensions) {
2486     return Status::OK();
2487   }
2488 
2489   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
2490   if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
2491       dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
2492       dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
2493       dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
2494       lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
2495     DotDimensionNumbers dot_dimension_numbers;
2496     dot_dimension_numbers.add_lhs_contracting_dimensions(1);
2497     dot_dimension_numbers.add_rhs_contracting_dimensions(0);
2498     auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
2499         ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
2500         rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
2501         dot->precision_config()));
2502     return ReplaceWithNewInstruction(
2503         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
2504   }
2505 
2506   return Status::OK();
2507 }
2508 
HandleGather(HloInstruction * gather)2509 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2510   const Shape& operand_shape = gather->operand(0)->shape();
2511   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2512     return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2513   }
2514 
2515   // Gathering from a scalar operand is simply a broadcast of that scalar
2516   if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2517     HloInstruction* new_operand = gather->mutable_operand(0);
2518     if (operand_shape.rank()) {
2519       TF_ASSIGN_OR_RETURN(new_operand,
2520                           MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2521                                              operand_shape.element_type()),
2522                                          new_operand));
2523     }
2524     HloInstruction* new_gather =
2525         MakeBroadcastHlo(new_operand, {}, gather->shape());
2526     return ReplaceInstruction(gather, new_gather);
2527   }
2528   // If the operand of a gather is very small, it is easier to fuse a
2529   // sequence of selects.
2530   const Shape& index_shape = gather->operand(1)->shape();
2531   if (operand_shape.rank() == 1 &&
2532       operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2533       gather->gather_dimension_numbers().index_vector_dim() ==
2534           index_shape.rank() &&
2535       gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2536     const int64 operand_elements = operand_shape.dimensions(0);
2537     auto get_value = [&](int64 i) {
2538       auto slice = computation_->AddInstruction(HloInstruction::CreateSlice(
2539           ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2540           gather->mutable_operand(0), {i}, {i + 1}, {1}));
2541       auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape(
2542           ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2543       return computation_->AddInstruction(
2544           HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2545     };
2546     auto result = get_value(0);
2547     auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2548     auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2549                                                    index_shape.element_type());
2550     for (int64 i = 0; i < operand_elements; ++i) {
2551       auto index_mask =
2552           computation_->AddInstruction(HloInstruction::CreateCompare(
2553               pred_shape, gather->mutable_operand(1),
2554               MakeScalarLike(gather->mutable_operand(1), i),
2555               ComparisonDirection::kGe));
2556       result = computation_->AddInstruction(
2557           HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2558                                         index_mask, get_value(i), result));
2559     }
2560     return ReplaceInstruction(gather, result);
2561   }
2562   return Status::OK();
2563 }
2564 
2565 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast)2566 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2567     HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2568     HloInstruction* clamp_upper_bound_bcast) {
2569   HloInstruction* clamp_lower_bound;
2570   CHECK(Match(clamp_lower_bound_bcast,
2571               m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2572       << clamp_lower_bound_bcast->ToString();
2573 
2574   HloInstruction* clamp_upper_bound;
2575   CHECK(Match(clamp_upper_bound_bcast,
2576               m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2577       << clamp_upper_bound_bcast->ToString();
2578 
2579   const Literal& lower_bound =
2580       Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2581   const Literal& upper_bound =
2582       Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2583 
2584   std::unique_ptr<HloInstruction> lower_bound_instr =
2585       HloInstruction::CreateConstant(lower_bound.Clone());
2586   std::unique_ptr<HloInstruction> upper_bound_instr =
2587       HloInstruction::CreateConstant(upper_bound.Clone());
2588 
2589   std::unique_ptr<HloInstruction> cloned_instruction =
2590       HloInstruction::CreateCompare(
2591           ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED),
2592           lower_bound_instr.get(), upper_bound_instr.get(),
2593           ComparisonDirection::kLt);
2594 
2595   HloEvaluator evaluator;
2596   TF_ASSIGN_OR_RETURN(auto result,
2597                       evaluator.Evaluate(cloned_instruction.get()));
2598   if (result.IsAll(true)) {
2599     return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2600                                          clamp_lower_bound_bcast, to_clamp,
2601                                          clamp_upper_bound_bcast);
2602   }
2603   return std::unique_ptr<HloInstruction>();
2604 }
2605 }  // namespace
2606 
HandleMaximum(HloInstruction * maximum)2607 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2608   HloInstruction *lhs, *rhs;
2609   CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2610 
2611   HloInstruction* clamp_upper_bound_bcast;
2612   HloInstruction* clamp_lower_bound_bcast;
2613   HloInstruction* to_clamp;
2614   if (Match(maximum, m::MaximumAnyOrder(
2615                          m::Broadcast(&clamp_lower_bound_bcast,
2616                                       m::ConstantEffectiveScalar()),
2617                          m::MinimumAnyOrder(
2618                              m::Op(&to_clamp),
2619                              m::Broadcast(&clamp_upper_bound_bcast,
2620                                           m::ConstantEffectiveScalar()))))) {
2621     TF_ASSIGN_OR_RETURN(auto clamp,
2622                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2623                                       clamp_upper_bound_bcast));
2624     if (clamp) {
2625       return ReplaceWithNewInstruction(maximum, std::move(clamp));
2626     }
2627   }
2628 
2629   HloInstruction* clamp_lower_bound;
2630   HloInstruction* clamp_upper_bound;
2631   HloInstruction* max_operand;
2632   HloInstruction* clamp;
2633   if (Match(maximum,
2634             m::MaximumAnyOrder(
2635                 m::Op(&max_operand),
2636                 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2637                          m::Op(&clamp_upper_bound))))) {
2638     if (max_operand == clamp_lower_bound &&
2639         ReplaceInstructionIfSameShape(maximum, clamp)) {
2640       return Status::OK();
2641     }
2642   }
2643 
2644   return Status::OK();
2645 }
2646 
HandleMinimum(HloInstruction * minimum)2647 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2648   HloInstruction *lhs, *rhs;
2649   CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2650 
2651   HloInstruction* clamp_upper_bound_bcast;
2652   HloInstruction* clamp_lower_bound_bcast;
2653   HloInstruction* to_clamp;
2654   if (Match(minimum, m::MinimumAnyOrder(
2655                          m::Broadcast(&clamp_upper_bound_bcast,
2656                                       m::ConstantEffectiveScalar()),
2657                          m::MaximumAnyOrder(
2658                              m::Op(&to_clamp),
2659                              m::Broadcast(&clamp_lower_bound_bcast,
2660                                           m::ConstantEffectiveScalar()))))) {
2661     TF_ASSIGN_OR_RETURN(auto clamp,
2662                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2663                                       clamp_upper_bound_bcast));
2664     if (clamp) {
2665       return ReplaceWithNewInstruction(minimum, std::move(clamp));
2666     }
2667   }
2668 
2669   return Status::OK();
2670 }
2671 
HandleClamp(HloInstruction * clamp)2672 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2673   HloInstruction* clamp_lower_bound;
2674   HloInstruction* clamp_upper_bound;
2675   HloInstruction* to_clamp;
2676   CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2677                               m::Op(&clamp_upper_bound))));
2678 
2679   // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2680   if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2681                                m::Op().Is(clamp_upper_bound))) &&
2682       ReplaceInstructionIfSameShape(clamp, to_clamp)) {
2683     return Status::OK();
2684   }
2685 
2686   return Status::OK();
2687 }
2688 
HandleMultiply(HloInstruction * multiply)2689 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
2690   HloInstruction *lhs, *rhs;
2691   CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
2692   // LHS*1 => LHS
2693   VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
2694   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
2695     return Status::OK();
2696   }
2697   // 1*RHS => RHS
2698   VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
2699   if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
2700     return Status::OK();
2701   }
2702 
2703   // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
2704   if (IsAll(lhs, 0) &&
2705       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2706       ReplaceInstructionIfSameShape(multiply, lhs)) {
2707     return Status::OK();
2708   }
2709   // LHS*0 => 0
2710   if (IsAll(rhs, 0) &&
2711       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2712       ReplaceInstructionIfSameShape(multiply, rhs)) {
2713     return Status::OK();
2714   }
2715 
2716   {
2717     HloInstruction* abs_operand;
2718     if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
2719         !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
2720       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
2721       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
2722       changed_ = true;
2723       return Status::OK();
2724     }
2725   }
2726 
2727   {
2728     HloInstruction *convert_operand, *operand;
2729     // Mul(Convert(Pred), operand) => select(pred, operand, 0)
2730     if (Match(multiply,
2731               m::MultiplyAnyOrder(
2732                   m::Op(&operand),
2733                   m::Convert(
2734                       m::Op(&convert_operand)
2735                           .WithShape(m::Shape().WithElementType(PRED)))))) {
2736       HloInstruction* zero_like_multiply =
2737           BroadcastZeros(computation_, multiply->shape().element_type(),
2738                          multiply->shape().dimensions());
2739       return ReplaceWithNewInstruction(
2740           multiply, HloInstruction::CreateTernary(
2741                         multiply->shape(), HloOpcode::kSelect, convert_operand,
2742                         operand, zero_like_multiply));
2743     }
2744   }
2745 
2746   {
2747     HloInstruction *a, *b, *c1, *c2;
2748     // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
2749     // constant1*constant2)
2750     if (Match(multiply,
2751               m::MultiplyAnyOrder(
2752                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2753                   m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
2754       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2755                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2756       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2757           !ShapeUtil::IsScalar(multiply->shape())) {
2758         product_of_constants =
2759             computation_->AddInstruction(HloInstruction::CreateBroadcast(
2760                 multiply->shape(), product_of_constants, {}));
2761       }
2762 
2763       return ReplaceWithNewInstruction(
2764           multiply,
2765           HloInstruction::CreateBinary(
2766               multiply->shape(), HloOpcode::kMultiply,
2767               computation_->AddInstruction(HloInstruction::CreateBinary(
2768                   multiply->shape(), HloOpcode::kMultiply, a, b)),
2769               product_of_constants));
2770     }
2771   }
2772 
2773   {
2774     HloInstruction *a, *c1, *c2;
2775     // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
2776     if (Match(multiply,
2777               m::MultiplyAnyOrder(
2778                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2779                   m::Constant(&c2)))) {
2780       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2781                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2782       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2783           !ShapeUtil::IsScalar(multiply->shape())) {
2784         product_of_constants =
2785             computation_->AddInstruction(HloInstruction::CreateBroadcast(
2786                 multiply->shape(), product_of_constants, {}));
2787       }
2788 
2789       return ReplaceWithNewInstruction(
2790           multiply,
2791           HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
2792                                        a, product_of_constants));
2793     }
2794   }
2795 
2796   {
2797     HloInstruction *a, *b, *constant, *op;
2798     // Mul(Mul(a, constant1), Broadcast(b)) =>
2799     // Mul(Broadcast(Mul(b, constant1), a))
2800     if (Match(multiply,
2801               m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
2802                                                       m::Constant(&constant)),
2803                                   m::Op(&op))) ||
2804         Match(multiply,
2805               m::MultiplyAnyOrder(
2806                   m::MultiplyAnyOrder(m::NonConstant(&a),
2807                                       m::Broadcast(m::Constant(&constant))),
2808                   m::Op(&op)))) {
2809       // Check that the other side was a broadcast, and not of a constant.
2810       if (ShapeUtil::IsScalar(constant->shape()) &&
2811           Match(op, m::Broadcast(m::NonConstant()))) {
2812         auto dims = op->dimensions();
2813         b = op->mutable_operand(0);
2814         if (!ShapeUtil::IsScalar(b->shape())) {
2815           constant = computation_->AddInstruction(
2816               HloInstruction::CreateBroadcast(b->shape(), constant, {}));
2817         }
2818 
2819         auto new_mul =
2820             computation_->AddInstruction(HloInstruction::CreateBinary(
2821                 b->shape(), HloOpcode::kMultiply, b, constant));
2822 
2823         return ReplaceWithNewInstruction(
2824             multiply,
2825             HloInstruction::CreateBinary(
2826                 multiply->shape(), HloOpcode::kMultiply, a,
2827                 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2828                     multiply->shape(), new_mul, dims))));
2829       }
2830     }
2831   }
2832 
2833   VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
2834   HloInstruction *a, *c1, *c2;
2835   if (Match(multiply,
2836             m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
2837                         m::Constant(&c2))) ||
2838       Match(multiply,
2839             m::Multiply(
2840                 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
2841                 m::Broadcast(m::ConstantScalar(&c2))))) {
2842     TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2843                         MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2844     if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2845         !ShapeUtil::IsScalar(multiply->shape())) {
2846       product_of_constants =
2847           computation_->AddInstruction(HloInstruction::CreateBroadcast(
2848               multiply->shape(), product_of_constants, {}));
2849     }
2850     return ReplaceWithNewInstruction(
2851         multiply,
2852         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
2853                                      product_of_constants));
2854   }
2855 
2856   VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
2857            << multiply->ToString();
2858   if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
2859     auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
2860         multiply->shape(), HloOpcode::kAdd, lhs, rhs));
2861     return ReplaceWithNewInstruction(
2862         multiply,
2863         HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
2864   }
2865 
2866   VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
2867            << multiply->ToString();
2868   HloInstruction* b;
2869   if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
2870       IsPositive(b, options_)) {
2871     return ReplaceWithNewInstruction(
2872         multiply,
2873         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
2874                                      MakeScalarLike(b, 1), b));
2875   }
2876 
2877   return Status::OK();
2878 }
2879 
HandleNegate(HloInstruction * negate)2880 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
2881   // negate(negate(x)) => x
2882   HloInstruction* x;
2883   if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
2884       ReplaceInstructionIfSameShape(negate, x)) {
2885     return Status::OK();
2886   }
2887   return Status::OK();
2888 }
2889 
HandleNot(HloInstruction * logical_not)2890 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
2891   // not(not(x)) => x
2892   HloInstruction* x;
2893   if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
2894       ReplaceInstructionIfSameShape(logical_not, x)) {
2895     return Status::OK();
2896   }
2897   return Status::OK();
2898 }
2899 
HandleOr(HloInstruction * logical_or)2900 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
2901   HloInstruction *lhs, *rhs;
2902   CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
2903 
2904   // Simplify logical or
2905   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
2906       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
2907     // A || True => True
2908     VLOG(10) << "trying transform [A || True => True]: "
2909              << logical_or->ToString();
2910     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
2911       return Status::OK();
2912     }
2913     // True || A => True
2914     VLOG(10) << "trying transform [True || A => True]: "
2915              << logical_or->ToString();
2916     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
2917       return Status::OK();
2918     }
2919   }
2920 
2921   // A || False => A and A | 0 => A
2922   VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
2923   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
2924     return Status::OK();
2925   }
2926 
2927   // False || A => A and 0 | A => A
2928   VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
2929   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
2930     return Status::OK();
2931   }
2932 
2933   return Status::OK();
2934 }
2935 
HandleLog(HloInstruction * log)2936 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
2937   // ln(exp(A)) => A
2938   VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
2939   HloInstruction *a, *b;
2940   if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
2941       ReplaceInstructionIfSameShape(log, a)) {
2942     return Status::OK();
2943   }
2944 
2945   // ln(pow(A,B)) => B*ln(abs(A))
2946   // or B*ln(A) if A is complex.
2947   if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
2948     auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
2949                      ? a
2950                      : computation_->AddInstruction(HloInstruction::CreateUnary(
2951                            log->shape(), HloOpcode::kAbs, a));
2952     auto new_log = computation_->AddInstruction(
2953         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
2954     return ReplaceWithNewInstruction(
2955         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2956                                           new_log, b));
2957   }
2958 
2959   if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
2960     auto new_log = computation_->AddInstruction(
2961         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
2962     return ReplaceWithNewInstruction(
2963         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2964                                           new_log, MakeScalarLike(log, 0.5)));
2965   }
2966 
2967   if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
2968     auto new_log = computation_->AddInstruction(
2969         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
2970     return ReplaceWithNewInstruction(
2971         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2972                                           new_log, MakeScalarLike(log, -0.5)));
2973   }
2974 
2975   return Status::OK();
2976 }
2977 
HandleGetTupleElement(HloInstruction * get_tuple_element)2978 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
2979     HloInstruction* get_tuple_element) {
2980   auto operand = get_tuple_element->mutable_operand(0);
2981   if (operand->opcode() == HloOpcode::kTuple) {
2982     // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
2983     VLOG(10) << "trying transform "
2984              << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
2985              << get_tuple_element->ToString();
2986     if (ReplaceInstructionIfSameShape(
2987             get_tuple_element,
2988             operand->mutable_operand(get_tuple_element->tuple_index()))) {
2989       return Status::OK();
2990     }
2991   }
2992   return Status::OK();
2993 }
2994 
2995 namespace {
2996 
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)2997 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
2998     const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
2999   CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3000   return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3001       hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3002 }
3003 
3004 // Returns true if the output of "instruction" is a permutation of the
3005 // elements of "operand". Precondition: "operand" is an operand of
3006 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3007 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3008                                           HloInstruction* operand) {
3009   DCHECK(!instruction->OperandIndices(operand).empty());
3010   switch (instruction->opcode()) {
3011     case HloOpcode::kReshape:
3012     case HloOpcode::kReverse:
3013     case HloOpcode::kTranspose:
3014       return true;
3015     case HloOpcode::kSort:
3016       return (!instruction->shape().IsTuple());
3017     default:
3018       return false;
3019   }
3020 }
3021 
3022 // Returns true if the output of "instruction" is a subset of the elements of
3023 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3024 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3025                                      HloInstruction* operand) {
3026   const auto operand_indices = instruction->OperandIndices(operand);
3027   CHECK(!operand_indices.empty());
3028   if (operand_indices.size() != 1) {
3029     return false;
3030   }
3031   int64 operand_index = operand_indices[0];
3032   switch (instruction->opcode()) {
3033     case HloOpcode::kSlice:
3034       CHECK_EQ(0, operand_index);
3035       return true;
3036     case HloOpcode::kDynamicSlice:
3037       return operand_index == 0;
3038     default:
3039       return false;
3040   }
3041 }
3042 
3043 }  // namespace
3044 
HandleBroadcast(HloInstruction * broadcast)3045 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3046   HloInstruction* operand;
3047   CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3048   auto dims = broadcast->dimensions();
3049   // A degenerate broadcast of a reshape that does not change the number of
3050   // elements can be replaced by a reshape.
3051   if (std::is_sorted(dims.begin(), dims.end()) &&
3052       ShapeUtil::ElementsIn(broadcast->shape()) ==
3053           ShapeUtil::ElementsIn(operand->shape())) {
3054     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3055                 "n(broadcast(X)) == n(X)";
3056     return ReplaceWithNewInstruction(
3057         broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3058   }
3059 
3060   // A degenerate broadcast that has the same input and output rank can be
3061   // converted into a transpose.
3062   if (broadcast->shape().rank() == operand->shape().rank() &&
3063       ShapeUtil::ElementsIn(broadcast->shape()) ==
3064           ShapeUtil::ElementsIn(operand->shape())) {
3065     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3066                 "n(broadcast(X)) == n(X)";
3067     return ReplaceWithNewInstruction(
3068         broadcast,
3069         HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3070   }
3071 
3072   // A broadcast of a reshape which merely inserts 1-sized dimensions can
3073   // elide its operand.
3074   {
3075     bool merely_inserts_or_deletes_1_sized_dimensions;
3076     std::vector<int64> inserted_indices, deleted_indices;
3077     std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
3078              inserted_indices) =
3079         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3080     if (merely_inserts_or_deletes_1_sized_dimensions &&
3081         deleted_indices.empty()) {
3082       std::reverse(inserted_indices.begin(), inserted_indices.end());
3083       for (auto inserted_index : inserted_indices) {
3084         dims.erase(dims.begin() + inserted_index);
3085       }
3086       return ReplaceWithNewInstruction(
3087           broadcast,
3088           HloInstruction::CreateBroadcast(broadcast->shape(),
3089                                           operand->mutable_operand(0), dims));
3090     }
3091   }
3092 
3093   // A Broadcast that feeds a unary element-wise operation can sink the
3094   // broadcast after the unary element-wise operation.
3095   TF_ASSIGN_OR_RETURN(
3096       bool sink_succeeded,
3097       TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3098   changed_ |= sink_succeeded;
3099   if (sink_succeeded) {
3100     return Status::OK();
3101   }
3102 
3103   // A scalar broadcast feeding an instruction which only permutes (reshape,
3104   // transpose, sort, reverse) or selects a subset of operand elements (slice,
3105   // dynamic slice) can be replaced with a broadcast directly to the output
3106   // shape of the instruction.
3107   if (ShapeUtil::IsScalar(operand->shape())) {
3108     for (HloInstruction* user : broadcast->users()) {
3109       // Skip if the broadcast user has no uses itself.
3110       if (user->user_count() == 0 && user != computation_->root_instruction()) {
3111         continue;
3112       }
3113       if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3114           OutputIsSubsetOfOperandElements(user, broadcast)) {
3115         VLOG(10) << "transform permuting/subset  of a scalar broadcast into "
3116                  << "a single broadcast";
3117         HloInstruction* new_broadcast = computation_->AddInstruction(
3118             HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3119         // Use HloInstruction::ReplaceAllUsesWith instead of
3120         // HloComputation::ReplaceWithNewInstruction because we are replacing an
3121         // instruction other than the visited instruction.
3122         changed_ = true;
3123         return user->ReplaceAllUsesWith(new_broadcast);
3124       }
3125     }
3126     return Status::OK();
3127   }
3128 
3129   // broadcast(iota) -> iota.
3130   if (operand->opcode() == HloOpcode::kIota) {
3131     return ReplaceWithNewInstruction(
3132         broadcast,
3133         HloInstruction::CreateIota(
3134             broadcast->shape(),
3135             dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3136   }
3137 
3138   // Merge two consecutive broadcasts into a single one.
3139   if (operand->opcode() == HloOpcode::kBroadcast) {
3140     std::vector<int64> new_dimensions;
3141     for (auto dim : operand->dimensions()) {
3142       new_dimensions.push_back(dims[dim]);
3143     }
3144     return ReplaceWithNewInstruction(
3145         broadcast,
3146         HloInstruction::CreateBroadcast(
3147             broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3148   }
3149   if (options_.is_layout_sensitive()) {
3150     return Status::OK();
3151   }
3152   if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3153     auto new_operand =
3154         operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3155             ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3156     std::vector<int64> new_dims;
3157     new_dims.reserve(new_operand->shape().rank());
3158     for (int64 i = 0; i < operand->shape().rank(); ++i) {
3159       if (operand->shape().dimensions(i) != 1) {
3160         new_dims.push_back(dims[i]);
3161       }
3162     }
3163     return ReplaceWithNewInstruction(
3164         broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3165                                                    new_operand, new_dims));
3166   }
3167   return Status::OK();
3168 }
3169 
HandleCompare(HloInstruction * compare)3170 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3171   HloInstruction* lhs;
3172   HloInstruction* rhs;
3173   CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3174   {
3175     // compare(broadcast(a) + x, broadcast(b)) ==>
3176     //   compare(x, broadcast(b-a)), only enabled for integral types.
3177     HloInstruction *x, *a, *b;
3178     if (Match(compare,
3179               m::Compare(
3180                   m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
3181                                                 m::Shape().IsScalar()))),
3182                   m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
3183       if (ShapeUtil::ElementIsSigned(x->shape()) &&
3184           ShapeUtil::ElementIsIntegral(x->shape())) {
3185         HloInstruction* sub =
3186             computation_->AddInstruction(HloInstruction::CreateBinary(
3187                 b->shape(), HloOpcode::kSubtract, b, a));
3188         HloInstruction* broadcast = computation_->AddInstruction(
3189             HloInstruction::CreateBroadcast(x->shape(), sub, {}));
3190         HloInstruction* new_compare = computation_->AddInstruction(
3191             HloInstruction::CreateCompare(compare->shape(), x, broadcast,
3192                                           compare->comparison_direction()));
3193         return ReplaceInstruction(compare, new_compare);
3194       }
3195     }
3196   }
3197 
3198   if (Cast<HloCompareInstruction>(compare)->type() ==
3199       Comparison::Type::kUnsigned) {
3200     // X u<  0 -> false
3201     if (compare->comparison_direction() == ComparisonDirection::kLt &&
3202         IsAll(rhs, 0)) {
3203       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3204     }
3205     // X u>= 0 -> true
3206     if (compare->comparison_direction() == ComparisonDirection::kGe &&
3207         IsAll(rhs, 0)) {
3208       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3209     }
3210     // 0 u>  X -> false
3211     if (compare->comparison_direction() == ComparisonDirection::kGt &&
3212         IsAll(lhs, 0)) {
3213       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3214     }
3215     // 0 u<= X -> true
3216     if (compare->comparison_direction() == ComparisonDirection::kLe &&
3217         IsAll(lhs, 0)) {
3218       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3219     }
3220   }
3221 
3222   if (compare->comparison_direction() == ComparisonDirection::kLt &&
3223       lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3224     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3225   } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3226              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3227     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3228   } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3229              lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3230     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3231   } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3232              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3233     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3234   }
3235   if (lhs == rhs &&
3236       primitive_util::IsIntegralType(lhs->shape().element_type())) {
3237     switch (compare->comparison_direction()) {
3238       case ComparisonDirection::kGt:
3239       case ComparisonDirection::kLt:
3240       case ComparisonDirection::kNe:
3241         return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3242       case ComparisonDirection::kEq:
3243       case ComparisonDirection::kGe:
3244       case ComparisonDirection::kLe:
3245         return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3246     }
3247   }
3248   return Status::OK();
3249 }
3250 
HandleConvert(HloInstruction * convert)3251 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3252   PrimitiveType src_type = convert->operand(0)->shape().element_type();
3253   PrimitiveType dest_type = convert->shape().element_type();
3254   // A conversion to the same element type as the operand is a nop and can be
3255   // removed.  A conversion of a constant can be simplified by making a new
3256   // constant.
3257   if (src_type == dest_type) {
3258     return ReplaceInstruction(convert, convert->mutable_operand(0));
3259   }
3260 
3261   // Eliminate a convert pair if it is a no-op. The following are a few
3262   // example cases that are being handled:
3263   // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3264   //    and convert(A, $TYPE1) is an upcast
3265   // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3266   //    and convert(A, $TYPE1) is an upcast and is an integral conversion from
3267   //    unsigned to signed (only signed to unsigned conversion is NOT allowed)
3268   // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3269   //    convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3270   //    $TYPE1) , floor(A), A) -> a case where the first convert has a
3271   //    fan-out
3272   if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
3273       IsConvertPairNoOp(convert)) {
3274     return ReplaceInstruction(convert,
3275                               convert->mutable_operand(0)->mutable_operand(0));
3276   }
3277   return Status::OK();
3278 }
3279 
3280 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3281 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3282   HloInstruction *c0, *c1;
3283   if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3284       c0 == c1) {
3285     return ReplaceInstruction(complex, c0);
3286   }
3287   return Status::OK();
3288 }
3289 
3290 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3291 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3292   HloInstruction* op;
3293   if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3294     return ReplaceInstruction(real, op);
3295   }
3296   return Status::OK();
3297 }
3298 
3299 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3300 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3301   HloInstruction* op;
3302   if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3303     return ReplaceInstruction(imag, op);
3304   }
3305   return Status::OK();
3306 }
3307 
HandleIota(HloInstruction * instruction)3308 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3309   // iota -> zero if the iota dimension never produces an element other than
3310   // zero.
3311   auto* iota = Cast<HloIotaInstruction>(instruction);
3312   if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3313     auto zero = computation_->AddInstruction(
3314         simplifier_->CreateConstantWithLayoutUpdated(
3315             LiteralUtil::Zero(iota->shape().element_type()).Clone()));
3316     return ReplaceWithNewInstruction(
3317         iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
3318   }
3319   return Status::OK();
3320 }
3321 
HandlePad(HloInstruction * pad)3322 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3323   if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3324     return ReplaceWithNewInstruction(
3325         pad, HloInstruction::CreateBroadcast(pad->shape(),
3326                                              pad->mutable_operand(1), {}));
3327   }
3328 
3329   // Interior padding on one sized dimensions have no effect. As a result it
3330   // makes other simplifications possible if there is no interior padding.
3331   if (HasInteriorPadding(pad->padding_config())) {
3332     PaddingConfig padding_config = pad->padding_config();
3333     bool cleared_interior_padding = false;
3334     for (int64 i = 0; i < pad->shape().rank(); ++i) {
3335       if (padding_config.dimensions(i).interior_padding() > 0 &&
3336           pad->operand(0)->shape().dimensions(i) == 1) {
3337         cleared_interior_padding = true;
3338         padding_config.mutable_dimensions(i)->set_interior_padding(0);
3339       }
3340     }
3341     if (cleared_interior_padding) {
3342       return ReplaceWithNewInstruction(
3343           pad,
3344           HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3345                                     pad->mutable_operand(1), padding_config));
3346     }
3347   }
3348 
3349   // Eliminate nop pads (padding all zero), and replace a pad with negative
3350   // padding with a pad with non-negative padding followed by a slice.
3351   bool all_zero = true;
3352   bool has_negative = false;
3353   // Used to possibly split off the unchanged padding dimensions.
3354   std::vector<int64> padding_dimensions;
3355   int64 dimension_index = 0;
3356   for (auto& padding_dimension : pad->padding_config().dimensions()) {
3357     if (padding_dimension.edge_padding_low() < 0 ||
3358         padding_dimension.edge_padding_high() < 0) {
3359       has_negative = true;
3360     }
3361     if (padding_dimension.edge_padding_low() != 0 ||
3362         padding_dimension.edge_padding_high() != 0) {
3363       all_zero = false;
3364       padding_dimensions.push_back(dimension_index);
3365     } else if (padding_dimension.interior_padding()) {
3366       padding_dimensions.push_back(dimension_index);
3367     }
3368     dimension_index++;
3369   }
3370 
3371   if (all_zero) {
3372     if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) {
3373       return Status::OK();
3374     }
3375   }
3376 
3377   // The context of this optimization can be found at b/163617402
3378   // It tries to capture the case of pad(broadcast(x)), where
3379   // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3380   // a subset of the padded dimensions in pad->config(),
3381   // and the padded dimensions in pad->config() is in turn a strict
3382   // subset of broadcast->shape().dimensions(). The combined op can be
3383   // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3384   // x  with dimensions that need to be padded, and broadcast2 extends
3385   // the result of padding to full dimensions.
3386   // TODO(qyi): for future extensions: The condition for broadcast(x)
3387   // ->dimensions() to be a subset of padded dimensions in pad->config()
3388   // does not have to be strictly required, but it makes the calculation
3389   // for optimization easier, so it is required by the current implementation.
3390   // Only the second condition between the padded dimensions and the
3391   // dimensions of the final shape have to be enforced for the optimization
3392   // to make sense. If needed to remove the first constraint, the shape
3393   // calculations across the implementation need to be re-adjusted.
3394   auto pad_dims = padding_dimensions.size();
3395   if (pad_dims < dimension_index &&
3396       pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3397       pad->operand(0)->user_count() == 1 &&
3398       pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3399     // Check broadcast operand dimensions is a subset of pading_dimensions.
3400     // If not, skip the optimization.
3401     bool opt_is_valid = true;
3402     std::vector<int64> broadcast_dimensions;
3403     HloBroadcastInstruction* broadcast =
3404         static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3405     for (auto broadcast_index : broadcast->dimensions()) {
3406       bool found = false;
3407       for (int i = 0; i < pad_dims; ++i) {
3408         if (broadcast_index == padding_dimensions[i]) {
3409           broadcast_dimensions.push_back(i);
3410           found = true;
3411           break;
3412         }
3413       }
3414       if (!found) {
3415         opt_is_valid = false;
3416         break;
3417       }
3418     }
3419     if (opt_is_valid) {
3420       auto pad_shape = pad->shape();
3421       auto broadcast_shape = broadcast->shape();
3422       auto pad_shape1 = pad_shape;
3423       auto broadcast_shape1 = broadcast_shape;
3424       PaddingConfig pad_config;
3425       for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3426         int64 j = padding_dimensions[i];
3427         while (--dimension_index > j) {
3428           broadcast_shape1.DeleteDimension(dimension_index);
3429           pad_shape1.DeleteDimension(dimension_index);
3430         }
3431       }
3432       while (--dimension_index >= 0) {
3433         broadcast_shape1.DeleteDimension(dimension_index);
3434         pad_shape1.DeleteDimension(dimension_index);
3435       }
3436       for (auto dimension_to_pad : padding_dimensions) {
3437         auto dimension = pad_config.add_dimensions();
3438         *dimension = pad->padding_config().dimensions(dimension_to_pad);
3439       }
3440       *broadcast->mutable_shape() = broadcast_shape1;
3441       *broadcast->mutable_dimensions() = broadcast_dimensions;
3442       simplifier_->UpdateLayout(broadcast->mutable_shape());
3443       auto pad2 =
3444           computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3445       *pad2->mutable_padding_config() = pad_config;
3446       simplifier_->UpdateLayout(pad2->mutable_shape());
3447       auto broadcast2 = computation_->AddInstruction(
3448           HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3449       return ReplaceInstruction(pad, broadcast2);
3450     }
3451   }
3452 
3453   if (has_negative && options_.enable_negative_padding_replacement()) {
3454     // Pad has negative padding. Replace with a pad with the non-negative
3455     // padding followed by a slice which effectively performs the negative
3456     // padding.
3457     // TODO(b/34628603): Add support for negative padding in the backends, or
3458     // change kPad semantics to disallow negative padding and use slice
3459     // instead.
3460 
3461     // First construct the padding config with non-negative entries and the
3462     // compute the shape of this new pad instruction.
3463     PaddingConfig nonzero_padding = pad->padding_config();
3464     for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3465       PaddingConfig::PaddingConfigDimension* padding_dimension =
3466           nonzero_padding.mutable_dimensions(i);
3467       // Set negative padding to zero.
3468       if (padding_dimension->edge_padding_low() < 0) {
3469         padding_dimension->set_edge_padding_low(0);
3470       }
3471       if (padding_dimension->edge_padding_high() < 0) {
3472         padding_dimension->set_edge_padding_high(0);
3473       }
3474     }
3475 
3476     TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3477                         MakePadHlo(pad->mutable_operand(0),
3478                                    pad->mutable_operand(1), nonzero_padding));
3479     // Copy the layout from the original pad instructions. The new pad and the
3480     // slice instruction should all have the same layout.
3481     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3482         pad->shape(), nonzero_pad->mutable_shape()));
3483     simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3484 
3485     // Second, construct the slice instruction to perform the negative
3486     // padding.
3487     std::vector<int64> start_indices;
3488     std::vector<int64> end_indices;
3489     std::vector<int64> strides;
3490     for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3491       const PaddingConfig::PaddingConfigDimension& padding_dimension =
3492           pad->padding_config().dimensions(i);
3493       int64 start = 0;
3494       if (padding_dimension.edge_padding_low() < 0) {
3495         start = -1 * padding_dimension.edge_padding_low();
3496       }
3497       int64 end = nonzero_pad->shape().dimensions(i);
3498       if (padding_dimension.edge_padding_high() < 0) {
3499         end += padding_dimension.edge_padding_high();
3500       }
3501       start_indices.push_back(start);
3502       end_indices.push_back(end);
3503       strides.push_back(1);
3504     }
3505 
3506     TF_ASSIGN_OR_RETURN(
3507         HloInstruction * slice,
3508         MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3509     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3510         pad->shape(), slice->mutable_shape()));
3511     simplifier_->UpdateLayout(slice->mutable_shape());
3512 
3513     // Verify that the slice shape matches the pad shape.
3514     auto equal = Shape::Equal();
3515     if (!options_.is_layout_sensitive()) {
3516       equal.IgnoreTilesInLayout();
3517     }
3518     TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3519 
3520     return ReplaceInstruction(pad, slice);
3521   }
3522 
3523   return Status::OK();
3524 }
3525 
HandlePower(HloInstruction * power)3526 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3527   VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3528   HloInstruction *lhs, *rhs;
3529   CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3530   if (IsAll(rhs, 0)) {
3531     return ReplaceInstruction(power, MakeScalarLike(power, 1));
3532   }
3533 
3534   VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3535   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
3536     return Status::OK();
3537   }
3538 
3539   // pow(exp(A),B) => exp(A*B)
3540   HloInstruction *a, *b;
3541   if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3542     auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
3543         power->shape(), HloOpcode::kMultiply, a, b));
3544     return ReplaceWithNewInstruction(
3545         power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3546                                            a_times_b));
3547   }
3548 
3549   VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3550   if (IsAll(rhs, 2)) {
3551     return ReplaceWithNewInstruction(
3552         power, HloInstruction::CreateBinary(power->shape(),
3553                                             HloOpcode::kMultiply, lhs, lhs));
3554   }
3555 
3556   // Pow(A, 3) is used in GELU.
3557   VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3558   if (IsAll(rhs, 3)) {
3559     HloInstruction* tmp =
3560         computation_->AddInstruction(HloInstruction::CreateBinary(
3561             power->shape(), HloOpcode::kMultiply, lhs, lhs));
3562     return ReplaceWithNewInstruction(
3563         power, HloInstruction::CreateBinary(power->shape(),
3564                                             HloOpcode::kMultiply, lhs, tmp));
3565   }
3566 
3567   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3568   if (IsAll(rhs, -1)) {
3569     return ReplaceWithNewInstruction(
3570         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3571                                             MakeScalarLike(lhs, 1), lhs));
3572   }
3573 
3574   return Status::OK();
3575 }
3576 
3577 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3578 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3579     HloInstruction* broadcast) {
3580   TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3581   bool changed = false;
3582   if (ShapeUtil::IsScalar(broadcast->shape())) {
3583     return false;
3584   }
3585   HloInstruction* operand = broadcast->mutable_operand(0);
3586   auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3587     return instruction->opcode() == HloOpcode::kBroadcast &&
3588            ShapeUtil::IsScalar(instruction->operand(0)->shape());
3589   };
3590   auto is_equal_broadcast = [operand,
3591                              broadcast](const HloInstruction* instruction) {
3592     return instruction->opcode() == HloOpcode::kBroadcast &&
3593            ShapeUtil::Equal(operand->shape(),
3594                             instruction->operand(0)->shape()) &&
3595            broadcast->dimensions() == instruction->dimensions();
3596   };
3597   auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3598     return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3599   };
3600   for (HloInstruction* user : broadcast->users()) {
3601     if (user->user_count() == 0 && user != computation_->root_instruction()) {
3602       continue;
3603     }
3604     // Do not move reshapes or broadcasts past copies since the shape the copy
3605     // will operate on will change.
3606     if (user->opcode() == HloOpcode::kCopy) {
3607       continue;
3608     }
3609     // Do not change the shape of fusion nodes in case there a multiple shapes
3610     // inside the fusion node already.
3611     if (user->opcode() == HloOpcode::kFusion) {
3612       continue;
3613     }
3614     if (!user->IsElementwise()) {
3615       continue;
3616     }
3617 
3618     // Check if all the operands of the user are compatible broadcasts for
3619     // sinking. (They are either scalar broadcasts or broadcasts casting
3620     // from/to the same shape/dimensions)
3621     int64 compatible_broadcast_count = 0;
3622     int64 broadcast_use_count = 0;
3623     for (HloInstruction* user_operand : user->operands()) {
3624       if (is_compatible_broadcast(user_operand)) {
3625         ++compatible_broadcast_count;
3626       } else if (broadcast == user_operand) {
3627         ++broadcast_use_count;
3628       }
3629     }
3630     if (compatible_broadcast_count + broadcast_use_count !=
3631         user->operand_count()) {
3632       continue;
3633     }
3634     std::vector<HloInstruction*> new_operands;
3635     new_operands.reserve(user->operand_count());
3636 
3637     Shape changed_shape;
3638     for (HloInstruction* user_operand : user->operands()) {
3639       // If this is a broadcast operand that is not our original broadcast input
3640       // to this function then we might need to change the input.
3641       if (is_compatible_broadcast(user_operand)) {
3642         // If this is a broadcast from a scalar value rewrite a broadcast from
3643         // the scalar to the new shape enforced from the other broadcast
3644         // operands.
3645         if (is_scalar_broadcast(user_operand)) {
3646           changed_shape = ShapeUtil::ChangeElementType(
3647               operand->shape(), user_operand->shape().element_type());
3648           simplifier_->UpdateLayout(&changed_shape);
3649           new_operands.push_back(
3650               computation_->AddInstruction(HloInstruction::CreateBroadcast(
3651                   changed_shape, user_operand->mutable_operand(0), {})));
3652           user_operand->SetupDerivedInstruction(new_operands.back());
3653         } else {
3654           // For the non-scalar broadcasts we guarantee that the shape of the
3655           // operand of the broadcast needs to be already a compatible shape.
3656           new_operands.push_back(user_operand->mutable_operand(0));
3657         }
3658       } else {
3659         CHECK_EQ(broadcast, user_operand);
3660         new_operands.push_back(operand);
3661       }
3662     }
3663     VLOG(4) << "Sinking broadcast after user:";
3664     VLOG(4) << "  old broadcast: " << broadcast->ToString();
3665     VLOG(4) << "  old user: " << user->ToString();
3666     changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
3667                                                  user->shape().element_type());
3668     simplifier_->UpdateLayout(&changed_shape);
3669     HloInstruction* new_user = computation_->AddInstruction(
3670         user->CloneWithNewOperands(changed_shape, new_operands));
3671     VLOG(4) << "  new user: " << new_user->ToString();
3672     HloInstruction* new_broadcast =
3673         computation_->AddInstruction(HloInstruction::CreateBroadcast(
3674             user->shape(), new_user, broadcast->dimensions()));
3675     broadcast->SetupDerivedInstruction(new_broadcast);
3676     VLOG(4) << "  new broadcast: " << new_broadcast->ToString();
3677     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
3678     changed = true;
3679   }
3680   return changed;
3681 }
3682 
3683 namespace {
3684 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)3685 std::unique_ptr<HloInstruction> TryRemainderToAnd(
3686     HloInstruction* remainder, HloComputation* computation,
3687     AlgebraicSimplifier* simplifier) {
3688   HloInstruction *a, *b, *c;
3689   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3690 
3691   if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
3692       !Match(b, m::ConstantEffectiveScalar(&c)) &&
3693       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
3694     return nullptr;
3695   }
3696 
3697   if (ShapeUtil::ElementIsSigned(remainder->shape())) {
3698     int64 b_value = c->literal().GetFirstElement<T>();
3699     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
3700       // Handle negative dividends by negating the result of the division.
3701       HloInstruction* zero_like_a = BroadcastZeros(
3702           computation, a->shape().element_type(), a->shape().dimensions());
3703 
3704       auto* dividend_is_negative =
3705           computation->AddInstruction(HloInstruction::CreateCompare(
3706               ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
3707               ComparisonDirection::kLt));
3708 
3709       auto* negated_dividend = computation->AddInstruction(
3710           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
3711 
3712       auto* abs_dividend =
3713           computation->AddInstruction(HloInstruction::CreateTernary(
3714               a->shape(), HloOpcode::kSelect, dividend_is_negative,
3715               negated_dividend, a));
3716 
3717       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
3718           remainder->shape(), HloOpcode::kAnd, abs_dividend,
3719           MakeScalarLike(abs_dividend, b_value - 1)));
3720 
3721       auto* neqated_quotient =
3722           computation->AddInstruction(HloInstruction::CreateUnary(
3723               quotient->shape(), HloOpcode::kNegate, quotient));
3724 
3725       return HloInstruction::CreateTernary(
3726           remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
3727           neqated_quotient, quotient);
3728     }
3729   } else {
3730     uint64 b_value = c->literal().GetFirstElement<T>();
3731     if (IsPowerOfTwo(b_value)) {
3732       HloInstruction* mask_amount = computation->AddInstruction(
3733           simplifier->CreateConstantWithLayoutUpdated(
3734               LiteralUtil::CreateR0<T>(b_value - 1)));
3735       if (!ShapeUtil::IsScalar(b->shape())) {
3736         mask_amount = computation->AddInstruction(
3737             HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
3738       }
3739       return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
3740                                           a, mask_amount);
3741     }
3742   }
3743   return nullptr;
3744 }
3745 }  // namespace
3746 
HandleRemainder(HloInstruction * remainder)3747 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
3748   HloInstruction *a, *b;
3749   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3750 
3751   // (A % B) % B == A % B.
3752   if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
3753     return ReplaceInstruction(remainder, a);
3754   }
3755 
3756   // A % B => A & (B - 1) if B is a power of 2.
3757   switch (remainder->shape().element_type()) {
3758     case S8:
3759       if (std::unique_ptr<HloInstruction> shift =
3760               TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
3761         return ReplaceWithNewInstruction(remainder, std::move(shift));
3762       }
3763       break;
3764     case S16:
3765       if (std::unique_ptr<HloInstruction> shift =
3766               TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
3767         return ReplaceWithNewInstruction(remainder, std::move(shift));
3768       }
3769       break;
3770     case S32:
3771       if (std::unique_ptr<HloInstruction> shift =
3772               TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
3773         return ReplaceWithNewInstruction(remainder, std::move(shift));
3774       }
3775       break;
3776     case S64:
3777       if (std::unique_ptr<HloInstruction> shift =
3778               TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
3779         return ReplaceWithNewInstruction(remainder, std::move(shift));
3780       }
3781       break;
3782     case U8:
3783       if (std::unique_ptr<HloInstruction> shift =
3784               TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
3785         return ReplaceWithNewInstruction(remainder, std::move(shift));
3786       }
3787       break;
3788     case U16:
3789       if (std::unique_ptr<HloInstruction> shift =
3790               TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
3791         return ReplaceWithNewInstruction(remainder, std::move(shift));
3792       }
3793       break;
3794     case U32:
3795       if (std::unique_ptr<HloInstruction> shift =
3796               TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
3797         return ReplaceWithNewInstruction(remainder, std::move(shift));
3798       }
3799       break;
3800     case U64:
3801       if (std::unique_ptr<HloInstruction> shift =
3802               TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
3803         return ReplaceWithNewInstruction(remainder, std::move(shift));
3804       }
3805       break;
3806     default:
3807       break;
3808   }
3809 
3810   // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
3811   //
3812   // Currently this only covers the case when N is a broadcasted constant
3813   // scalar.  We could also cover the case when N is a non-broadcasted constant
3814   // with the same value repeated.
3815   HloInstruction* iota;
3816   HloInstruction* divisor;
3817   if (Match(remainder,
3818             m::Remainder(m::Iota(&iota),
3819                          m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
3820     // The iota counts {0, ..., iota_upper_bound - 1}.  (Actually this is
3821     // conservative; the iota may overflow and count up to a smaller value than
3822     // this.  But that's OK for our purposes here.)
3823     int64 iota_upper_bound = iota->shape().dimensions(
3824         Cast<HloIotaInstruction>(iota)->iota_dimension());
3825     absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3826         std::vector<int64>(0, divisor->shape().dimensions_size()));
3827     if (divisor_val && *divisor_val >= iota_upper_bound) {
3828       return ReplaceInstruction(remainder, iota);
3829     }
3830   }
3831 
3832   // (X + N) % N = X % N, so long as X + N does not overflow.
3833   //
3834   // We don't have range tracking in XLA that would let us know whether X + N
3835   // overflows, so for now we only do this simplification when X is an iota.  We
3836   // could add other operations where it's easy to see a range, such as
3837   // remainder, convert, etc., though at some point we'd probably want a
3838   // range-tracking analysis.
3839   HloInstruction* bcast;
3840   HloInstruction* addend;
3841   if (Match(
3842           remainder,
3843           m::Remainder(
3844               m::AddAnyOrder(m::Iota(&iota),
3845                              m::Broadcast(m::ConstantEffectiveScalar(&addend))),
3846               m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
3847       addend == divisor) {
3848     // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
3849     // that iota_upper_bound is conservative, and the true upper bound may be
3850     // smaller.
3851     int64 iota_upper_bound = iota->shape().dimensions(
3852         Cast<HloIotaInstruction>(iota)->iota_dimension());
3853     absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3854         std::vector<int64>(0, divisor->shape().dimensions_size()));
3855     if (divisor_val) {
3856       // Check whether divisor_val + iota_upper_bound - 1 overflows.
3857       absl::optional<int64> max_val =
3858           OverflowSafeAdd(*divisor_val, iota_upper_bound);
3859       if (max_val.has_value() &&
3860           FitsInIntegralType(*max_val, iota->shape().element_type())) {
3861         return ReplaceWithNewInstruction(
3862             remainder,
3863             HloInstruction::CreateBinary(remainder->shape(),
3864                                          HloOpcode::kRemainder, iota, bcast));
3865       }
3866     }
3867   }
3868 
3869   return Status::OK();
3870 }
3871 
HandleReshape(HloInstruction * reshape)3872 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
3873   auto operand = reshape->mutable_operand(0);
3874 
3875   // Reshape directly to empty constant if the shape contains zero-element
3876   // dimension.
3877   if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
3878     // If the instruction doesn't have a layout, use a default layout for
3879     // the literal result.
3880     Shape reshaped_shape = reshape->shape();
3881     if (!LayoutUtil::HasLayout(reshaped_shape)) {
3882       LayoutUtil::SetToDefaultLayout(&reshaped_shape);
3883     }
3884     auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
3885         Literal::CreateFromShape(reshaped_shape));
3886 
3887     return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
3888   }
3889 
3890   // Delete no-op reshapes, i.e. where shape = operand shape.
3891   if (SameShape(reshape, operand)) {
3892     VLOG(10) << "deleting no-op reshape";
3893     return ReplaceInstruction(reshape, operand);
3894   }
3895 
3896   // Merge reshapes.
3897   if (HloOpcode::kReshape == operand->opcode()) {
3898     return ReplaceWithNewInstruction(
3899         reshape, HloInstruction::CreateReshape(reshape->shape(),
3900                                                operand->mutable_operand(0)));
3901   }
3902 
3903   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3904     *operand->mutable_shape() = reshape->shape();
3905     return ReplaceInstruction(reshape, operand);
3906   }
3907 
3908   if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
3909     auto opt_dims = ReshapeLeavesDimensionsUnmodified(
3910         reshape, reshape->operand(0)->dimensions());
3911     if (opt_dims.has_value()) {
3912       return ReplaceWithNewInstruction(
3913           reshape,
3914           HloInstruction::CreateBroadcast(
3915               reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
3916               *opt_dims));
3917     }
3918   }
3919 
3920   // reshape(iota) -> iota or a mixed radix calculation like
3921   // s32[2,3,4] reshape(s32[24] iota()) to
3922   // add(
3923   //    add(s32[2,3,4] iota() iota_dimension=2,
3924   //        4 * s32[2,3,4] iota() iota_dimension=1),
3925   //    12 * s32[2,3,4] iota() iota_dimension=0).
3926   if (operand->opcode() == HloOpcode::kIota) {
3927     auto* iota = Cast<HloIotaInstruction>(operand);
3928     auto common_factors =
3929         CommonFactors(reshape->operand(0)->shape().dimensions(),
3930                       reshape->shape().dimensions());
3931     auto iota_dim = absl::c_find_if(
3932         common_factors, [&](const std::pair<int64, int64>& dim_pair) {
3933           return dim_pair.first == iota->iota_dimension() &&
3934                  reshape->shape().dimensions(dim_pair.second) > 1;
3935         });
3936     auto next_dim = absl::c_find_if(
3937         common_factors, [&](const std::pair<int64, int64>& dim_pair) {
3938           return dim_pair.first == iota->iota_dimension() + 1;
3939         });
3940     if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
3941       int64 multiplier = 1;
3942       HloInstruction* new_reshape = nullptr;
3943 
3944       for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
3945            --dim) {
3946         HloInstruction* new_iota = computation_->AddInstruction(
3947             HloInstruction::CreateIota(reshape->shape(), dim));
3948         iota->SetupDerivedInstruction(new_iota);
3949         if (new_reshape) {
3950           new_reshape =
3951               computation_->AddInstruction(HloInstruction::CreateBinary(
3952                   reshape->shape(), HloOpcode::kAdd, new_reshape,
3953                   computation_->AddInstruction(HloInstruction::CreateBinary(
3954                       reshape->shape(), HloOpcode::kMultiply, new_iota,
3955                       MakeScalarLike(reshape, multiplier)))));
3956           reshape->SetupDerivedInstruction(new_reshape);
3957         } else {
3958           new_reshape = new_iota;
3959         }
3960         multiplier *= reshape->shape().dimensions(dim);
3961       }
3962       reshape->SetupDerivedInstruction(new_reshape);
3963       return ReplaceInstruction(reshape, new_reshape);
3964     }
3965   }
3966 
3967   // Make this a bitcast if possible.
3968   if (HloInstruction* bitcast_operand =
3969           BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
3970     ReplaceWithBitcast(reshape, bitcast_operand);
3971   }
3972   return Status::OK();
3973 }
3974 
HandleReverse(HloInstruction * reverse)3975 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
3976   // When all the dimensions to reverse are trivial (i.e. the bound is 1),
3977   // there is nothing to be done.
3978   auto dim_is_one = [&](int64 i) -> bool {
3979     return reverse->shape().dimensions(i) == 1;
3980   };
3981   if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
3982     return ReplaceInstruction(reverse, reverse->mutable_operand(0));
3983   }
3984   return Status::OK();
3985 }
3986 
TrySimplifyScalarSlice(HloInstruction * slice)3987 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
3988     HloInstruction* slice) {
3989   // Only try to do this for effective scalars. We could do the same for slicing
3990   // out larger pieces of padding (replacing with a broadcast of the padding
3991   // value), but this is probably not worth it.
3992   if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
3993     return false;
3994   }
3995 
3996   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
3997     VLOG(10) << "Trying to simplify scalar slice of concat";
3998     // Only do this for R1, there's no chance of this being useful otherwise.
3999     if (slice->shape().rank() != 1) {
4000       VLOG(10) << "Not folding, slice is not rank 1";
4001       return false;
4002     }
4003     HloConcatenateInstruction* concat =
4004         Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4005     int64 operand_start = 0;
4006     int64 operand_num = 0;
4007     // Weird loop structure to avoid annoying off-by-one errors.
4008     while (true) {
4009       TF_RET_CHECK(operand_num < concat->operand_count());
4010       const HloInstruction* operand = concat->operand(operand_num);
4011       int64 next_operand_start = operand_start + operand->shape().dimensions(0);
4012       if (next_operand_start > slice->slice_starts(0)) {
4013         break;
4014       }
4015       operand_start = next_operand_start;
4016       operand_num++;
4017     }
4018 
4019     bool replaced = ReplaceInstructionIfSameShape(
4020         slice, concat->mutable_operand(operand_num));
4021     if (replaced) {
4022       VLOG(10) << "Folding scalar slice of concat into concat operand";
4023     } else {
4024       VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4025       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4026           slice, HloInstruction::CreateSlice(
4027                      slice->shape(), concat->mutable_operand(operand_num),
4028                      {slice->slice_starts(0) - operand_start},
4029                      {slice->slice_starts(0) - operand_start + 1},
4030                      slice->slice_strides())));
4031     }
4032     return true;
4033   }
4034 
4035   return false;
4036 }
4037 
TryToReorderSliceAndReshape(HloInstruction * slice)4038 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4039     HloInstruction* slice) {
4040   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4041   if (!IsUnstridedSlice(slice)) {
4042     return false;
4043   }
4044   HloInstruction* reshape = slice->mutable_operand(0);
4045   if (reshape->opcode() != HloOpcode::kReshape) {
4046     return false;
4047   }
4048   HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4049   int64 slice_rank = slice->shape().rank();
4050   std::vector<int64> sliced_dims;
4051   for (int64 i = 0; i < slice_rank; ++i) {
4052     if (slice->slice_starts(i) != 0 ||
4053         slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4054       sliced_dims.push_back(i);
4055     }
4056   }
4057 
4058   if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4059       slice->slice_starts(0) == 0) {
4060     const Shape& new_slice_shape = new_slice_operand->shape();
4061     const int64 rank = new_slice_shape.rank();
4062     std::vector<int64> new_slice_starts(rank, 0);
4063     std::vector<int64> new_slice_stides(rank, 1);
4064     std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
4065                                         new_slice_shape.dimensions().end());
4066     int64 slice_elements = ShapeUtil::ElementsIn(slice->shape());
4067     for (int64 i = rank - 1; i >= 0; --i) {
4068       if (slice_elements >= new_slice_limits[i]) {
4069         if (slice_elements % new_slice_limits[i] != 0) {
4070           return false;
4071         }
4072         slice_elements /= new_slice_limits[i];
4073       } else {
4074         new_slice_limits[i] = slice_elements;
4075         slice_elements = 1;
4076       }
4077     }
4078     HloInstruction* new_slice =
4079         computation_->AddInstruction(HloInstruction::CreateSlice(
4080             ShapeUtil::MakeShape(new_slice_shape.element_type(),
4081                                  new_slice_limits),
4082             new_slice_operand, new_slice_starts, new_slice_limits,
4083             new_slice_stides));
4084     simplifier_->UpdateLayout(new_slice->mutable_shape());
4085     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4086         slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4087     return true;
4088   }
4089   return false;
4090 }
4091 
4092 // Allowing a slice to move through a reverse with any necessary updates to the
4093 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4094 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4095     HloInstruction* slice) {
4096   VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4097           << slice->ToString();
4098   if (Match(slice, m::Slice(m::Reverse()))) {
4099     HloInstruction* reverse = slice->mutable_operand(0);
4100     HloInstruction* reverse_operand = reverse->mutable_operand(0);
4101     std::vector<int64> new_starts = slice->slice_starts();
4102     std::vector<int64> new_limits = slice->slice_limits();
4103     std::vector<int64> new_strides = slice->slice_strides();
4104     for (auto rdim : reverse->dimensions()) {
4105       int64 start = slice->slice_starts(rdim);
4106       int64 limit = slice->slice_limits(rdim);
4107       int64 stride = slice->slice_strides(rdim);
4108       // find_nth allows us to compute the appropriate index to begin
4109       // with during reverse even in the presence of non-unit strides
4110       int64 find_nth = (limit - start - 1) / stride;
4111       find_nth = start + find_nth * stride;
4112       limit = find_nth + 1;
4113       new_starts[rdim] =
4114           (reverse->shape().dimensions(rdim) - start) - (limit - start);
4115       new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4116       VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4117               << limit << " and new (start, limit):" << new_starts[rdim] << ","
4118               << new_limits[rdim];
4119     }
4120     // New slice formed from the reverse_operand, but strides and shape of the
4121     // slice output remains the same. New slice's starts and limits are updated
4122     // for ONLY the reversed dimensions as indicated above.
4123     HloInstruction* new_slice = computation_->AddInstruction(
4124         HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4125                                     new_limits, new_strides));
4126     simplifier_->UpdateLayout(new_slice->mutable_shape());
4127     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4128         slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4129                                              reverse->dimensions())));
4130     // We do not delete the old reverse, since there might be another
4131     // consumer of that reverse (i.e., full reverse output). DCE should take
4132     // care of any deletion that is necessary if there was no use of reverse.
4133     return true;
4134   }
4135   return false;
4136 }
4137 
HandleSlice(HloInstruction * slice)4138 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4139   // Delete no-op slices, i.e. where shape = operand shape.
4140   if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
4141     return Status::OK();
4142   }
4143 
4144   HloInstruction* pad;
4145   HloInstruction* pad_operand;
4146   if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4147     // Is the result of the slice the pad operand.
4148     bool slice_undoes_pad = true;
4149     // Can the slice be moved to the pad_operand without any padding being read.
4150     bool slice_inside_pad = true;
4151     // Does this slice slice out pading only.
4152     bool slice_in_padding = false;
4153     std::vector<int64> new_starts = slice->slice_starts();
4154     std::vector<int64> new_limits = slice->slice_limits();
4155     for (int64 i = 0; i < slice->shape().rank(); ++i) {
4156       const int64 start = slice->slice_starts(i);
4157       const int64 stride = slice->slice_strides(i);
4158       const int64 limit = slice->slice_limits(i);
4159       const int64 size = pad->shape().dimensions(i);
4160 
4161       const auto& dim = pad->padding_config().dimensions(i);
4162       const int64 low = dim.edge_padding_low();
4163       const int64 high = dim.edge_padding_high();
4164       const int64 interior = dim.interior_padding();
4165       const int64 edge = size - high;
4166 
4167       if (limit <= low || start >= edge) {
4168         slice_in_padding = true;
4169         break;
4170       }
4171 
4172       if (start != low || stride - 1 != interior) {
4173         slice_undoes_pad = false;
4174       }
4175 
4176       if (start < low || limit > edge || interior != 0 || stride != 1) {
4177         slice_inside_pad = false;
4178       }
4179       new_starts[i] -= low;
4180       new_limits[i] -= low;
4181     }
4182     if (slice_in_padding) {
4183       HloInstruction* broadcast =
4184           MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4185       *(broadcast->mutable_shape()) = slice->shape();
4186       return ReplaceInstruction(slice, broadcast);
4187     }
4188     if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
4189       return Status::OK();
4190     }
4191     if (slice_inside_pad) {
4192       TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4193                           MakeSliceHlo(pad_operand, new_starts, new_limits,
4194                                        slice->slice_strides()));
4195       *(new_slice->mutable_shape()) = slice->shape();
4196       return ReplaceInstruction(slice, new_slice);
4197     }
4198   }
4199 
4200   if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4201       IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4202     HloInstruction* operand_slice = slice->mutable_operand(0);
4203     std::vector<int64> new_slice_starts = slice->slice_starts();
4204     std::vector<int64> new_slice_limits = slice->slice_limits();
4205     for (int64 i = 0; i < new_slice_starts.size(); ++i) {
4206       new_slice_starts[i] += operand_slice->slice_starts(i);
4207       new_slice_limits[i] += operand_slice->slice_starts(i);
4208     }
4209     return ReplaceWithNewInstruction(
4210         slice, HloInstruction::CreateSlice(
4211                    slice->shape(), operand_slice->mutable_operand(0),
4212                    new_slice_starts, new_slice_limits, slice->slice_strides()));
4213   }
4214 
4215   auto only_broadcast_dims_sliced = [&] {
4216     if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4217       return false;
4218     }
4219     for (int64 dim : slice->operand(0)->dimensions()) {
4220       if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4221           slice->slice_limits(dim) !=
4222               slice->operand(0)->shape().dimensions(dim)) {
4223         return false;
4224       }
4225     }
4226     return true;
4227   };
4228   if (only_broadcast_dims_sliced()) {
4229     return ReplaceWithNewInstruction(
4230         slice,
4231         HloInstruction::CreateBroadcast(
4232             slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4233             slice->mutable_operand(0)->dimensions()));
4234   }
4235 
4236   TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4237   if (replaced) {
4238     return Status::OK();
4239   }
4240 
4241   HloInstruction* broadcast;
4242   HloInstruction* broadcast_operand;
4243   if (Match(slice,
4244             m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4245     std::vector<int64> new_slice_starts;
4246     std::vector<int64> new_slice_strides;
4247     std::vector<int64> new_slice_limits;
4248     new_slice_starts.reserve(broadcast_operand->shape().rank());
4249     new_slice_strides.reserve(broadcast_operand->shape().rank());
4250     new_slice_limits.reserve(broadcast_operand->shape().rank());
4251     for (int64 dim : broadcast->dimensions()) {
4252       new_slice_starts.push_back(slice->slice_starts(dim));
4253       new_slice_strides.push_back(slice->slice_strides(dim));
4254       new_slice_limits.push_back(slice->slice_limits(dim));
4255     }
4256     VLOG(3) << "Sink broadcast through slice";
4257     VLOG(3) << "Original slice: " << slice->ToString();
4258     VLOG(3) << "Original broadcast: " << broadcast->ToString();
4259     auto new_slice_shape = broadcast_operand->shape();
4260     for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) {
4261       int64 size_i = (new_slice_limits[i] - new_slice_starts[i] +
4262                       new_slice_strides[i] - 1) /
4263                      new_slice_strides[i];
4264       new_slice_shape.set_dimensions(i, size_i);
4265     }
4266     simplifier_->UpdateLayout(&new_slice_shape);
4267     HloComputation* computation = broadcast_operand->parent();
4268     auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice(
4269         new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4270         new_slice_strides));
4271     auto new_broadcast = HloInstruction::CreateBroadcast(
4272         slice->shape(), new_slice, broadcast->dimensions());
4273     VLOG(3) << "New slice: " << slice->ToString();
4274     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4275     return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
4276   }
4277 
4278   // Try to simplify concat -> slice to an operand of concat.
4279   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4280       IsUnstridedSlice(slice)) {
4281     auto concat = slice->operand(0);
4282     int64 concat_dim = concat->concatenate_dimension();
4283     int64 piece_start = 0;
4284     for (auto piece : concat->operands()) {
4285       if (!SameShape(piece, slice)) {
4286         piece_start += piece->shape().dimensions(concat_dim);
4287         continue;
4288       }
4289       if (slice->slice_starts(concat_dim) == piece_start) {
4290         return ReplaceInstruction(slice, piece);
4291       }
4292       piece_start += piece->shape().dimensions(concat_dim);
4293     }
4294   }
4295 
4296   // Do not try to reorder slices and reshapes after layout assignment as it may
4297   // be invalid.
4298   if (!options_.is_layout_sensitive()) {
4299     TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4300   }
4301   if (replaced) {
4302     return Status::OK();
4303   }
4304 
4305   bool reversed = false;
4306   if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4307     TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4308   }
4309   if (reversed) {
4310     return Status::OK();
4311   }
4312 
4313   return Status::OK();
4314 }
4315 
HandleRsqrt(HloInstruction * rsqrt)4316 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4317   VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4318            << rsqrt->ToString();
4319   HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4320   if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4321       IsAll(rsqrt_operand->operand(1), -2) &&
4322       IsPositive(rsqrt_operand, options_)) {
4323     return ReplaceWithNewInstruction(
4324         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4325                                            rsqrt_operand->mutable_operand(0)));
4326   }
4327 
4328   VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4329            << rsqrt->ToString();
4330   if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4331       IsAll(rsqrt_operand->operand(0), 1) &&
4332       IsPositive(rsqrt_operand->operand(1), options_)) {
4333     return ReplaceWithNewInstruction(
4334         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4335                                            rsqrt_operand->mutable_operand(1)));
4336   }
4337 
4338   return Status::OK();
4339 }
4340 
HandleDynamicSlice(HloInstruction * dynamic_slice)4341 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4342     HloInstruction* dynamic_slice) {
4343   auto operand = dynamic_slice->mutable_operand(0);
4344   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4345     return ReplaceInstruction(dynamic_slice, operand);
4346   }
4347   // DynamicSlice where operand has the same size as the output is simply equal
4348   // to operand.
4349   if (SameShape(operand, dynamic_slice)) {
4350     return ReplaceInstruction(dynamic_slice, operand);
4351   }
4352 
4353   HloInstruction* broadcast_operand;
4354   if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4355     std::vector<HloInstruction*> new_indices;
4356     new_indices.reserve(broadcast_operand->shape().rank());
4357     std::vector<int64> new_slice_sizes;
4358     new_slice_sizes.reserve(broadcast_operand->shape().rank());
4359 
4360     for (int64 dim : operand->dimensions()) {
4361       new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4362       new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4363     }
4364 
4365     VLOG(3) << "Sink broadcast through dynamic slice";
4366     VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4367     VLOG(3) << "Original broadcast: " << operand->ToString();
4368     HloInstruction* new_dynamic_slice = broadcast_operand;
4369     if (!new_slice_sizes.empty()) {
4370       auto new_ds_shape = broadcast_operand->shape();
4371       for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) {
4372         new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4373       }
4374       simplifier_->UpdateLayout(&new_ds_shape);
4375       HloComputation* computation = broadcast_operand->parent();
4376       new_dynamic_slice =
4377           computation->AddInstruction(HloInstruction::CreateDynamicSlice(
4378               new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4379     }
4380     auto new_broadcast = HloInstruction::CreateBroadcast(
4381         dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
4382     VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4383     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4384     return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
4385   }
4386 
4387   // Convert a dynamic slice into a slice if all offsets are constant and the
4388   // operand is not constant.
4389   if (operand->opcode() != HloOpcode::kConstant &&
4390       absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4391                                     dynamic_slice->operands().end()),
4392                      [](HloInstruction* operand) {
4393                        return operand->opcode() == HloOpcode::kConstant &&
4394                               ShapeUtil::ElementIsIntegral(operand->shape());
4395                      })) {
4396     const int64 rank = operand->shape().rank();
4397     std::vector<int64> slice_starts(rank);
4398     std::vector<int64> slice_limits(rank);
4399     std::vector<int64> slice_strides(rank, 1);
4400 
4401     for (int64 i = 0; i < rank; ++i) {
4402       absl::optional<int64> offset =
4403           dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4404       if (!offset || *offset < 0) {
4405         return Status::OK();
4406       }
4407       const int64 max_offset =
4408           dynamic_slice->operand(0)->shape().dimensions(i) -
4409           dynamic_slice->shape().dimensions(i);
4410       slice_starts[i] = std::min(max_offset, *offset);
4411       slice_limits[i] =
4412           std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4413     }
4414     return ReplaceWithNewInstruction(
4415         dynamic_slice,
4416         HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4417                                     slice_starts, slice_limits, slice_strides));
4418   }
4419   return Status::OK();
4420 }
4421 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)4422 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
4423     HloInstruction* dynamic_update_slice) {
4424   // Rewriting DynamicUpdateSlice when it matches
4425   // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
4426   // to a Pad(x, constant)
4427   // Only Broadcast considered currently, other ops need to be considered
4428   // in the future.
4429   HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
4430   HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
4431   HloInstruction* pad_value;
4432   if (Match(updated,
4433             m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
4434     auto updated_shape = updated->shape();
4435     auto update_shape = dus_update->shape();
4436     auto update_start_indx = dynamic_update_slice->operand(2);
4437     int64 offset = 0;
4438     bool compatible = true;
4439     // Whether the start indices to dynamic update slice is a list,
4440     // output of a tuple/concatenate, we setup the update_start_indx
4441     // appropriately.
4442     if (ShapeUtil::IsScalar(update_start_indx->shape())) {
4443       update_start_indx = dynamic_update_slice;
4444       offset = 2;
4445     } else {
4446       if (update_start_indx->opcode() == HloOpcode::kTuple ||
4447           update_start_indx->opcode() == HloOpcode::kConcatenate) {
4448         offset = 0;
4449       } else {
4450         compatible = false;
4451       }
4452     }
4453     PaddingConfig padding_config;
4454     if (compatible) {
4455       for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
4456         auto padding_config_dim = padding_config.add_dimensions();
4457         auto slice_dim_start = update_start_indx->operand(dim + offset);
4458         if (!Match(slice_dim_start, m::ConstantScalar())) {
4459           compatible = false;
4460           break;
4461         }
4462         VLOG(2) << "slice: " << slice_dim_start->ToString();
4463         absl::optional<int64> beg =
4464             slice_dim_start->literal().GetFirstInteger();
4465         if (!beg) {
4466           compatible = false;
4467           break;
4468         }
4469         VLOG(2) << "beg value: " << *beg;
4470         auto update_width = ShapeUtil::GetDimension(update_shape, dim);
4471         auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
4472         // Clamp beg so that it is non-negative.
4473         *beg = std::max<int64>(0, *beg);
4474         // Clamp beg so that it is in-bounds.
4475         *beg = std::min<int64>(bcast_width - update_width, *beg);
4476         VLOG(2) << "adjusted beg value: " << *beg;
4477         padding_config_dim->set_edge_padding_low(*beg);
4478         padding_config_dim->set_edge_padding_high(bcast_width -
4479                                                   (*beg + update_width));
4480         // dynamic_update_slice does not specify a stride
4481         padding_config_dim->set_interior_padding(0);
4482       }
4483     }
4484 
4485     if (compatible) {
4486       HloInstruction* pad =
4487           computation_->AddInstruction(HloInstruction::CreatePad(
4488               updated_shape, dus_update, pad_value, padding_config));
4489       VLOG(2) << dynamic_update_slice->ToString();
4490       VLOG(2) << " with pad:" << pad->ToString();
4491       VLOG(2) << " Computation before rewrite is: "
4492               << dynamic_update_slice->parent()->ToString();
4493       return ReplaceInstruction(dynamic_update_slice, pad);
4494     }
4495   }
4496 
4497   // DynamicUpdateSlice where operand and dus_update have the same size is
4498   // equal to dus_update.
4499   if (SameShape(dynamic_update_slice, dus_update)) {
4500     return ReplaceInstruction(dynamic_update_slice, dus_update);
4501   }
4502 
4503   // If any dimension of dus_update is 0, elide the DynamicUpdateSlice.  This
4504   // optimization becomes invalid should we later prefer to warn about out of
4505   // bound indices.
4506   if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
4507     return ReplaceInstruction(dynamic_update_slice, updated);
4508   }
4509   return Status::OK();
4510 }
4511 
HandleReduce(HloInstruction * hlo)4512 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
4513   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
4514   bool multi_output_reduce = reduce->shape().IsTuple();
4515   // For tuple reduce, we require all reduce shapes to be the same, up to the
4516   // element types, so we can just the first operand and the first result as a
4517   // representative.
4518   auto arg = reduce->inputs()[0];
4519   auto init_value = reduce->init_values()[0];
4520   const Shape& reduce_result_shape =
4521       multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
4522 
4523   absl::Span<const int64> dimensions(reduce->dimensions());
4524   HloComputation* function = reduce->to_apply();
4525   if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
4526       ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
4527     if (multi_output_reduce) {
4528       std::vector<HloInstruction*> broadcast_inits;
4529       int64 inputs = reduce->input_count();
4530       for (int64 i = 0; i < inputs; ++i) {
4531         broadcast_inits.push_back(computation_->AddInstruction(
4532             HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
4533                                             reduce->init_values()[i], {})));
4534       }
4535       return ReplaceWithNewInstruction(
4536           reduce, HloInstruction::CreateTuple(broadcast_inits));
4537     } else {
4538       return ReplaceWithNewInstruction(
4539           reduce,
4540           HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
4541     }
4542   }
4543 
4544   // Turn trivial variadic reductions into normal reductions.
4545   if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
4546       reduce->input_count() == 1 &&
4547       Match(function->root_instruction(), m::Tuple())) {
4548     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
4549         replacements;
4550     replacements[function->root_instruction()] = nullptr;
4551     auto new_function = computation_->parent()->AddEmbeddedComputation(
4552         function->CloneWithReplacements(
4553             std::move(replacements), /*extra_parameters=*/{},
4554             /*context=*/nullptr,
4555             /*suffix=*/"clone",
4556             /*new_root=*/function->root_instruction()->operand(0)));
4557     auto new_reduce = computation_->AddInstruction(
4558         HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
4559                                      reduce->dimensions(), new_function));
4560     return ReplaceWithNewInstruction(reduce,
4561                                      HloInstruction::CreateTuple({new_reduce}));
4562   }
4563 
4564   if (options_.is_layout_sensitive()) {
4565     return Status::OK();
4566   }
4567 
4568   // If the reduction results in the same number of elements, then the only
4569   // possible side effect would be a reshape. Since the init_value is an
4570   // identity of the reduction function, we can therefore replace the reduce
4571   // with a simple reshape, ignoring the reduction function completely.
4572   if (ShapeUtil::ElementsIn(reduce_result_shape) ==
4573       ShapeUtil::ElementsIn(arg->shape())) {
4574     if (multi_output_reduce) {
4575       std::vector<HloInstruction*> reshaped_args;
4576       int64 inputs = reduce->input_count();
4577       for (int64 i = 0; i < inputs; ++i) {
4578         reshaped_args.push_back(
4579             computation_->AddInstruction(HloInstruction::CreateReshape(
4580                 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
4581       }
4582       return ReplaceWithNewInstruction(
4583           reduce, HloInstruction::CreateTuple(reshaped_args));
4584     } else {
4585       return ReplaceWithNewInstruction(
4586           reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
4587     }
4588   }
4589 
4590   // TODO(b/131122694): Most of those optimizations below can be done for
4591   // multi-output reduces.
4592   if (multi_output_reduce) {
4593     return Status::OK();
4594   }
4595 
4596   // A Transpose feeding a reduce can simply permute the reduction dimensions
4597   // field if the output of the reduce is a vector or scalar. Higher ranked
4598   // result may require a transpose of the output.
4599   if (reduce_result_shape.rank() <= 1 &&
4600       arg->opcode() == HloOpcode::kTranspose) {
4601     auto transpose_dimensions = arg->dimensions();
4602     std::vector<int64> new_reduce_dimensions;
4603     for (auto dim : dimensions) {
4604       new_reduce_dimensions.push_back(transpose_dimensions[dim]);
4605     }
4606     return ReplaceWithNewInstruction(
4607         reduce, HloInstruction::CreateReduce(
4608                     reduce_result_shape, arg->mutable_operand(0), init_value,
4609                     new_reduce_dimensions, function));
4610   }
4611 
4612   // If a reduce feeds a reduce with the same computation and initial value,
4613   // they can be combined into a single reduce.
4614   if (arg->opcode() == HloOpcode::kReduce &&
4615       init_value->Identical(*arg->operand(1)) &&
4616       *function == *arg->to_apply()) {
4617     // Create a new reduce with the combined reduction dimensions of both
4618     // reduces.
4619     std::vector<int64> arg_dims = arg->dimensions();
4620     absl::c_sort(arg_dims);
4621     std::vector<int64> reduce_dims = reduce->dimensions();
4622     absl::c_sort(reduce_dims);
4623     // Transform reduce_dims to the same rank as the operand of the operand.
4624     for (int64 arg_dim : arg_dims) {
4625       for (int64& dim : reduce_dims) {
4626         if (dim >= arg_dim) {
4627           ++dim;
4628         }
4629       }
4630     }
4631     std::vector<int64> new_dimensions;
4632     new_dimensions.reserve(arg->dimensions().size() +
4633                            reduce->dimensions().size());
4634     std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
4635                reduce_dims.end(), std::back_inserter(new_dimensions));
4636     return ReplaceWithNewInstruction(
4637         reduce, HloInstruction::CreateReduce(
4638                     reduce_result_shape, arg->mutable_operand(0), init_value,
4639                     new_dimensions, function));
4640   }
4641 
4642   // A reshape that collapses multiple dimensions into a dimension being
4643   // reduced can just reduce all of those dimensions instead of doing a
4644   // collapsing reshape before a reduction.
4645   if (options_.enable_reduce_of_reshape() &&
4646       arg->opcode() == HloOpcode::kReshape) {
4647     std::vector<std::pair<int64, int64>> unmodified_dims =
4648         ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
4649                                                  arg->shape());
4650     std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
4651     std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
4652     for (auto dim : dimensions) {
4653       arg_dim_in_output[dim] = false;
4654     }
4655     for (auto dim_pair : unmodified_dims) {
4656       arg_dim_unmodified[dim_pair.second] = true;
4657     }
4658     // The goal is to verify that all dimensions that are not removed in the
4659     // reduce are unmodified by the reshape. For example:
4660     // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
4661     bool can_move_reshape_into_reduce = true;
4662     for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
4663       if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
4664         can_move_reshape_into_reduce = false;
4665       }
4666     }
4667     if (can_move_reshape_into_reduce) {
4668       changed_ = true;
4669       absl::flat_hash_set<int64> dimensions_not_to_reduce;
4670       for (auto dim_pair : unmodified_dims) {
4671         if (arg_dim_in_output[dim_pair.second]) {
4672           dimensions_not_to_reduce.insert(dim_pair.first);
4673         }
4674       }
4675       std::vector<int64> new_reduce_dimensions;
4676       for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
4677         if (!dimensions_not_to_reduce.contains(i)) {
4678           new_reduce_dimensions.push_back(i);
4679         }
4680       }
4681       return ReplaceWithNewInstruction(
4682           reduce, HloInstruction::CreateReduce(
4683                       reduce_result_shape, arg->mutable_operand(0), init_value,
4684                       new_reduce_dimensions, function));
4685     }
4686   }
4687   // Convert Reduce(concat({a,b,...})) to
4688   //  map(reduce(a),map(reduce(b),...,))
4689   //
4690   // This should make fusion easier or use less memory bandwidth in the unfused
4691   // case.
4692   if (arg->opcode() == HloOpcode::kConcatenate &&
4693       absl::c_linear_search(reduce->dimensions(),
4694                             arg->concatenate_dimension())) {
4695     HloInstruction* old_reduce = nullptr;
4696     for (HloInstruction* operand : arg->operands()) {
4697       HloInstruction* new_reduce = computation_->AddInstruction(
4698           HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
4699                                        reduce->dimensions(), function));
4700       if (old_reduce != nullptr) {
4701         new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
4702             reduce_result_shape, {old_reduce, new_reduce}, function));
4703       }
4704       old_reduce = new_reduce;
4705     }
4706     return ReplaceInstruction(reduce, old_reduce);
4707   }
4708 
4709   HloInstruction *dot, *lhs, *rhs;
4710   // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
4711   // batch dimensions of the dot. The transformation supports reducing other
4712   // dimensions as well.
4713   if (options_.enable_dot_strength_reduction() &&
4714       Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
4715       Match(reduce->to_apply()->root_instruction(),
4716             m::Add(m::Parameter(), m::Parameter())) &&
4717       absl::c_any_of(reduce->dimensions(), [&](int64 dim) {
4718         return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
4719       })) {
4720     const auto& dnums = dot->dot_dimension_numbers();
4721     DotDimensionNumbers new_dnums = dnums;
4722     new_dnums.clear_lhs_batch_dimensions();
4723     new_dnums.clear_rhs_batch_dimensions();
4724     int64 removed_dims = 0;
4725     for (int64 batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
4726          ++batch_dim) {
4727       if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
4728         new_dnums.add_rhs_contracting_dimensions(
4729             dnums.rhs_batch_dimensions(batch_dim));
4730         new_dnums.add_lhs_contracting_dimensions(
4731             dnums.lhs_batch_dimensions(batch_dim));
4732         ++removed_dims;
4733       } else {
4734         new_dnums.add_rhs_batch_dimensions(
4735             dnums.rhs_batch_dimensions(batch_dim));
4736         new_dnums.add_lhs_batch_dimensions(
4737             dnums.lhs_batch_dimensions(batch_dim));
4738       }
4739     }
4740     std::vector<int64> reduce_dims;
4741     for (int64 dim : reduce->dimensions()) {
4742       if (dim >= dnums.lhs_batch_dimensions_size()) {
4743         reduce_dims.push_back(dim - removed_dims);
4744       }
4745     }
4746     TF_ASSIGN_OR_RETURN(
4747         auto new_dot,
4748         MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
4749                    /*preferred_element_type=*/dot->shape().element_type()));
4750     dot->SetupDerivedInstruction(new_dot);
4751     if (reduce_dims.empty()) {
4752       return ReplaceInstruction(hlo, new_dot);
4753     }
4754     TF_ASSIGN_OR_RETURN(
4755         auto new_reduce,
4756         MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
4757     reduce->SetupDerivedInstruction(new_reduce);
4758     return ReplaceInstruction(hlo, new_reduce);
4759   }
4760   return Status::OK();
4761 }
4762 
HandleReduceWindow(HloInstruction * reduce_window)4763 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
4764     HloInstruction* reduce_window) {
4765   // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
4766   if (reduce_window->shape().IsTuple()) {
4767     return Status::OK();
4768   }
4769   if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
4770     return ReplaceWithNewInstruction(
4771         reduce_window,
4772         HloInstruction::CreateBroadcast(reduce_window->shape(),
4773                                         reduce_window->mutable_operand(1), {}));
4774   }
4775   auto operand = reduce_window->mutable_operand(0);
4776   const Window& window = reduce_window->window();
4777   auto function = reduce_window->to_apply();
4778   if (ShapeUtil::IsScalar(operand->shape())) {
4779     TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
4780     return ReplaceWithNewInstruction(
4781         reduce_window,
4782         HloInstruction::CreateMap(reduce_window->shape(),
4783                                   {reduce_window->mutable_operand(1), operand},
4784                                   function));
4785   }
4786 
4787   if (options_.enable_window_reduce_to_reduce_replacement()) {
4788     // A reduce window can be expressed as a reduce and a reshape if all
4789     // dimensions either have a window size of one or the entire dimension. If
4790     // there is no stride, dilation, or padding, this is as easy as checking the
4791     // size of the output shape and window dimension.
4792     //
4793     // The reshape is a bitcast since it adds one-sized dimensions. Often these
4794     // ones are immediately removed as well with another reshape. The
4795     // implementation of reduce tends to be slightly more efficient at reducing
4796     // entire dimensions compared to reduce window.
4797     auto effective_reduce_dims = [&] {
4798       if (window_util::HasStride(window) || window_util::HasDilation(window) ||
4799           window_util::HasPadding(window)) {
4800         return absl::InlinedVector<int64, 8>{};
4801       }
4802       absl::InlinedVector<int64, 8> reduce_dims;
4803       for (int64 i = 0; i < window.dimensions_size(); ++i) {
4804         if (window.dimensions(i).size() == 1) {
4805           continue;
4806         } else if (reduce_window->shape().dimensions(i) == 1) {
4807           reduce_dims.push_back(i);
4808         } else {
4809           return absl::InlinedVector<int64, 8>{};
4810         }
4811       }
4812       return reduce_dims;
4813     }();
4814 
4815     // If a reduce window can be expressed as a reduce, do so and reshape the
4816     // output.
4817     if (!effective_reduce_dims.empty()) {
4818       Shape reduce_shape = ShapeUtil::FilterDimensions(
4819           [&](int64 dim) {
4820             return !absl::c_linear_search(effective_reduce_dims, dim);
4821           },
4822           reduce_window->shape());
4823       simplifier_->UpdateLayout(&reduce_shape);
4824       HloInstruction* reduce =
4825           computation_->AddInstruction(HloInstruction::CreateReduce(
4826               /*shape=*/reduce_shape,
4827               /*operand=*/operand,
4828               /*init_value=*/reduce_window->mutable_operand(1),
4829               /*dimensions_to_reduce=*/effective_reduce_dims,
4830               /*reduce_computation=*/function));
4831       return ReplaceWithNewInstruction(
4832           reduce_window,
4833           HloInstruction::CreateReshape(reduce_window->shape(), reduce));
4834     }
4835   }
4836 
4837   // This optimization folds a pad op into reduce_window.
4838   HloInstruction* pad;
4839   const HloInstruction* convert = nullptr;
4840   if (operand->opcode() == HloOpcode::kPad) {
4841     pad = operand;
4842   } else if (operand->opcode() == HloOpcode::kConvert &&
4843              operand->operand(0)->opcode() == HloOpcode::kPad) {
4844     convert = operand;
4845     pad = operand->mutable_operand(0);
4846   } else {
4847     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
4848     return Status::OK();
4849   }
4850 
4851   VLOG(10) << "Considering folding Pad: " << pad->ToString()
4852            << "\ninto reduce-window: " << reduce_window->ToString()
4853            << (convert != nullptr
4854                    ? absl::StrCat("\nvia convert: ", convert->ToString())
4855                    : "");
4856 
4857   // Do not fold interior padding into ReduceWindow since the backends do not
4858   // support it.
4859   const PaddingConfig& pad_config = pad->padding_config();
4860   if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
4861     VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
4862     return Status::OK();
4863   }
4864 
4865   // If reduce_window already has padding, the pad value of the pad op and the
4866   // init value of reduce_window must match to allow folding the pad.
4867   const HloInstruction* pad_value = pad->operand(1);
4868   const HloInstruction* reduce_init_value = reduce_window->operand(1);
4869   if (pad_value != reduce_init_value) {
4870     auto literals_are_equivalent = [&] {
4871       auto& pad_literal = pad_value->literal();
4872       auto& reduce_init_literal = reduce_init_value->literal();
4873       if (pad_literal == reduce_init_literal) {
4874         return true;
4875       }
4876       auto converted_pad_literal =
4877           pad_literal.ConvertToShape(reduce_init_value->shape());
4878       if (!converted_pad_literal.ok()) {
4879         return false;
4880       }
4881       return converted_pad_literal.ValueOrDie() == reduce_init_literal;
4882     };
4883     // The pad value is usually a constant, so we handle that case and do not
4884     // try to get more fancy about proving equivalence in cases beyond that.
4885     if (pad_value->opcode() != HloOpcode::kConstant ||
4886         reduce_init_value->opcode() != HloOpcode::kConstant ||
4887         !literals_are_equivalent()) {
4888       VLOG(10) << "Not folding pad into reduce-window due to different pad "
4889                   "values.";
4890       return Status::OK();
4891     }
4892   }
4893 
4894   // If the pad puts a single non-identity value in each window that we're
4895   // reducing, then this is a broadcast.
4896   HloInstruction* pad_operand = pad->mutable_operand(0);
4897   auto is_effective_broadcast = [&] {
4898     if (window_util::HasStride(window)) {
4899       VLOG(10) << "Window has stride.";
4900       return false;
4901     }
4902     if (!window_util::HasSymmetricPadding(pad_config)) {
4903       VLOG(10) << "Window has uneven padding.";
4904       return false;
4905     }
4906     if (HasInteriorPadding(pad_config)) {
4907       VLOG(10) << "Window has interior padding.";
4908       return false;
4909     }
4910     for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
4911       const auto& pad_dimension = pad_config.dimensions(i);
4912       if ((pad_dimension.edge_padding_low() != 0 ||
4913            pad_dimension.edge_padding_high() != 0) &&
4914           pad_operand->shape().dimensions(i) != 1) {
4915         VLOG(10) << "Found non-trivial dimension being padded: " << i;
4916         return false;
4917       }
4918     }
4919     VLOG(10) << "Found to be padding trivial dimensions only.";
4920 
4921     for (int64 i = 0; i < window.dimensions_size(); ++i) {
4922       const auto& pad_dimension = pad_config.dimensions(i);
4923       const WindowDimension& window_dimension = window.dimensions(i);
4924       bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
4925                                     pad_dimension.edge_padding_high() != 0);
4926       if (dimension_has_padding &&
4927           window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
4928         VLOG(10) << "Found window did not cover single unpadded element in "
4929                     "dimension: "
4930                  << i;
4931         return false;
4932       }
4933       if (pad_operand->shape().dimensions(i) != 1 &&
4934           window_dimension.size() != 1) {
4935         VLOG(10) << "Found window covers more than one element in non-trivial "
4936                     "dimension: "
4937                  << i;
4938         return false;
4939       }
4940     }
4941     VLOG(10) << "Found window covers a single unpadded element.";
4942     return true;
4943   };
4944 
4945   HloInstruction* new_reduce_window_operand;
4946   if (convert != nullptr) {
4947     Shape changed_shape = ShapeUtil::ChangeElementType(
4948         pad_operand->shape(), convert->shape().element_type());
4949     simplifier_->UpdateLayout(&changed_shape);
4950     new_reduce_window_operand = computation_->AddInstruction(
4951         HloInstruction::CreateConvert(changed_shape, pad_operand));
4952   } else {
4953     new_reduce_window_operand = pad_operand;
4954   }
4955 
4956   if (is_effective_broadcast()) {
4957     VLOG(10) << "Replacing pad/reduce-window with broadcast.";
4958     auto fadd = [this](std::unique_ptr<HloInstruction> x) {
4959       return computation_->AddInstruction(std::move(x));
4960     };
4961     return ReplaceWithNewInstruction(
4962         reduce_window, HloInstruction::CreateBroadcastSequence(
4963                            /*output_shape=*/reduce_window->shape(),
4964                            /*operand=*/new_reduce_window_operand, fadd));
4965   }
4966 
4967   // Carry out the folding of the pad into reduce_window.
4968   VLOG(10) << "Folding pad into reduce-window.";
4969   Window new_window = window;
4970   const int64 rank = reduce_window->shape().rank();
4971   TF_RET_CHECK(pad_config.dimensions_size() == rank);
4972   TF_RET_CHECK(window.dimensions_size() == rank);
4973   for (int64 i = 0; i < rank; ++i) {
4974     const auto& pad_dim = pad_config.dimensions(i);
4975     auto& window_dim = *new_window.mutable_dimensions(i);
4976     window_dim.set_padding_low(window_dim.padding_low() +
4977                                pad_dim.edge_padding_low());
4978     window_dim.set_padding_high(window_dim.padding_high() +
4979                                 pad_dim.edge_padding_high());
4980     if (pad_dim.interior_padding() != 0) {
4981       CHECK_EQ(window_dim.base_dilation(), 1);
4982       window_dim.set_base_dilation(1 + pad_dim.interior_padding());
4983     }
4984   }
4985 
4986   return ReplaceWithNewInstruction(
4987       reduce_window, HloInstruction::CreateReduceWindow(
4988                          /*shape=*/reduce_window->shape(),
4989                          /*operand=*/new_reduce_window_operand,
4990                          /*init_value=*/reduce_window->mutable_operand(1),
4991                          /*window=*/new_window,
4992                          /*reduce_computation=*/function));
4993 }
4994 
HandleSelect(HloInstruction * select)4995 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
4996   // select(x, y, y) -> y.
4997   if (select->operand(1) == select->operand(2)) {
4998     return ReplaceInstruction(select, select->mutable_operand(1));
4999   }
5000   // select(true, x, y) -> x.
5001   if (IsAll(select->operand(0), true)) {
5002     return ReplaceInstruction(select, select->mutable_operand(1));
5003   }
5004   // select(false, x, y) -> y.
5005   if (IsAll(select->operand(0), false)) {
5006     return ReplaceInstruction(select, select->mutable_operand(2));
5007   }
5008   return Status::OK();
5009 }
5010 
HandleScatter(HloInstruction * scatter)5011 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) {
5012   if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) &&
5013       ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) {
5014     return Status::OK();
5015   }
5016   if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) &&
5017       SameShape(scatter, scatter->operand(0)) &&
5018       SameShape(scatter, scatter->operand(2))) {
5019     return ReplaceWithNewInstruction(
5020         scatter, HloInstruction::CreateMap(
5021                      scatter->shape(),
5022                      {scatter->mutable_operand(0), scatter->mutable_operand(2)},
5023                      scatter->to_apply()));
5024   }
5025   return Status::OK();
5026 }
HandleSort(HloInstruction * sort)5027 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5028   auto operand = sort->mutable_operand(0);
5029   int64 dimension_to_sort = sort->dimensions(0);
5030   if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5031       operand->shape().dimensions(dimension_to_sort) <= 1) {
5032     if (sort->operand_count() == 1) {
5033       return ReplaceInstruction(sort, operand);
5034     }
5035     // If it is key/value sort, the output of sort is a tuple.
5036     return ReplaceWithNewInstruction(
5037         sort, HloInstruction::CreateTuple(sort->operands()));
5038   }
5039   return Status::OK();
5040 }
5041 
HandleSqrt(HloInstruction * sqrt)5042 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
5043   VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
5044   HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
5045   if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
5046       sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
5047     return ReplaceWithNewInstruction(
5048         sqrt, HloInstruction::CreateUnary(
5049                   sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
5050                   sqrt_operand->mutable_operand(0)));
5051   }
5052   return Status::OK();
5053 }
5054 
5055 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64> perm)5056 bool OnlyPermutesDegenerateDims(const Shape& shape,
5057                                 absl::Span<const int64> perm) {
5058   std::vector<int64> new_permutation;
5059   int64 degenerate_count = 0;
5060   for (int64 i = 0; i < perm.size(); ++i) {
5061     if (shape.dimensions(i) != 1) {
5062       new_permutation.push_back(perm[i]);
5063     } else {
5064       ++degenerate_count;
5065     }
5066   }
5067   return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
5068 }
5069 }  // namespace
5070 
HandleTranspose(HloInstruction * transpose)5071 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
5072   auto operand = transpose->mutable_operand(0);
5073   if (std::is_sorted(transpose->dimensions().begin(),
5074                      transpose->dimensions().end())) {
5075     VLOG(10) << "deleting no-op transpose";
5076     return ReplaceInstruction(transpose, operand);
5077   }
5078 
5079   if (HloOpcode::kTranspose == operand->opcode()) {
5080     return ReplaceWithNewInstruction(
5081         transpose, HloInstruction::CreateTranspose(
5082                        transpose->shape(), operand->mutable_operand(0),
5083                        ComposePermutations(operand->dimensions(),
5084                                            transpose->dimensions())));
5085   }
5086 
5087   // Convert transpose(dot(a,b)) to dot(b,a).
5088   if (operand->opcode() == HloOpcode::kDot && operand->user_count() == 1 &&
5089       operand->shape().rank() == 2) {
5090     TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr<bool> {
5091       const auto& dnums = operand->dot_dimension_numbers();
5092       if (dnums.lhs_batch_dimensions_size() != 0) {
5093         return false;
5094       }
5095       HloInstruction* lhs = operand->mutable_operand(0);
5096       if (lhs->shape().rank() != 1 + dnums.lhs_contracting_dimensions_size()) {
5097         return false;
5098       }
5099       HloInstruction* rhs = operand->mutable_operand(1);
5100       if (rhs->shape().rank() != 1 + dnums.rhs_contracting_dimensions_size()) {
5101         return false;
5102       }
5103       DotDimensionNumbers new_dnums;
5104       *new_dnums.mutable_lhs_contracting_dimensions() =
5105           dnums.rhs_contracting_dimensions();
5106       *new_dnums.mutable_rhs_contracting_dimensions() =
5107           dnums.lhs_contracting_dimensions();
5108       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
5109           transpose, HloInstruction::CreateDot(transpose->shape(), /*lhs=*/rhs,
5110                                                /*rhs=*/lhs, new_dnums,
5111                                                operand->precision_config())));
5112       return true;
5113     }());
5114     if (did_transform) {
5115       return Status::OK();
5116     }
5117   }
5118 
5119   // Replace transpose with a reshape if more than one degenerate method is
5120   // permuted.
5121   if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
5122     return ReplaceWithNewInstruction(
5123         transpose, HloInstruction::CreateReshape(
5124                        transpose->shape(), transpose->mutable_operand(0)));
5125   }
5126 
5127   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
5128     *operand->mutable_shape() = transpose->shape();
5129     return ReplaceInstruction(transpose, operand);
5130   }
5131 
5132   if (options_.is_layout_sensitive() &&
5133       options_.replace_transpose_with_bitcast() &&
5134       TransposeIsBitcast(transpose)) {
5135     ReplaceWithBitcast(transpose);
5136     return Status::OK();
5137   }
5138 
5139   return Status::OK();
5140 }
5141 
FoldConvInputPad(HloInstruction * convolution)5142 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
5143     HloInstruction* convolution) {
5144   auto* lhs = convolution->mutable_operand(0);
5145   auto* rhs = convolution->mutable_operand(1);
5146   const auto& window = convolution->window();
5147   const ConvolutionDimensionNumbers& dnums =
5148       convolution->convolution_dimension_numbers();
5149 
5150   if (lhs->opcode() != HloOpcode::kPad) {
5151     return false;
5152   }
5153 
5154   // Convolution's padding is always zero, so bail if the kPad is adding
5155   // something other than zero.
5156   if (!IsAll(lhs->operand(1), 0)) {
5157     return false;
5158   }
5159 
5160   const auto& padding = lhs->padding_config();
5161 
5162   // Can't pad batch or feature dims.
5163   for (int64 dim :
5164        {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
5165     const auto& p = padding.dimensions(dim);
5166     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5167         p.interior_padding() != 0) {
5168       return false;
5169     }
5170   }
5171 
5172   // Compute the window which is the result of merging the kPad and the
5173   // convolution's existing window.
5174   Window new_window = window;
5175   for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
5176     auto& w = *new_window.mutable_dimensions(dim);
5177     const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
5178     // Edge padding composes with itself in the straightforward way, but
5179     // composing interior padding is nontrivial, and we cowardly refuse to
5180     // think about it. If we see interior padding in either the kPad or conv,
5181     // bail if there's any sort of padding in the other.
5182     if (p.interior_padding() != 0 &&
5183         (w.padding_low() != 0 || w.padding_high() != 0 ||
5184          w.base_dilation() != 1)) {
5185       return false;
5186     }
5187     if (w.base_dilation() != 1 &&
5188         (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5189          p.interior_padding() != 0)) {
5190       return false;
5191     }
5192 
5193     w.set_padding_low(w.padding_low() + p.edge_padding_low());
5194     w.set_padding_high(w.padding_high() + p.edge_padding_high());
5195     if (p.interior_padding() != 0) {
5196       CHECK_EQ(w.base_dilation(), 1);
5197       w.set_base_dilation(1 + p.interior_padding());
5198     }
5199   }
5200 
5201   auto new_conv = convolution->CloneWithNewOperands(
5202       convolution->shape(), {lhs->mutable_operand(0), rhs});
5203   new_conv->set_window(new_window);
5204   TF_RETURN_IF_ERROR(
5205       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5206   return true;
5207 }
5208 
FoldConvFilterPad(HloInstruction * convolution)5209 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
5210     HloInstruction* convolution) {
5211   auto* lhs = convolution->mutable_operand(0);
5212   auto* rhs = convolution->mutable_operand(1);
5213   const ConvolutionDimensionNumbers& dnums =
5214       convolution->convolution_dimension_numbers();
5215 
5216   if (rhs->opcode() != HloOpcode::kPad) {
5217     return false;
5218   }
5219 
5220   // Convolution's padding is always zero, so bail if the kPad is adding
5221   // something other than zero.
5222   if (!IsAll(rhs->operand(1), 0)) {
5223     return false;
5224   }
5225 
5226   const auto& padding = rhs->padding_config();
5227 
5228   // Can't pad or dilate feature dims.
5229   for (int64 dim : {dnums.kernel_input_feature_dimension(),
5230                     dnums.kernel_output_feature_dimension()}) {
5231     const auto& p = padding.dimensions(dim);
5232     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5233         p.interior_padding() != 0) {
5234       return false;
5235     }
5236   }
5237 
5238   // Compute the window which is the result of merging the kPad and the
5239   // convolution's existing window.
5240   Window new_window = convolution->window();
5241   for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
5242     auto& w = *new_window.mutable_dimensions(dim);
5243     const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
5244 
5245     // We can only do this transformation if p adds dilation to the filter --
5246     // edge padding on the filter is not supported in conv.
5247     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
5248       return false;
5249     }
5250 
5251     // Nothing to do if the kPad for this dim is entirely a nop.
5252     if (p.interior_padding() == 0) {
5253       continue;
5254     }
5255 
5256     // We cowardly refuse to think about how dilation composes with itself;
5257     // bail if both the kPad and conv have dilation on this dimension.
5258     if (w.window_dilation() > 1) {
5259       return false;
5260     }
5261     CHECK_EQ(w.window_dilation(), 1);
5262     w.set_window_dilation(1 + p.interior_padding());
5263     w.set_size(rhs->operand(0)->shape().dimensions(
5264         dnums.kernel_spatial_dimensions(dim)));
5265   }
5266 
5267   auto new_conv = convolution->CloneWithNewOperands(
5268       convolution->shape(), {lhs, rhs->mutable_operand(0)});
5269   new_conv->set_window(new_window);
5270   TF_RETURN_IF_ERROR(
5271       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5272   return true;
5273 }
5274 
SwapConvOperands(HloInstruction * convolution)5275 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
5276     HloInstruction* convolution) {
5277   if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
5278     return false;
5279   }
5280   if (convolution->feature_group_count() > 1 ||
5281       convolution->batch_group_count() > 1) {
5282     return false;
5283   }
5284 
5285   const auto& dnums = convolution->convolution_dimension_numbers();
5286   const auto& window_dims = convolution->window().dimensions();
5287   Window swapped_window;
5288 
5289   HloInstruction *input = convolution->mutable_operand(0),
5290                  *kernel = convolution->mutable_operand(1);
5291   int64 kernel_product = 1;
5292   int64 swapped_kernel_product = 1;
5293   DimensionVector reverse_dimensions;
5294   for (int64 spatial_dim = 0;
5295        spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
5296     const int64 kernel_size = window_dims[spatial_dim].size();
5297     const bool can_be_group_or_contraction =
5298         !window_dims[spatial_dim].window_reversal() &&
5299         window_dims[spatial_dim].padding_low() == 0 &&
5300         window_dims[spatial_dim].padding_high() == 0 &&
5301         window_dims[spatial_dim].window_dilation() == 1;
5302     const bool is_group_dim =
5303         can_be_group_or_contraction &&
5304         window_dims[spatial_dim].base_dilation() == kernel_size &&
5305         window_dims[spatial_dim].stride() == kernel_size - 1;
5306     const int64 input_size =
5307         input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
5308     const bool is_pure_contraction_dim =
5309         kernel_size == input_size && can_be_group_or_contraction &&
5310         window_dims[spatial_dim].base_dilation() == 1 &&
5311         window_dims[spatial_dim].stride() == 1;
5312     if (is_group_dim || is_pure_contraction_dim) {
5313       *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
5314       continue;
5315     }
5316 
5317     const int64 dilated_kernel_size =
5318         1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
5319     const int64 dilated_input_size =
5320         1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
5321 
5322     // Don't decide to swap if the input size is one, since many convolution
5323     // implementations can easily hand that special case efficiently.
5324     kernel_product *= kernel_size;
5325     swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
5326 
5327     auto new_dim = swapped_window.add_dimensions();
5328     new_dim->set_size(input_size);
5329     // If the kernel is not reversed, the activations must be manually reversed.
5330     if (!window_dims[spatial_dim].window_reversal()) {
5331       reverse_dimensions.push_back(
5332           dnums.kernel_spatial_dimensions(spatial_dim));
5333     }
5334     // The input is not originally reversed so it must be reversed to move the
5335     // kernel.
5336     new_dim->set_window_reversal(true);
5337     // Base dilation and window dilation switch places.
5338     new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
5339     new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
5340     new_dim->set_stride(window_dims[spatial_dim].stride());
5341     new_dim->set_padding_low(dilated_input_size +
5342                              window_dims[spatial_dim].padding_low() -
5343                              dilated_kernel_size);
5344     new_dim->set_padding_high(dilated_input_size +
5345                               window_dims[spatial_dim].padding_high() -
5346                               dilated_kernel_size);
5347   }
5348 
5349   // Don't transform if a naive convolution implementation would not have fewer
5350   // flops.
5351   if (kernel_product <= swapped_kernel_product) {
5352     return false;
5353   }
5354   ConvolutionDimensionNumbers swapped_dnums;
5355   *swapped_dnums.mutable_output_spatial_dimensions() =
5356       dnums.output_spatial_dimensions();
5357   // Swap batch and output feature of the output.
5358   swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
5359   swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
5360 
5361   // Swap input dnums with kernel dnums
5362   *swapped_dnums.mutable_input_spatial_dimensions() =
5363       dnums.kernel_spatial_dimensions();
5364   swapped_dnums.set_input_batch_dimension(
5365       dnums.kernel_output_feature_dimension());
5366   swapped_dnums.set_input_feature_dimension(
5367       dnums.kernel_input_feature_dimension());
5368 
5369   // Swap kernel dnums with input dnums
5370   *swapped_dnums.mutable_kernel_spatial_dimensions() =
5371       dnums.input_spatial_dimensions();
5372   swapped_dnums.set_kernel_output_feature_dimension(
5373       dnums.input_batch_dimension());
5374   swapped_dnums.set_kernel_input_feature_dimension(
5375       dnums.input_feature_dimension());
5376 
5377   PrecisionConfig precision_config;
5378   precision_config.add_operand_precision(
5379       convolution->precision_config().operand_precision(1));
5380   precision_config.add_operand_precision(
5381       convolution->precision_config().operand_precision(0));
5382   if (!reverse_dimensions.empty()) {
5383     TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
5384   }
5385   TF_ASSIGN_OR_RETURN(
5386       HloInstruction * new_convolution,
5387       MakeConvolveHlo(
5388           kernel, input, /*feature_group_count=*/1,
5389           /*batch_group_count=*/1, swapped_window, swapped_dnums,
5390           precision_config,
5391           /*preferred_element_type=*/convolution->shape().element_type()));
5392 
5393   convolution->SetupDerivedInstruction(new_convolution);
5394   TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
5395 
5396   return true;
5397 }
5398 
SimplifyConvToDot(HloInstruction * convolution)5399 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
5400     HloInstruction* convolution) {
5401   auto* lhs = convolution->mutable_operand(0);
5402   auto* rhs = convolution->mutable_operand(1);
5403   const auto& window = convolution->window();
5404   const ConvolutionDimensionNumbers& dnums =
5405       convolution->convolution_dimension_numbers();
5406 
5407   if (!options_.enable_conv_simplification()) {
5408     return false;
5409   }
5410 
5411   // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
5412   // layout-insensitive mode, for fear of adding nontrivial reshapes.
5413   if (!options_.is_layout_sensitive()) {
5414     return false;
5415   }
5416 
5417   const Shape& input_shape = lhs->shape();
5418   const Shape& filter_shape = rhs->shape();
5419   const Shape& convolution_shape = convolution->shape();
5420   TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
5421   TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
5422   TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
5423 
5424   // Require the spatial dimensions in the kernel to have a bound of one.
5425   for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
5426     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
5427       return false;
5428     }
5429   }
5430 
5431   // Stride ignores part of the output, which matrix multiplication does not do,
5432   // so require no stride. Padding and base (lhs) dilation both implicitly
5433   // extend the data, which matrix multiplication also does not do, so require
5434   // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
5435   // for a 1x1 window, so window dilation is no problem.
5436   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
5437       window_util::HasBaseDilation(window)) {
5438     return false;
5439   }
5440 
5441   // Also, the shapes must align for a rowmajor matmul:
5442   // - the input and output have the same layout.
5443   // - for input/output, the channel dimension must be the most minor. Other
5444   //   spatial dims can be in any order.
5445   // - for filters, the input channel dimension must be more major than the
5446   //   output channel dimension. The width+height don't matter because
5447   //   they are 1.
5448   //
5449   // These constraints are harsh. If the channel dimension is the most major
5450   // and/or the layout of input/output feature dimensions are reversed, we can
5451   // still convert Conv into more efficient Matmul with operand transposition
5452   // (such as the transposition flags in cuBLAS SGEMM).
5453   if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
5454       LayoutUtil::Minor(input_shape.layout(), 0) !=
5455           dnums.input_feature_dimension() ||
5456       LayoutUtil::Minor(convolution_shape.layout(), 0) !=
5457           dnums.output_feature_dimension() ||
5458       // The input feature dimension should come later in the minor-to-major
5459       // order.
5460       (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5461                            dnums.kernel_input_feature_dimension()) <
5462        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5463                            dnums.kernel_output_feature_dimension()))) {
5464     return false;
5465   }
5466 
5467   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
5468     std::vector<int64> dims(operand->shape().dimensions_size());
5469     std::iota(dims.begin(), dims.end(), 0);
5470     return computation_->AddInstruction(
5471         HloInstruction::CreateBitcast(shape, operand));
5472   };
5473 
5474   // Replace it with a dot, with bitcasts around it to get the right shape.
5475   const int64 input_channels =
5476       input_shape.dimensions(dnums.input_feature_dimension());
5477   const int64 output_channels =
5478       filter_shape.dimensions(dnums.kernel_output_feature_dimension());
5479 
5480   // Computes the product of the non-feature dimensions.
5481   int64 conv_width = 1;
5482   for (int i = 0; i < input_shape.dimensions_size(); ++i) {
5483     if (i != dnums.input_feature_dimension()) {
5484       conv_width *= input_shape.dimensions(i);
5485     }
5486   }
5487 
5488   // We already checked feature_dimension is most minor, so data in input_shape
5489   // and row-major {conv_width,input_channels} are bitwise identical.
5490   Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5491       input_shape.element_type(), {conv_width, input_channels});
5492   simplifier_->UpdateLayout(&new_input_shape);
5493   // We already checked input_feature_dimension is more major than
5494   // output_feature_dimension, so data in filter_shape and row-major
5495   // {input_channels,output_channels} are bitwise identical.
5496   Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5497       filter_shape.element_type(), {input_channels, output_channels});
5498   simplifier_->UpdateLayout(&new_filter_shape);
5499   Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5500       convolution_shape.element_type(), {conv_width, output_channels});
5501   simplifier_->UpdateLayout(&dot_output_shape);
5502 
5503   auto new_lhs = add_bitcast(new_input_shape, lhs);
5504   auto new_rhs = add_bitcast(new_filter_shape, rhs);
5505   DotDimensionNumbers dot_dimension_numbers;
5506   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
5507   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
5508   auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
5509       dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
5510       convolution->precision_config()));
5511 
5512   TF_RETURN_IF_ERROR(
5513       ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
5514   return true;
5515 }
5516 
HandleConvolution(HloInstruction * convolution)5517 Status AlgebraicSimplifierVisitor::HandleConvolution(
5518     HloInstruction* convolution) {
5519   if (options_.enable_scalar_multiply_reduction()) {
5520     TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
5521   }
5522 
5523   // Zero-sized input or filter.
5524   if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
5525       ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
5526     return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
5527   }
5528 
5529   // Try to merge padding/dilation of the input with the convolution's window.
5530   TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
5531   if (folded_input_pad) {
5532     return Status::OK();
5533   }
5534 
5535   // Try to merge dilation of the filter with the convolution's window.
5536   TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
5537   if (folded_filter_pad) {
5538     return Status::OK();
5539   }
5540 
5541   // Try to swap convolution operands.
5542   TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
5543   if (swapped) {
5544     return Status::OK();
5545   }
5546   // Try to replace the convolution with a kDot instruction.
5547   TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
5548   if (replaced_with_dot) {
5549     return Status::OK();
5550   }
5551 
5552   return Status::OK();
5553 }
5554 
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)5555 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
5556     HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
5557     HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
5558   // Ensure shapes of min and max operand are equal to match current shape
5559   // inference.
5560   if (!SameShape(min_operand, max_operand)) {
5561     return false;
5562   }
5563 
5564   auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
5565                                              max_operand, operand, min_operand);
5566   TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
5567   return true;
5568 }
5569 
HandleMap(HloInstruction * map)5570 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
5571   auto* map_computation = map->to_apply();
5572   auto* map_root = map_computation->root_instruction();
5573   if (map_root->opcode() == HloOpcode::kParameter) {
5574     ReplaceInstructionIfSameShape(
5575         map, map->mutable_operand(map_root->parameter_number()));
5576     return Status::OK();
5577   }
5578   if (map_root->opcode() == HloOpcode::kConstant) {
5579     if (!ShapeUtil::IsScalar(map_root->shape())) {
5580       return Status::OK();
5581     }
5582     auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
5583     if (ShapeUtil::IsScalar(map->shape())) {
5584       return ReplaceWithNewInstruction(map, std::move(clone));
5585     }
5586     return ReplaceWithNewInstruction(
5587         map,
5588         HloInstruction::CreateBroadcast(
5589             map->shape(), computation_->AddInstruction(std::move(clone)), {}));
5590   }
5591   // Inline the map if the map computation only contains an elementwise
5592   // operation that can accept arbitrary shapes.
5593   if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
5594     return Status::OK();
5595   }
5596   std::vector<HloInstruction*> new_operands;
5597   for (auto* root_operand : map_root->operands()) {
5598     if (root_operand->opcode() != HloOpcode::kParameter) {
5599       return Status::OK();
5600     }
5601     new_operands.push_back(
5602         map->mutable_operand(root_operand->parameter_number()));
5603   }
5604   auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
5605   return ReplaceWithNewInstruction(map, std::move(clone));
5606 }
5607 
Run(HloModule * module)5608 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
5609   XLA_VLOG_LINES(2,
5610                  "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
5611   bool changed = false;
5612   AlgebraicSimplifierVisitor visitor(options_, this);
5613   for (auto* comp : module->MakeNonfusionComputations()) {
5614     if (visitor.Run(comp, options_, this)) {
5615       changed = true;
5616     }
5617   }
5618   XLA_VLOG_LINES(2,
5619                  "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
5620   return changed;
5621 }
5622 
5623 }  // namespace xla
5624