• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/bits.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68 
69 namespace xla {
70 
71 namespace {
72 
73 namespace m = match;
74 
IsAll(const HloInstruction * op,int8_t value)75 bool IsAll(const HloInstruction* op, int8_t value) {
76   switch (op->opcode()) {
77     case HloOpcode::kBroadcast:
78       return IsAll(op->operand(0), value);
79     case HloOpcode::kConstant:
80       return op->literal().IsAll(value);
81     default:
82       return false;
83   }
84 }
85 
IsAnyOperandComplex(const HloInstruction * hlo)86 bool IsAnyOperandComplex(const HloInstruction* hlo) {
87   for (auto operand : hlo->operands()) {
88     if (ShapeUtil::ElementIsComplex(operand->shape())) {
89       return true;
90     }
91   }
92   return false;
93 }
94 
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)95 bool IsPositive(const HloInstruction* hlo,
96                 const AlgebraicSimplifierOptions& options) {
97   // Utility only handles real types.
98   if (IsAnyOperandComplex(hlo)) {
99     return false;
100   }
101   switch (hlo->opcode()) {
102     case HloOpcode::kGetTupleElement: {
103       const HloInstruction* gte_operand = hlo->operand(0);
104       switch (gte_operand->opcode()) {
105         case HloOpcode::kCustomCall: {
106           const auto& target = gte_operand->custom_call_target();
107           return target ==
108                      options.get_cudnn_batchnorm_forward_training_metadata() &&
109                  hlo->tuple_index() == 2;
110         }
111         default:
112           return false;
113       }
114     }
115     case HloOpcode::kPower:
116     case HloOpcode::kAbs:
117     case HloOpcode::kRsqrt:
118     case HloOpcode::kSqrt:
119       return IsPositive(hlo->operand(0), options);
120 
121     case HloOpcode::kMultiply: {
122       return hlo->operand(0) == hlo->operand(1) &&
123              IsPositive(hlo->operand(0), options);
124     }
125     default:
126       return false;
127   }
128 }
129 
GetConstantValue(const HloInstruction * inst)130 absl::optional<double> GetConstantValue(const HloInstruction* inst) {
131   if (!ShapeUtil::IsEffectiveScalar(inst->shape())) {
132     return absl::nullopt;
133   }
134   switch (inst->shape().element_type()) {
135     case F16:
136       return static_cast<float>(inst->literal().GetFirstElement<half>());
137     case BF16:
138       return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
139     case F32:
140       return inst->literal().GetFirstElement<float>();
141     case F64:
142       return inst->literal().GetFirstElement<double>();
143     default:
144       return absl::nullopt;
145   }
146 }
147 
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)148 bool IsNonNegative(const HloInstruction* hlo,
149                    const AlgebraicSimplifierOptions& options) {
150   // Utility only handles real types.
151   if (IsAnyOperandComplex(hlo)) {
152     return false;
153   }
154   switch (hlo->opcode()) {
155     case HloOpcode::kMultiply: {
156       return hlo->operand(0) == hlo->operand(1);
157     }
158     case HloOpcode::kAbs: {
159       return true;
160     }
161     case HloOpcode::kBroadcast: {
162       return IsNonNegative(hlo->operand(0), options);
163     }
164     case HloOpcode::kConstant: {
165       if (absl::optional<double> value = GetConstantValue(hlo)) {
166         return *value >= 0.0;
167       }
168       return false;
169     }
170     case HloOpcode::kMaximum: {
171       return IsNonNegative(hlo->operand(0), options) ||
172              IsNonNegative(hlo->operand(1), options);
173     }
174     case HloOpcode::kSelect: {
175       return IsNonNegative(hlo->operand(1), options) &&
176              IsNonNegative(hlo->operand(2), options);
177     }
178     default:
179       return IsPositive(hlo, options);
180   }
181 }
182 
183 // Checks whether `op` is a floating-point constant or broadcast of a constant
184 // of the form +/- 2^k for some integer k positive, negative, or zero.  Such
185 // values are interesting because multiplying by a power of 2 just moves the
186 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)187 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
188   // Unwrap the broadcast if necessary.
189   const HloInstruction* c;
190   if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
191       !Match(op, m::Broadcast(m::Constant(&c).WithShape(
192                      m::Shape().IsEffectiveScalar())))) {
193     return false;
194   }
195   auto val = [&]() -> absl::optional<double> {
196     switch (c->shape().element_type()) {
197       case BF16:
198         return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
199       case F16:
200         return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
201       case F32:
202         return c->literal().GetFirstElement<float>();
203       case F64:
204         return c->literal().GetFirstElement<double>();
205       default:
206         // Cowardly refuse to consider complex types.
207         return absl::nullopt;
208     }
209   }();
210   if (!val) {
211     return false;
212   }
213 
214   int exp;
215   double mantissa = std::frexp(*val, &exp);
216   // frexp returns a value in the range (-1, -0.5] U [0.5, 1).  A return value
217   // of +/-0.5 therefore indicates that the floating point value is a power of
218   // 2.
219   return mantissa == 0.5 || mantissa == -0.5;
220 }
221 
222 // Returns whether the given transpose produces a result which is bit-wise
223 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)224 bool TransposeIsBitcast(const HloInstruction* transpose) {
225   CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
226   const HloInstruction* operand = transpose->operand(0);
227   return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
228                                        transpose->dimensions());
229 }
230 
231 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)232 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
233     HloInstruction* instr, HloInstruction* operand,
234     const AlgebraicSimplifierOptions& options) {
235   // Can't replace chain of copies and reshapes with bitcasts if the compiler
236   // used a memory layout which isn't compatible.
237   if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
238     return operand;
239   }
240 
241   // If the operand is a copy or reshape try to see if the operand's operand
242   // would produce a bitcast with initial instruction.
243   if (HloOpcode::kReshape == operand->opcode() ||
244       HloOpcode::kCopy == operand->opcode()) {
245     return BitcastingOperandOfReshapeOrCopyChainHelper(
246         instr, operand->mutable_operand(0), options);
247   }
248   return nullptr;
249 }
250 
251 // Returns an operand of a chain of reshapes and copies that is bit-wise
252 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)253 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
254     HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
255   if (!options.is_layout_sensitive()) {
256     return nullptr;
257   }
258   CHECK(HloOpcode::kReshape == instr->opcode() ||
259         HloOpcode::kCopy == instr->opcode());
260   return BitcastingOperandOfReshapeOrCopyChainHelper(
261       instr, instr->mutable_operand(0), options);
262 }
263 
IsUnstridedSlice(const HloInstruction * hlo)264 bool IsUnstridedSlice(const HloInstruction* hlo) {
265   return absl::c_all_of(hlo->slice_strides(),
266                         [](int64_t stride) { return stride == 1; });
267 }
268 
269 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)270 bool IsConvertPairNoOp(const HloInstruction* convert) {
271   //    [operand_convert]         [convert]
272   // (src)->convert-(intermediate)->convert-(dest)
273   const HloInstruction* operand_convert = convert->operand(0);
274   CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
275   const Shape& src_shape = operand_convert->operand(0)->shape();
276   const Shape& intermediate_shape = operand_convert->shape();
277   const Shape& dest_shape = convert->shape();
278 
279   const PrimitiveType src_type = src_shape.element_type();
280   const PrimitiveType intermediate_type = intermediate_shape.element_type();
281   const PrimitiveType dest_type = dest_shape.element_type();
282 
283   // src_type must be equal to dest_type.
284   if (src_type != dest_type) {
285     return false;
286   }
287 
288   // src_type must be a larger container than intermediate_type.
289   if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
290       ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
291     return false;
292   }
293 
294   // Both src_type and intermediate_type must be either floating or integral.
295   bool is_conversion_floating =
296       ShapeUtil::ElementIsFloating(src_shape) &&
297       ShapeUtil::ElementIsFloating(intermediate_shape);
298   bool is_conversion_integral =
299       ShapeUtil::ElementIsIntegral(src_shape) &&
300       ShapeUtil::ElementIsIntegral(intermediate_shape);
301 
302   return is_conversion_floating || is_conversion_integral;
303 }
304 
305 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
306 // algebraic expressions to simplified forms. Note: This only supports
307 // simplifications that simply look at the operands of an instruction. For the
308 // more general case a worklist based approach would be needed.
309 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
310  public:
AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)311   explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
312                                       AlgebraicSimplifier* simplifier)
313       : options_(options), simplifier_(simplifier) {}
314 
315   Status HandleAbs(HloInstruction* abs) override;
316 
317   Status HandleAdd(HloInstruction* add) override;
318 
319   Status HandleAnd(HloInstruction* logical_and) override;
320 
321   Status HandleBitcast(HloInstruction* bitcast) override;
322 
323   Status HandleBitcastConvert(HloInstruction* bitcast) override;
324 
325   Status HandleBroadcast(HloInstruction* broadcast) override;
326 
327   Status HandleCompare(HloInstruction* compare) override;
328 
329   Status HandleConcatenate(HloInstruction* concatenate) override;
330 
331   Status HandleConstant(HloInstruction* constant) override;
332 
333   Status HandleCopy(HloInstruction* copy) override;
334 
335   Status HandleConvert(HloInstruction* convert) override;
336 
337   Status HandleComplex(HloInstruction* complex) override;
338 
339   Status HandleReal(HloInstruction* real) override;
340 
341   Status HandleImag(HloInstruction* imag) override;
342 
343   Status HandleIota(HloInstruction* instruction) override;
344 
345   Status HandleConvolution(HloInstruction* convolution) override;
346 
347   Status HandleDivide(HloInstruction* divide) override;
348 
349   Status HandleDot(HloInstruction* dot) override;
350 
351   Status HandleGather(HloInstruction* gather) override;
352 
353   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
354 
355   Status HandleLog(HloInstruction* log) override;
356 
357   Status HandleMaximum(HloInstruction* maximum) override;
358 
359   Status HandleMinimum(HloInstruction* minimum) override;
360 
361   Status HandleClamp(HloInstruction* clamp) override;
362 
363   Status HandleMultiply(HloInstruction* multiply) override;
364 
365   Status HandleNegate(HloInstruction* negate) override;
366 
367   Status HandleNot(HloInstruction* logical_not) override;
368 
369   Status HandleOr(HloInstruction* logical_or) override;
370 
371   Status HandlePad(HloInstruction* pad) override;
372 
373   Status HandlePower(HloInstruction* power) override;
374 
375   Status HandleRemainder(HloInstruction* remainder) override;
376 
377   Status HandleReshape(HloInstruction* reshape) override;
378 
379   Status HandleReduce(HloInstruction* hlo) override;
380 
381   Status HandleReduceWindow(HloInstruction* hlo) override;
382 
383   Status HandleReverse(HloInstruction* reverse) override;
384 
385   Status HandleRsqrt(HloInstruction* rsqrt) override;
386 
387   Status HandleSlice(HloInstruction* slice) override;
388 
389   Status HandleSqrt(HloInstruction* sqrt) override;
390 
391   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
392 
393   Status HandleDynamicUpdateSlice(
394       HloInstruction* dynamic_update_slice) override;
395   Status HandleScatter(HloInstruction* scatter) override;
396 
397   Status HandleSelect(HloInstruction* select) override;
398 
399   Status HandleSort(HloInstruction* sort) override;
400 
401   Status HandleTranspose(HloInstruction* transpose) override;
402 
403   Status HandleSubtract(HloInstruction* sub) override;
404 
405   Status HandleMap(HloInstruction* map) override;
406 
407   // Runs the visitor on a computation.
408   bool Run(HloComputation* computation,
409            const AlgebraicSimplifierOptions& options,
410            AlgebraicSimplifier* simplifier);
411 
412  private:
413   // Removes degenerate dimension from dot.
414   StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
415 
416   // Converts to primitive type if the input hlo is not that type, otherwise
417   // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)418   HloInstruction* AsType(HloInstruction* hlo,
419                          const PrimitiveType element_type) {
420     if (hlo->shape().element_type() == element_type) {
421       return hlo;
422     }
423     Shape changed_shape =
424         ShapeUtil::ChangeElementType(hlo->shape(), element_type);
425     simplifier_->UpdateLayout(&changed_shape);
426     return computation_->AddInstruction(
427         HloInstruction::CreateConvert(changed_shape, hlo));
428   }
429 
430   // Transposes a dot operand such that the batch dimensions are the most major,
431   // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)432   StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
433       HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
434       absl::Span<const int64> contracting_dimensions) {
435     std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
436                                             batch_dimensions.end());
437     for (int64_t i = 0; i < dot_operand->shape().rank(); ++i) {
438       if (!(absl::c_linear_search(batch_dimensions, i) ||
439             absl::c_linear_search(contracting_dimensions, i))) {
440         transpose_dimensions.push_back(i);
441       }
442     }
443     transpose_dimensions.insert(transpose_dimensions.end(),
444                                 contracting_dimensions.begin(),
445                                 contracting_dimensions.end());
446     if (absl::c_is_sorted(transpose_dimensions)) {
447       return dot_operand;
448     }
449     return MakeTransposeHlo(dot_operand, transpose_dimensions);
450   }
451 
452   // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims,PrimitiveType type)453   HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims,
454                             PrimitiveType type) {
455     HloInstruction* zero = computation_->AddInstruction(
456         simplifier_->CreateConstantWithLayoutUpdated(
457             LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
458     HloComputation* AddReduce_computation =
459         GetOrCreateScalarAddComputation(type);
460     Shape shape = ShapeUtil::FilterDimensions(
461         [&](int64_t dim) { return !absl::c_linear_search(dims, dim); },
462         hlo->shape());
463     simplifier_->UpdateLayout(&shape);
464     return computation_->AddInstruction(HloInstruction::CreateReduce(
465         shape, hlo, zero, dims, AddReduce_computation));
466   }
467 
468   // Move scalar multiply to the smallest side of convolution to
469   // reduce multiply computations.
470   Status ScalarMultiplyReduction(HloInstruction* dot);
471 
472   // Convenience method for replacing an instruction with a bitcast. If operand
473   // is not null, then the bitcast will use the specified operand instead of the
474   // operand of the instruction.
475   void ReplaceWithBitcast(HloInstruction* instruction,
476                           HloInstruction* operand = nullptr);
477 
478   // Replace old instruction with new instruction if old and new instructions
479   // have the same shape. Updates uses and root instruction. Returns whether a
480   // replacement was made.
481   bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
482                                      HloInstruction* new_instruction);
483 
484   // Returns whether the shape of the output of the given instructions are the
485   // same for the purposes of simplification. If options_.is_layout_sensitive()
486   // is true, then this tests shape equality including layout
487   // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
488   // tests shape compatibility (ShapeUtil::Compatible).
489   bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
490 
491   // Returns whether it was possible to transform `root` to a clamp instruction.
492   // With min a minimum instruction, max a maximum instruction, min_operand a
493   // operand of min and max_operand a operand of max.
494   // Precondition: root is either a minimum or a maximum.
495   bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
496                                    HloInstruction* min_operand,
497                                    HloInstruction* operand, HloInstruction* max,
498                                    HloInstruction* max_operand);
499 
500   // A Broadcast that feeds an element-wise operation with a unique non-scalar
501   // operand can sink to after the operation.
502   StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
503       HloInstruction* broadcast);
504 
505   StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
506   StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
507       const HloInstruction& dot, HloInstruction* lhs,
508       int64_t lhs_contracting_dim, HloInstruction* rhs,
509       int64_t rhs_contracting_dim, bool swapped);
510 
511   StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
512 
513   StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
514       HloInstruction* dot);
515 
GetOrCreateScalarAddComputation(PrimitiveType type)516   HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
517     HloComputation*& scalar_add_computation = scalar_add_computations_[type];
518     if (scalar_add_computation) {
519       return scalar_add_computation;
520     }
521 
522     HloComputation::Builder b("scalar_add_computation");
523     Shape shape = ShapeUtil::MakeShape(type, {});
524     simplifier_->UpdateLayout(&shape);
525     auto scalar_lhs = b.AddInstruction(
526         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
527     auto scalar_rhs = b.AddInstruction(
528         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
529     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
530         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
531     scalar_add_computation =
532         computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
533     return scalar_add_computation;
534   }
535 
536   // Tries to fold a kPad in the input or filter into the convolution
537   // instruction's window.
538   StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
539   StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
540 
541   // Tries to swap convolution operands if they would result in a more efficient
542   // convolution.
543   StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
544 
545   // Tries to use a kDot in place of the given convolution.
546   StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
547 
548   // Tries to simplify a slice where the result of the slice is a scalar.
549   StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
550 
551   // Tries to convert slice(reshape(X)) into reshape(slice(X))
552   StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
553 
554   // Tries to convert slice(reverse(X)) into reverse(slice(X))
555   StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice);
556 
557   // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
558   // `(< a N)`. This is crucial for being able to figure out the loop trip
559   // count.
560   //
561   // Assumes that the input is conjunction.
562   StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
563 
564   // Useful when we want to use the same visitor over multiple computations.
565   void ResetState(HloComputation* computation);
566 
567   // Current HloComputation instance the AlgebraicSimplifierVisitor is
568   // traversing.
569   HloComputation* computation_;
570 
571   // The backend-specific options selected for the algebraic simplifier.
572   const AlgebraicSimplifierOptions& options_;
573 
574   // Whether algebraic simplification has occurred.
575   bool changed_ = false;
576 
577   // Cached computation for adding two scalars of a given type.
578   absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
579 
580   AlgebraicSimplifier* simplifier_ = nullptr;
581 };
582 
583 }  // namespace
584 
ResetState(HloComputation * computation)585 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
586   changed_ = false;
587   ResetVisitStates();
588   computation_ = computation;
589 }
590 
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)591 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
592                                      const AlgebraicSimplifierOptions& options,
593                                      AlgebraicSimplifier* simplifier) {
594   ResetState(computation);
595   TF_CHECK_OK(computation->Accept(this));
596   return changed_ || changed();
597 }
598 
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const599 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
600                                            const HloInstruction* rhs) const {
601   if (options_.is_layout_sensitive()) {
602     return ShapeUtil::Equal(lhs->shape(), rhs->shape());
603   } else {
604     return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
605   }
606 }
607 
608 namespace {
609 
IsOpCodeMultiplyCommutative(HloOpcode opcode)610 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
611   switch (opcode) {
612     case HloOpcode::kMultiply:
613     case HloOpcode::kTranspose:
614     case HloOpcode::kReshape:
615     case HloOpcode::kSelect:
616       return true;
617     default:
618       return false;
619   }
620 }
621 
MakeScalarInstruction(HloInstruction * target,float multiplier)622 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
623                                                       float multiplier) {
624   switch (target->shape().element_type()) {
625     case BF16:
626       return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
627           LiteralUtil::CreateR0<float>(multiplier)));
628       break;
629     case F32:
630       return HloInstruction::CreateConstant(
631           LiteralUtil::CreateR0<float>(multiplier));
632       break;
633     default:
634       LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
635   }
636 }
637 
638 }  // namespace
639 
ScalarMultiplyReduction(HloInstruction * dot)640 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
641     HloInstruction* dot) {
642   // We only process bfloat16 and float32 for now.
643   if (dot->shape().element_type() != BF16 &&
644       dot->shape().element_type() != F32) {
645     return Status::OK();
646   }
647 
648   auto lhs = dot->mutable_operand(0);
649   auto rhs = dot->mutable_operand(1);
650 
651   const int64_t dot_size = ShapeUtil::ElementsIn(dot->shape());
652   const int64_t lhs_size = ShapeUtil::ElementsIn(lhs->shape());
653   const int64_t rhs_size = ShapeUtil::ElementsIn(rhs->shape());
654 
655   HloInstruction* target = nullptr;
656   // (current node, user, operand_index)
657   std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
658   std::vector<HloInstruction*> users;
659 
660   // Find which side of dot has the smallest size:
661   // operand 0, operand 1, or output.
662   if (dot_size <= std::min(lhs_size, rhs_size)) {
663     target = dot;
664     if (dot_size < lhs_size) {
665       operands.emplace_back(lhs, dot, 0);
666     }
667     if (dot_size < rhs_size) {
668       operands.emplace_back(rhs, dot, 1);
669     }
670   } else if (lhs_size <= rhs_size) {
671     target = lhs;
672     if (lhs_size < rhs_size) {
673       operands.emplace_back(rhs, dot, 1);
674     }
675     if (lhs_size < dot_size && dot->user_count() == 1) {
676       users.push_back(dot->users().front());
677     }
678   } else {
679     target = rhs;
680     if (rhs_size < lhs_size) {
681       operands.emplace_back(lhs, dot, 0);
682     }
683     if (rhs_size < dot_size && dot->user_count() == 1) {
684       users.push_back(dot->users().front());
685     }
686   }
687 
688   std::vector<float> values;
689 
690   // DFS to find scalar multiply ops from the operands.
691   while (!operands.empty()) {
692     HloInstruction* inst;
693     HloInstruction* user;
694     int64_t index;
695     std::tie(inst, user, index) = operands.back();
696     operands.pop_back();
697 
698     // Skip the op types that are not commutative with multiply.
699     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
700       continue;
701     }
702 
703     HloInstruction* operand;
704     HloInstruction* multiplier;
705     // Pattern match a scalar multiply.
706     if (Match(inst, m::MultiplyAnyOrder(
707                         m::Op(&operand),
708                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
709       CHECK_LT(index, user->operand_count());
710       CHECK_EQ(inst, user->operands()[index]);
711 
712       // When found a scalar multiply, save its scalar value.
713       values.push_back(*GetConstantValue(multiplier));
714       // And remove the scalar multiply op.
715       TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
716       inst = operand;
717     }
718 
719     // Push the operands of inst.
720     int64_t i = 0;
721     for (auto* operand : inst->operands()) {
722       operands.emplace_back(operand, inst, i++);
723     }
724   }
725 
726   // DFS to find scalar multiply ops from the users.
727   while (!users.empty()) {
728     auto inst = users.back();
729     users.pop_back();
730 
731     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
732       continue;
733     }
734 
735     HloInstruction* operand;
736     HloInstruction* multiplier;
737     if (Match(inst, m::MultiplyAnyOrder(
738                         m::Op(&operand),
739                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
740       values.push_back(*GetConstantValue(multiplier));
741 
742       TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
743       inst = operand;
744     }
745 
746     // Process the instructions with only one user.
747     // Otherwise moving scalar multiply to the operands changes the values of
748     // other users.
749     if (inst->user_count() == 1) {
750       users.push_back(inst->users().front());
751     }
752   }
753 
754   if (values.empty()) {
755     return Status::OK();
756   }
757 
758   changed_ = true;
759 
760   // Combine all constant multipliers.
761   float multiplier = 1.0;
762   for (const float v : values) {
763     multiplier *= v;
764   }
765 
766   // Create a new const scalar multiply instruction.
767   HloInstruction* new_const_inst;
768   new_const_inst =
769       computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
770 
771   // Broadcast the scalar multiplier.
772   HloInstruction* new_broadcast = computation_->AddInstruction(
773       HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
774   // Create a new scalar multiply instruction.
775   HloInstruction* new_multiply =
776       computation_->AddInstruction(HloInstruction::CreateBinary(
777           target->shape(), HloOpcode::kMultiply, target, new_broadcast));
778   CHECK_EQ(new_multiply->shape(), target->shape());
779 
780   // Update the dependency with the rest of the instructions.
781   if (target == lhs) {
782     return dot->ReplaceOperandWith(0, new_multiply);
783   } else if (target == rhs) {
784     return dot->ReplaceOperandWith(1, new_multiply);
785   } else {
786     CHECK_EQ(target, dot);
787     return dot->ReplaceAllUsesWith(new_multiply);
788   }
789 }
790 
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)791 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
792                                                     HloInstruction* operand) {
793   CHECK_EQ(1, instruction->operand_count());
794   if (operand == nullptr) {
795     operand = instruction->mutable_operand(0);
796   }
797   CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
798            ShapeUtil::ElementsIn(operand->shape()));
799   CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
800            ShapeUtil::ByteSizeOf(operand->shape()));
801 
802   auto bitcast = computation_->AddInstruction(
803       HloInstruction::CreateBitcast(instruction->shape(), operand));
804   bitcast->set_metadata(instruction->metadata());
805   TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
806 }
807 
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)808 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
809     HloInstruction* old_instruction, HloInstruction* new_instruction) {
810   if (!SameShape(old_instruction, new_instruction)) {
811     return false;
812   }
813   TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
814   return true;
815 }
816 
HandleAbs(HloInstruction * abs)817 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
818   HloInstruction* abs_operand = abs->mutable_operand(0);
819   VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
820            << " Abs operand is: " << abs_operand->ToString();
821   if (IsNonNegative(abs->operand(0), options_)) {
822     return ReplaceInstruction(abs, abs_operand);
823   }
824   return Status::OK();
825 }
826 
HandleAdd(HloInstruction * add)827 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
828   HloInstruction *lhs, *rhs;
829   CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
830 
831   // A + 0 => A
832   VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
833   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
834     return Status::OK();
835   }
836   // 0 + A => A
837   VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
838   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
839     return Status::OK();
840   }
841 
842   // Canonicalization: Put constants on the right.  This makes the reassociation
843   // rules below simpler.
844   VLOG(10) << "trying transform [Const + A => A + Const]";
845   if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
846     return ReplaceWithNewInstruction(
847         add,
848         HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
849   }
850 
851   // Reassociate to allow constant folding.
852   //
853   // Note: This is not general.  For example, we won't reassociate
854   //
855   //   (A + C1) + (B + C2) =>  A + B + (C1 + C2).
856   //
857   VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
858   HloInstruction *a, *c1, *c2;
859   if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
860                         m::Constant(&c2))) ||
861       Match(add, m::Add(m::Add(m::NonConstant(&a),
862                                m::Broadcast(m::ConstantScalar(&c1))),
863                         m::Broadcast(m::ConstantScalar(&c2))))) {
864     TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
865                         MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
866     if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
867         !ShapeUtil::IsScalar(add->shape())) {
868       sum_of_constants = computation_->AddInstruction(
869           HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
870     }
871     return ReplaceWithNewInstruction(
872         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
873                                           sum_of_constants));
874   }
875 
876   // Convert add with fullshape into add with partial shape when a
877   // portion of add is effective:
878   //             zero (fullshape)   rhs (partialshape)
879   // .           |                  |
880   // . lhs .    dynamic_update_slice (fullshape)
881   // . |         |
882   // Add (fullshape)
883   //
884   // to:
885   //              lhs
886   //              |
887   //             dynamic_slice (partialshape)   rhs (partialshape)
888   // .           |                      |
889   // . lhs .    add (partial_shape)+----+
890   // . |         |
891   // dynamic_update_slice (fullshape)
892   //
893   // This is pattern is discovered in control flow V2 gradient update.
894   if (Match(add,
895             m::Add(m::Op(&lhs),
896                    m::Op(&rhs)
897                        .WithOpcode(HloOpcode::kDynamicUpdateSlice)
898                        .WithOperand(
899                            0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
900     const Shape& partial_shape = rhs->operand(1)->shape();
901     auto sliced_lhs =
902         computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
903             partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
904             partial_shape.dimensions()));
905 
906     auto add_partial = computation_->AddInstruction(
907         HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
908                                      sliced_lhs, rhs->mutable_operand(1)));
909 
910     auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
911         lhs->shape(), lhs, add_partial,
912         absl::MakeSpan(rhs->operands()).subspan(2));
913 
914     return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
915   }
916 
917   // A*C + B*C => (A+B)*C
918   //
919   //  - If A, B, and C are integers, do this unconditionally. Proof of
920   //    correctness: https://rise4fun.com/Alive/u9X.
921   //
922   //  - If A, B, and C are floating point, do this if C is a scalar constant or
923   //    broadcast of scalar constant and is equal to +/- 2^k for some (possibly
924   //    negative) integer k.
925   //
926   //    Multiplying by a power of 2 just moves the exponent, so our answer is
927   //    exact modulo rounding of intermediate results so long as
928   //
929   //     - none of the three products has an exponent which underflows (so the
930   //       result is 0 or denormal), and
931   //     - none of the three products overflows to inf.
932   //
933   //    Proof: See algebraic_simplifier_proof_distributive_property.py.
934   //
935   //    We deem these differences in rounding, underflow, and overflow
936   //    acceptable in the ML context.
937   //
938   //    Furthermore, if `enable_floats_are_real` is true, the simplification is
939   //    done nonetheless. This might cause numerical differences even if there
940   //    is no underflow or overflow.
941   HloInstruction *b, *c;
942   if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
943         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
944        (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
945         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
946       // Make sure we would decrease the number of multiplies.
947       (lhs->user_count() == 1 && rhs->user_count() == 1) &&
948       (ShapeUtil::ElementIsIntegral(add->shape()) ||
949        options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
950     return ReplaceWithNewInstruction(
951         add, HloInstruction::CreateBinary(
952                  add->shape(), HloOpcode::kMultiply,
953                  computation_->AddInstruction(HloInstruction::CreateBinary(
954                      add->shape(), HloOpcode::kAdd, a, b)),
955                  c));
956   }
957 
958   if (options_.is_layout_sensitive()) {
959     return Status::OK();
960   }
961 
962   HloInstruction* lhs_scatter_operand = nullptr;
963   HloInstruction* rhs_scatter_operand = nullptr;
964   HloInstruction* lhs_scatter_update = nullptr;
965   HloInstruction* rhs_scatter_update = nullptr;
966   HloInstruction* lhs_scatter_index = nullptr;
967   HloInstruction* rhs_scatter_index = nullptr;
968   bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
969                                            m::Op(&lhs_scatter_index),
970                                            m::Op(&lhs_scatter_update))
971                                     .WithOneUse()) &&
972                      Match(lhs->to_apply()->root_instruction(),
973                            m::Add(m::Parameter(), m::Parameter()));
974   bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
975                                            m::Op(&rhs_scatter_index),
976                                            m::Op(&rhs_scatter_update))
977                                     .WithOneUse()) &&
978                      Match(rhs->to_apply()->root_instruction(),
979                            m::Add(m::Parameter(), m::Parameter()));
980   if (rhs_scatter && lhs_scatter) {
981     const auto& lhs_dnums = lhs->scatter_dimension_numbers();
982     const auto& rhs_dnums = rhs->scatter_dimension_numbers();
983     absl::optional<int64> index_concat_dimension;
984     absl::optional<int64> update_concat_dimension;
985     // Don't try to combine scatters of different ranks.
986     if (lhs_scatter_index->shape().rank() !=
987         rhs_scatter_index->shape().rank()) {
988       return Status::OK();
989     }
990 
991     int64_t first_index_dim = lhs_scatter_index->shape().rank();
992     int64_t first_update_dim = lhs_scatter_update->shape().rank();
993     // Find a dimension where it is possible to concatenate the indices and
994     // updates. This is the first and only non-equal dimension or the first
995     // equally sized dimension.
996     for (int64_t d = lhs_scatter_index->shape().rank() - 1,
997                  update_dim = lhs_scatter_update->shape().rank() - 1;
998          d >= 0; --d) {
999       if (d == lhs_dnums.index_vector_dim()) {
1000         continue;
1001       }
1002       while (
1003           absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
1004         --update_dim;
1005       }
1006       if (lhs_scatter_index->shape().dimensions(d) ==
1007           rhs_scatter_index->shape().dimensions(d)) {
1008         first_index_dim = d;
1009         first_update_dim = update_dim--;
1010         continue;
1011       }
1012       // More than one dimension of unequal size was found, bail out.
1013       if (index_concat_dimension) {
1014         return Status::OK();
1015       }
1016       index_concat_dimension = d;
1017       update_concat_dimension = update_dim--;
1018     }
1019     if (!index_concat_dimension) {
1020       index_concat_dimension = first_index_dim;
1021       update_concat_dimension = first_update_dim;
1022     }
1023 
1024     // A scalar scatter will require additional reshapes of the index and
1025     // update.
1026     if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
1027       return Status::OK();
1028     }
1029     const bool update_concat_is_cheap =
1030         ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
1031             ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
1032         ShapeUtil::ElementsIn(lhs->shape());
1033     if (!update_concat_is_cheap) {
1034       return Status::OK();
1035     }
1036     const bool same_dimension_numbers =
1037         lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
1038         absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
1039                       rhs_dnums.scatter_dims_to_operand_dims()) &&
1040         absl::c_equal(lhs_dnums.inserted_window_dims(),
1041                       rhs_dnums.inserted_window_dims()) &&
1042         absl::c_equal(lhs_dnums.update_window_dims(),
1043                       rhs_dnums.update_window_dims());
1044     const bool index_concat_is_safe =
1045         !lhs->unique_indices() && !rhs->unique_indices() &&
1046         !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
1047         !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
1048 
1049     Shape lhs_update_window = ShapeUtil::FilterDimensions(
1050         [&](int64_t dim) {
1051           return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
1052         },
1053         lhs_scatter_update->shape());
1054     Shape rhs_update_window = ShapeUtil::FilterDimensions(
1055         [&](int64_t dim) {
1056           return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
1057         },
1058         rhs_scatter_update->shape());
1059     // Concatenate the indices and updates
1060     if (index_concat_is_safe && same_dimension_numbers &&
1061         index_concat_dimension &&
1062         lhs_scatter_index->shape().element_type() ==
1063             rhs_scatter_index->shape().element_type() &&
1064         ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
1065       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1066                           MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1067                                         rhs_scatter_operand));
1068       TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
1069                           MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
1070                                         *index_concat_dimension));
1071       TF_ASSIGN_OR_RETURN(
1072           HloInstruction * new_update,
1073           MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
1074                         *update_concat_dimension));
1075       return ReplaceWithNewInstruction(
1076           add, HloInstruction::CreateScatter(
1077                    add->shape(), new_operand, new_index, new_update,
1078                    lhs->to_apply(), lhs_dnums, false, false));
1079     }
1080     TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1081                         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1082                                       rhs_scatter_operand));
1083     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1084     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
1085     return ReplaceInstruction(add, lhs);
1086   } else if (rhs_scatter) {
1087     TF_ASSIGN_OR_RETURN(
1088         HloInstruction * new_operand,
1089         MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
1090     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1091     return ReplaceInstruction(add, rhs);
1092   } else if (lhs_scatter) {
1093     TF_ASSIGN_OR_RETURN(
1094         HloInstruction * new_operand,
1095         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
1096     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
1097     return ReplaceInstruction(add, lhs);
1098   }
1099   return Status::OK();
1100 }
1101 
TrySimplifyTautologicalCompare(HloInstruction * conjunction)1102 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
1103     HloInstruction* conjunction) {
1104   HloInstruction *lhs, *rhs;
1105   if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
1106     return false;
1107   }
1108   struct LessThanCompareInfo {  // (LT var constant)
1109     HloInstruction* var;
1110     int64 constant;
1111   };
1112 
1113   auto get_compare_info =
1114       [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
1115     HloInstruction *lhs, *rhs;
1116     auto scalar_shape_matcher =
1117         m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
1118     if (Match(cmp, m::Compare(m::Op(&lhs),
1119                               m::Constant(&rhs).WithShape(scalar_shape_matcher))
1120                        .WithComparisonDirection(ComparisonDirection::kLt))) {
1121       return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
1122     } else if (Match(
1123                    cmp,
1124                    m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
1125                               m::Op(&rhs))
1126                        .WithComparisonDirection(ComparisonDirection::kGt))) {
1127       return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
1128     }
1129     return absl::nullopt;
1130   };
1131 
1132   absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
1133   absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
1134   if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
1135     int64_t new_bound = std::min(lhs_info->constant, rhs_info->constant);
1136     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1137         conjunction,
1138         HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
1139                                       MakeScalarLike(lhs_info->var, new_bound),
1140                                       ComparisonDirection::kLt)));
1141     return true;
1142   }
1143   return false;
1144 }
1145 
HandleAnd(HloInstruction * logical_and)1146 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1147   HloInstruction *lhs, *rhs;
1148   CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1149   // Simplify logical and
1150   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1151       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1152     // A && True => A
1153     VLOG(10) << "trying transform [A && True => A]: "
1154              << logical_and->ToString();
1155     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1156       return Status::OK();
1157     }
1158     // True && A => A
1159     VLOG(10) << "trying transform [True && A => A]: "
1160              << logical_and->ToString();
1161     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1162       return Status::OK();
1163     }
1164   }
1165 
1166   // A && False => False or A & 0 => 0
1167   VLOG(10) << "trying transform [A && False => False]: "
1168            << logical_and->ToString();
1169   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1170     return Status::OK();
1171   }
1172 
1173   // False && A => False or A & 0 => 0
1174   VLOG(10) << "trying transform [False && A => False]: "
1175            << logical_and->ToString();
1176   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1177     return Status::OK();
1178   }
1179 
1180   // Simplify tautological conjunctions.
1181   TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1182                       TrySimplifyTautologicalCompare(logical_and));
1183   if (found_tautological_compare) {
1184     return Status::OK();
1185   }
1186 
1187   return Status::OK();
1188 }
1189 
HandleBitcast(HloInstruction * bitcast)1190 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1191   // If a bitcast feeds a bitcast, make it a single bitcast.
1192   HloInstruction* op;
1193   if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1194     return ReplaceWithNewInstruction(
1195         bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1196   }
1197   // All bitcasts can be eliminated (assuming layout constraints are
1198   // satisfied).
1199   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1200   return Status::OK();
1201 }
1202 
HandleBitcastConvert(HloInstruction * bitcast)1203 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1204     HloInstruction* bitcast) {
1205   // Eliminate bitcast converts between same shape.
1206   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1207   return Status::OK();
1208 }
1209 
HandleCopy(HloInstruction * copy)1210 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1211   // If a copy feeds a copy, make it a single copy.
1212   HloInstruction* op;
1213   if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1214     return ReplaceWithNewInstruction(
1215         copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1216   }
1217   // All copies can be eliminated (assuming layout constraints are satisfied).
1218   if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
1219     return Status::OK();
1220   }
1221 
1222   if (HloInstruction* bitcast_operand =
1223           BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1224     ReplaceWithBitcast(copy, bitcast_operand);
1225     return Status::OK();
1226   }
1227 
1228   // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1229   if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1230       copy->operand(0)->user_count() == 1 &&
1231       ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1232     return ReplaceWithNewInstruction(
1233         copy,
1234         copy->operand(0)->CloneWithNewOperands(
1235             copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1236   }
1237   return Status::OK();
1238 }
1239 
HandleConcatenate(HloInstruction * concatenate)1240 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1241     HloInstruction* concatenate) {
1242   absl::Span<HloInstruction* const> operands(concatenate->operands());
1243   if (operands.size() == 1) {
1244     // Unary concatenates are useless.
1245     ReplaceInstructionIfSameShape(concatenate, operands[0]);
1246     return Status::OK();
1247   }
1248   // Filter out and remove empty operands.
1249   std::vector<HloInstruction*> nonempty_operands;
1250   for (HloInstruction* operand : operands) {
1251     if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1252       nonempty_operands.push_back(operand);
1253     }
1254   }
1255   if (nonempty_operands.size() < operands.size()) {
1256     HloInstruction* replacement;
1257     if (nonempty_operands.empty()) {
1258       replacement = operands[0];
1259     } else if (nonempty_operands.size() == 1) {
1260       replacement = nonempty_operands[0];
1261     } else {
1262       replacement =
1263           computation_->AddInstruction(concatenate->CloneWithNewOperands(
1264               concatenate->shape(), nonempty_operands));
1265     }
1266     VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1267              << replacement->ToString();
1268     ReplaceInstructionIfSameShape(concatenate, replacement);
1269     return Status::OK();
1270   }
1271 
1272   if (options_.is_layout_sensitive()) {
1273     return Status::OK();
1274   }
1275 
1276   // concat(x, concat(y, z)) -> concat(x, y, z).  We only do this in
1277   // layout-insensitive mode because some backends may have (late,
1278   // layout-sensitive) passes that break up ops with many operands into smaller
1279   // pieces.  This would undo that.
1280   absl::InlinedVector<HloInstruction*, 8> unnested_concat_operands;
1281   for (HloInstruction* operand : operands) {
1282     if (operand->opcode() == HloOpcode::kConcatenate &&
1283         operand->concatenate_dimension() ==
1284             concatenate->concatenate_dimension()) {
1285       for (HloInstruction* instr : operand->operands()) {
1286         unnested_concat_operands.push_back(instr);
1287       }
1288     } else {
1289       unnested_concat_operands.push_back(operand);
1290     }
1291   }
1292   if (unnested_concat_operands.size() != concatenate->operand_count()) {
1293     return ReplaceWithNewInstruction(
1294         concatenate, HloInstruction::CreateConcatenate(
1295                          concatenate->shape(), unnested_concat_operands,
1296                          concatenate->concatenate_dimension()));
1297   }
1298 
1299   // Check if we can merge "adjacent" slice operands which take slices from the
1300   // same other op. For simplicity we only merge unstrided slices.
1301   int64_t concatenate_dimension = concatenate->concatenate_dimension();
1302   std::vector<HloInstruction*> new_operands;
1303   int64_t i = 0;
1304   while (i < operands.size()) {
1305     if (operands[i]->opcode() != HloOpcode::kSlice ||
1306         !IsUnstridedSlice(operands[i])) {
1307       new_operands.push_back(operands[i]);
1308       ++i;
1309       continue;
1310     }
1311     int64_t slice_end = operands[i]->slice_limits(concatenate_dimension);
1312     HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1313     int64_t j = i + 1;
1314     while (j < operands.size()) {
1315       if (operands[j]->opcode() != HloOpcode::kSlice ||
1316           !IsUnstridedSlice(operands[j]) ||
1317           operands[j]->operand(0) != slice_operand ||
1318           operands[j]->slice_starts(concatenate_dimension) != slice_end) {
1319         break;
1320       }
1321       // Check that all the slice_start values are the same in all other
1322       // dimensions. This implies that the slice_limit values are also the same,
1323       // because operands of concatenate need to have the same shape, and we
1324       // already checked that the slices are unstrided.
1325       bool same_other_starts = true;
1326       for (int64_t k = 0; k < operands[j]->slice_starts().size(); ++k) {
1327         if (k == concatenate_dimension) {
1328           continue;
1329         }
1330         if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1331           same_other_starts = false;
1332           break;
1333         }
1334       }
1335       if (!same_other_starts) {
1336         break;
1337       }
1338       slice_end = operands[j]->slice_limits(concatenate_dimension);
1339       ++j;
1340     }
1341     if (j - i > 1) {
1342       Shape new_slice_shape = operands[i]->shape();
1343       new_slice_shape.set_dimensions(
1344           concatenate_dimension,
1345           slice_end - operands[i]->slice_starts(concatenate_dimension));
1346       simplifier_->UpdateLayout(&new_slice_shape);
1347       auto new_limit_indices = operands[i]->slice_limits();
1348       new_limit_indices[concatenate_dimension] = slice_end;
1349       auto new_slice_op =
1350           computation_->AddInstruction(HloInstruction::CreateSlice(
1351               new_slice_shape, slice_operand,
1352               /*start_indices=*/operands[i]->slice_starts(),
1353               /*limit_indices=*/new_limit_indices,
1354               /*strides=*/operands[i]->slice_strides()));
1355       new_operands.push_back(new_slice_op);
1356     } else {
1357       new_operands.push_back(operands[i]);
1358     }
1359     i = j;
1360   }
1361   if (new_operands.size() < operands.size()) {
1362     auto replacement = computation_->AddInstruction(
1363         concatenate->CloneWithNewOperands(concatenate->shape(), new_operands));
1364     ReplaceInstructionIfSameShape(concatenate, replacement);
1365     return Status::OK();
1366   }
1367 
1368   if (operands.size() == 2) {
1369     // A binary concat with a broadcasted scalar as an operand can be converted
1370     // into a pad which is simpler to fold into other operations.
1371     bool is_effective_low_pad = Match(
1372         operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1373     bool is_effective_high_pad = Match(
1374         operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1375     if (!is_effective_low_pad && !is_effective_high_pad) {
1376       return Status::OK();
1377     }
1378     PaddingConfig padding_config;
1379     for (int64_t dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1380       auto padding_config_dim = padding_config.add_dimensions();
1381       padding_config_dim->set_edge_padding_high(0);
1382       padding_config_dim->set_edge_padding_low(0);
1383       padding_config_dim->set_interior_padding(0);
1384       if (dim == concatenate_dimension) {
1385         if (is_effective_low_pad) {
1386           padding_config_dim->set_edge_padding_low(
1387               operands[0]->shape().dimensions(dim));
1388         } else {
1389           padding_config_dim->set_edge_padding_high(
1390               operands[1]->shape().dimensions(dim));
1391         }
1392       }
1393     }
1394     int64_t operand_to_pad = is_effective_low_pad ? 1 : 0;
1395     int64_t pad_value_operand = is_effective_low_pad ? 0 : 1;
1396     HloInstruction* pad =
1397         computation_->AddInstruction(HloInstruction::CreatePad(
1398             concatenate->shape(), operands[operand_to_pad],
1399             operands[pad_value_operand]->mutable_operand(0), padding_config));
1400     return ReplaceInstruction(concatenate, pad);
1401   }
1402 
1403   if (absl::c_count(operands, operands[0]) == operands.size() &&
1404       operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1405     Shape new_shape = operands[0]->shape();
1406     absl::InlinedVector<int64, 8> broadcast_dims;
1407     for (int64_t i = 0; i < new_shape.rank(); ++i) {
1408       if (i == concatenate_dimension) {
1409         continue;
1410       }
1411       broadcast_dims.push_back(i);
1412     }
1413     new_shape.DeleteDimension(concatenate_dimension);
1414     return ReplaceInstruction(
1415         concatenate,
1416         MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1417                          broadcast_dims, concatenate->shape()));
1418   }
1419   return Status::OK();
1420 }
1421 
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1422 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1423                                           const LiteralSlice& literal,
1424                                           AlgebraicSimplifier* simplifier) {
1425   if (literal.shape().IsTuple()) {
1426     std::vector<HloInstruction*> elems;
1427     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1428     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1429       elems.push_back(BuildTupleConstant(
1430           computation, LiteralSlice(literal, {i}), simplifier));
1431     }
1432     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1433   } else {
1434     return computation->AddInstruction(
1435         simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1436   }
1437 }
1438 
HandleConstant(HloInstruction * constant)1439 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1440   // Tuple constants aren't directly supported by any backend. Expand them into
1441   // explicit Tuple instructions.
1442   if (constant->shape().IsTuple()) {
1443     return ReplaceInstruction(
1444         constant,
1445         BuildTupleConstant(computation_, constant->literal(), simplifier_));
1446   }
1447 
1448   if (constant->shape().element_type() == TOKEN) {
1449     return Status::OK();
1450   }
1451 
1452   // If a literal is all the same element replace it with a scalar broadcast.
1453   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1454       constant->literal().IsAllFirst()) {
1455     Literal unique_scalar(
1456         LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1457     HloInstruction* scalar = computation_->AddInstruction(
1458         simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1459     return ReplaceWithNewInstruction(
1460         constant,
1461         HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1462   }
1463 
1464   // If a literal is an increasing sequence from zero, replace it with an iota.
1465   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1466       constant->literal().IsR1Iota()) {
1467     return ReplaceWithNewInstruction(
1468         constant, HloInstruction::CreateIota(constant->shape(), 0));
1469   }
1470 
1471   if (absl::optional<int64> stride = constant->literal().IsR1StridedIota()) {
1472     // Replace the constant with iota * stride.
1473     HloInstruction* stride_hlo = MakeScalarLike(constant, *stride);
1474     HloInstruction* iota = computation_->AddInstruction(
1475         HloInstruction::CreateIota(constant->shape(), 0));
1476     return ReplaceWithNewInstruction(
1477         constant,
1478         HloInstruction::CreateBinary(constant->shape(), HloOpcode::kMultiply,
1479                                      iota, stride_hlo));
1480   }
1481 
1482   return Status::OK();
1483 }
1484 
HandleSubtract(HloInstruction * sub)1485 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1486   HloInstruction *lhs, *rhs;
1487   CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1488   // A - 0 => A
1489   VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1490   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
1491     return Status::OK();
1492   }
1493 
1494   // Canonicalize subtraction of a constant to addition.
1495   VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1496   if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1497       Match(sub, m::Subtract(m::NonConstant(&lhs),
1498                              m::Broadcast(m::Constant(&rhs))))) {
1499     HloInstruction* negative_const = computation_->AddInstruction(
1500         HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1501     if (const HloInstruction* broadcast =
1502             DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1503       negative_const =
1504           computation_->AddInstruction(HloInstruction::CreateBroadcast(
1505               broadcast->shape(), negative_const, broadcast->dimensions()));
1506     }
1507     return ReplaceWithNewInstruction(
1508         sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1509                                           negative_const));
1510   }
1511 
1512   return Status::OK();
1513 }
1514 namespace {
1515 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1516 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1517   return result->Populate<T>([&](absl::Span<const int64> indices) {
1518     return T{1.0} / constant.literal().Get<T>(indices);
1519   });
1520 }
1521 
1522 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1523 std::unique_ptr<HloInstruction> TryDivideToShift(
1524     HloInstruction* divide, HloComputation* computation,
1525     AlgebraicSimplifier* simplifier) {
1526   HloInstruction *a, *b, *c;
1527   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1528 
1529   if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1530       !Match(b, m::ConstantEffectiveScalar(&c)) &&
1531       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1532     return nullptr;
1533   }
1534 
1535   if (ShapeUtil::ElementIsSigned(divide->shape())) {
1536     int64_t b_value = c->literal().GetFirstElement<T>();
1537     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
1538       // Handle negative dividends by negating the result of the division.
1539       HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1540 
1541       Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1542       simplifier->UpdateLayout(&changed_shape);
1543       auto* dividend_is_negative =
1544           computation->AddInstruction(HloInstruction::CreateCompare(
1545               changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1546 
1547       auto* negated_dividend = computation->AddInstruction(
1548           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1549 
1550       auto* abs_dividend =
1551           computation->AddInstruction(HloInstruction::CreateTernary(
1552               a->shape(), HloOpcode::kSelect, dividend_is_negative,
1553               negated_dividend, a));
1554 
1555       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
1556           divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1557           MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
1558 
1559       auto* neqated_quotient =
1560           computation->AddInstruction(HloInstruction::CreateUnary(
1561               quotient->shape(), HloOpcode::kNegate, quotient));
1562 
1563       return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1564                                            dividend_is_negative,
1565                                            neqated_quotient, quotient);
1566     }
1567   } else {
1568     uint64 b_value = c->literal().GetFirstElement<T>();
1569     if (IsPowerOfTwo(b_value)) {
1570       return HloInstruction::CreateBinary(
1571           divide->shape(), HloOpcode::kShiftRightLogical, a,
1572           MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
1573     }
1574   }
1575 
1576   return nullptr;
1577 }
1578 }  // namespace
1579 
HandleDivide(HloInstruction * divide)1580 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1581   HloInstruction *a, *b, *c, *d;
1582   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1583   // A/1 => A
1584   VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1585   if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
1586     return Status::OK();
1587   }
1588 
1589   // A / B => A >> log2(B) if B is a power of 2.
1590   switch (divide->shape().element_type()) {
1591     case S8:
1592       if (std::unique_ptr<HloInstruction> shift =
1593               TryDivideToShift<int8>(divide, computation_, simplifier_)) {
1594         return ReplaceWithNewInstruction(divide, std::move(shift));
1595       }
1596       break;
1597     case S16:
1598       if (std::unique_ptr<HloInstruction> shift =
1599               TryDivideToShift<int16>(divide, computation_, simplifier_)) {
1600         return ReplaceWithNewInstruction(divide, std::move(shift));
1601       }
1602       break;
1603     case S32:
1604       if (std::unique_ptr<HloInstruction> shift =
1605               TryDivideToShift<int32>(divide, computation_, simplifier_)) {
1606         return ReplaceWithNewInstruction(divide, std::move(shift));
1607       }
1608       break;
1609     case S64:
1610       if (std::unique_ptr<HloInstruction> shift =
1611               TryDivideToShift<int64>(divide, computation_, simplifier_)) {
1612         return ReplaceWithNewInstruction(divide, std::move(shift));
1613       }
1614       break;
1615     case U8:
1616       if (std::unique_ptr<HloInstruction> shift =
1617               TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
1618         return ReplaceWithNewInstruction(divide, std::move(shift));
1619       }
1620       break;
1621     case U16:
1622       if (std::unique_ptr<HloInstruction> shift =
1623               TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
1624         return ReplaceWithNewInstruction(divide, std::move(shift));
1625       }
1626       break;
1627     case U32:
1628       if (std::unique_ptr<HloInstruction> shift =
1629               TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
1630         return ReplaceWithNewInstruction(divide, std::move(shift));
1631       }
1632       break;
1633     case U64:
1634       if (std::unique_ptr<HloInstruction> shift =
1635               TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
1636         return ReplaceWithNewInstruction(divide, std::move(shift));
1637       }
1638       break;
1639     default:
1640       break;
1641   }
1642 
1643   Shape* shape;
1644   // exp(A)/exp(B) => exp(A-B)
1645   if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1646                         .WithShape(m::Shape(&shape)))) {
1647     VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1648     HloInstruction* subtract = computation_->AddInstruction(
1649         HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1650     return ReplaceWithNewInstruction(
1651         divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1652   }
1653 
1654   // A/exp(B) => A*exp(-B)
1655   if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1656     VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1657     HloInstruction* negate = computation_->AddInstruction(
1658         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1659     HloInstruction* new_exp = computation_->AddInstruction(
1660         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1661     return ReplaceWithNewInstruction(
1662         divide, HloInstruction::CreateBinary(divide->shape(),
1663                                              HloOpcode::kMultiply, a, new_exp));
1664   }
1665 
1666   // A/pow(B,C) => A*pow(B,-C)
1667   if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1668     VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1669     // The output shape of the created negate operator should be the same as the
1670     // input.
1671     const Shape& negate_shape = c->shape();
1672     HloInstruction* negate = computation_->AddInstruction(
1673         HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1674     // And the power operator should retain the output shape of the old one.
1675     const Shape& new_power_shape = b->shape();
1676     HloInstruction* new_power =
1677         computation_->AddInstruction(HloInstruction::CreateBinary(
1678             new_power_shape, HloOpcode::kPower, b, negate));
1679     return ReplaceWithNewInstruction(
1680         divide, HloInstruction::CreateBinary(
1681                     divide->shape(), HloOpcode::kMultiply, a, new_power));
1682   }
1683 
1684   // A/sqrt(B) => A*rsqrt(X).
1685   if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1686     auto* rsqrt = computation_->AddInstruction(
1687         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1688     return ReplaceWithNewInstruction(
1689         divide, HloInstruction::CreateBinary(rsqrt->shape(),
1690                                              HloOpcode::kMultiply, a, rsqrt));
1691   }
1692 
1693   // A/rsqrt(B) => A*sqrt(B).
1694   if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1695     auto* sqrt = computation_->AddInstruction(
1696         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1697     return ReplaceWithNewInstruction(
1698         divide, HloInstruction::CreateBinary(sqrt->shape(),
1699                                              HloOpcode::kMultiply, a, sqrt));
1700   }
1701 
1702   // Simplifying integral division would produce unexpected results.
1703   if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1704     return Status::OK();
1705   }
1706 
1707   // A / Const => A * (1 / Const)
1708   //
1709   // (Backends can do this transformation, but generally only if the constant is
1710   // a scalar.)
1711   if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1712       (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1713     Shape result_shape = c->literal().shape();
1714     Literal new_literal(result_shape);
1715     switch (result_shape.element_type()) {
1716       case F16:
1717         TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1718         break;
1719       case F32:
1720         TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1721         break;
1722       case BF16:
1723         TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1724         break;
1725       case F64:
1726         TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1727         break;
1728       case C64:
1729         TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1730         break;
1731       case C128:
1732         TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1733         break;
1734       default:
1735         return Status::OK();
1736     }
1737     auto inverse = computation_->AddInstruction(
1738         simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1739     if (b != c) {
1740       inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1741           b->shape(), inverse, b->dimensions()));
1742     }
1743     TF_ASSIGN_OR_RETURN(auto new_divide,
1744                         MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1745     return ReplaceInstruction(divide, new_divide);
1746   }
1747 
1748   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
1749   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1750                               m::Divide(m::Op(&c), m::Op(&d))))) {
1751     TF_ASSIGN_OR_RETURN(auto a_times_d,
1752                         MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1753     TF_ASSIGN_OR_RETURN(auto b_times_c,
1754                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1755     TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1756                                                        a_times_d, b_times_c));
1757 
1758     return ReplaceInstruction(divide, new_divide);
1759   }
1760 
1761   // (A / B) / C => A / (B * C)
1762   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1763     TF_ASSIGN_OR_RETURN(auto b_times_c,
1764                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1765     TF_ASSIGN_OR_RETURN(auto new_divide,
1766                         MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1767     return ReplaceInstruction(divide, new_divide);
1768   }
1769 
1770   // A / (B / C) => (A*C) / B
1771   if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1772     TF_ASSIGN_OR_RETURN(auto a_times_c,
1773                         MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1774     TF_ASSIGN_OR_RETURN(auto new_divide,
1775                         MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1776     return ReplaceInstruction(divide, new_divide);
1777   }
1778 
1779   // If X is a convert from pred, then
1780   // X / broadcast(Y) => broadcast(1/Y) * X
1781   if (Match(divide,
1782             m::Divide(
1783                 m::Convert(&a,
1784                            m::Op().WithShape(m::Shape().WithElementType(PRED))),
1785                 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1786     TF_ASSIGN_OR_RETURN(
1787         auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1788     auto recip_bcast = computation_->AddInstruction(
1789         HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1790     TF_ASSIGN_OR_RETURN(auto mul,
1791                         MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1792     return ReplaceInstruction(divide, mul);
1793   }
1794 
1795   return Status::OK();
1796 }
1797 
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1798 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1799     HloInstruction* dot) {
1800   const Shape& lhs_shape = dot->operand(0)->shape();
1801   int64_t num_degenerate_lhs_dims = 0;
1802   std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1);
1803   for (int64_t i = 0; i < lhs_shape.rank(); ++i) {
1804     if (lhs_shape.dimensions(i) == 1) {
1805       ++num_degenerate_lhs_dims;
1806     } else {
1807       lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1808     }
1809   }
1810 
1811   const Shape& rhs_shape = dot->operand(1)->shape();
1812   int64_t num_degenerate_rhs_dims = 0;
1813   std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1);
1814   for (int64_t i = 0; i < rhs_shape.rank(); ++i) {
1815     if (rhs_shape.dimensions(i) == 1) {
1816       ++num_degenerate_rhs_dims;
1817     } else {
1818       rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1819     }
1820   }
1821   if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1822     return false;
1823   }
1824   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1825   DotDimensionNumbers new_dnums;
1826   for (int64_t dim : dnums.lhs_batch_dimensions()) {
1827     int64_t new_dim = lhs_dimension_map[dim];
1828     if (new_dim != -1) {
1829       new_dnums.add_lhs_batch_dimensions(new_dim);
1830     }
1831   }
1832   for (int64_t dim : dnums.lhs_contracting_dimensions()) {
1833     int64_t new_dim = lhs_dimension_map[dim];
1834     if (new_dim != -1) {
1835       new_dnums.add_lhs_contracting_dimensions(new_dim);
1836     }
1837   }
1838 
1839   for (int64_t dim : dnums.rhs_batch_dimensions()) {
1840     int64_t new_dim = rhs_dimension_map[dim];
1841     if (new_dim != -1) {
1842       new_dnums.add_rhs_batch_dimensions(new_dim);
1843     }
1844   }
1845   for (int64_t dim : dnums.rhs_contracting_dimensions()) {
1846     int64_t new_dim = rhs_dimension_map[dim];
1847     if (new_dim != -1) {
1848       new_dnums.add_rhs_contracting_dimensions(new_dim);
1849     }
1850   }
1851 
1852   HloInstruction* new_lhs =
1853       num_degenerate_lhs_dims > 0
1854           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1855                 ShapeUtil::DropDegenerateDimensions(lhs_shape),
1856                 dot->mutable_operand(0)))
1857           : dot->mutable_operand(0);
1858   HloInstruction* new_rhs =
1859       num_degenerate_rhs_dims > 0
1860           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1861                 ShapeUtil::DropDegenerateDimensions(rhs_shape),
1862                 dot->mutable_operand(1)))
1863           : dot->mutable_operand(1);
1864   TF_ASSIGN_OR_RETURN(
1865       auto new_dot,
1866       MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
1867                  /*preferred_element_type=*/dot->shape().element_type()));
1868   if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
1869     TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
1870   } else {
1871     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1872         dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
1873   }
1874   return true;
1875 }
1876 
OptimizeDotOfConcat(HloInstruction * dot)1877 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1878     HloInstruction* dot) {
1879   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1880   if (dnums.lhs_contracting_dimensions_size() != 1 ||
1881       dnums.lhs_batch_dimensions_size() != 0 ||
1882       dot->shape().dimensions_size() != 2) {  // dot output 2D
1883     return nullptr;
1884   }
1885 
1886   const int64_t lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1887   const int64_t rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1888   HloInstruction *lhs, *rhs;
1889   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1890 
1891   TF_ASSIGN_OR_RETURN(
1892       HloInstruction * optimized_lhs_concat,
1893       OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1894                                 rhs_contracting_dim, /*swapped=*/false));
1895   if (optimized_lhs_concat) {
1896     return optimized_lhs_concat;
1897   }
1898 
1899   return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1900                                    lhs_contracting_dim, /*swapped=*/true);
1901 }
1902 
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64_t lhs_contracting_dim,HloInstruction * rhs,int64_t rhs_contracting_dim,bool swapped)1903 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1904     const HloInstruction& dot, HloInstruction* lhs, int64_t lhs_contracting_dim,
1905     HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) {
1906   bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1907                       lhs->concatenate_dimension() == lhs_contracting_dim &&
1908                       rhs->opcode() == HloOpcode::kConstant;
1909   if (!can_optimize) {
1910     return nullptr;
1911   }
1912 
1913   // We're replacing this:
1914   //
1915   //   +-----+-----+-----+      +-------------------+
1916   //   |     |     |     |      |                   |
1917   //   |     |     |     |      |        R_0        |
1918   //   |     |     |     |      |                   |
1919   //   |     |     |     |      +-------------------+
1920   //   |     |     |     |      |                   |
1921   //   | L_0 | L_1 | L_2 |   *  |        R_1        |
1922   //   |     |     |     |      |                   |
1923   //   |     |     |     |      +-------------------+
1924   //   |     |     |     |      |                   |
1925   //   |     |     |     |      |        R_2        |
1926   //   |     |     |     |      |                   |
1927   //   +-----+-----+-----+      +-------------------+
1928   //
1929   // with this:
1930   //
1931   // [Sum over i]
1932   //
1933   //   +-----+     +-------------------+
1934   //   |     |     |                   |
1935   //   |     |  *  |        R_i        |
1936   //   |     |     |                   |
1937   //   |     |     +-------------------+
1938   //   |     |
1939   //   | L_i |
1940   //   |     |
1941   //   |     |
1942   //   |     |
1943   //   |     |
1944   //   |     |
1945   //   +-----+
1946   //
1947   // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1948   // for free) and the RHS is a constant tensor (and thus can be split at
1949   // compile time).  In the future, we may also want to do this when both the
1950   // LHS and the RHS are concatenate operations that line up along the dimension
1951   // being contracted over.
1952   //
1953   // We should be able to generalize this transform to work on a non-constant
1954   // RHS when/if we have in-place slices or support input-fusing slices into
1955   // Dots.
1956 
1957   // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1958   // the diagram above).
1959   DotDimensionNumbers new_dot_dnums;
1960   new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1961                                                        : lhs_contracting_dim);
1962   new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1963                                                        : rhs_contracting_dim);
1964 
1965   // Here we use the MKN notation, where the contracted dimension has K
1966   // elements and the two non-contracted dimensions have M and N elements.
1967   HloInstruction* add_result = nullptr;
1968   int64_t rhs_contracting_dim_offset = 0;
1969   int64_t n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1970   for (HloInstruction* concat_op : lhs->operands()) {
1971     int64_t sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1972     Shape rhs_slice_shape(rhs->shape());
1973     rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1974     simplifier_->UpdateLayout(&rhs_slice_shape);
1975 
1976     std::array<int64, 2> start_indices;
1977     start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1978     start_indices[1 - rhs_contracting_dim] = 0;
1979 
1980     std::array<int64, 2> limit_indices;
1981     limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1982     limit_indices[1 - rhs_contracting_dim] = n;
1983 
1984     HloInstruction* rhs_slice =
1985         computation_->AddInstruction(HloInstruction::CreateSlice(
1986             rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1987             /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1988 
1989     // TODO(b/69062148): We can get rid of `swapped` once all backends support
1990     // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1991     // LHS with dimension 0 of the RHS).  But for now we keep the same
1992     // contraction dimensions as the incoming dot operation to ensure the new
1993     // dot operations can be lowered.
1994     HloInstruction *new_dot_lhs, *new_dot_rhs;
1995     if (swapped) {
1996       new_dot_lhs = rhs_slice;
1997       new_dot_rhs = concat_op;
1998     } else {
1999       new_dot_lhs = concat_op;
2000       new_dot_rhs = rhs_slice;
2001     }
2002 
2003     auto* new_dot = computation_->AddInstruction(
2004         HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
2005                                   new_dot_dnums, dot.precision_config()));
2006 
2007     if (add_result) {
2008       add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
2009           dot.shape(), HloOpcode::kAdd, add_result, new_dot));
2010     } else {
2011       add_result = new_dot;
2012     }
2013 
2014     rhs_contracting_dim_offset += sub_k;
2015   }
2016 
2017   return add_result;
2018 }
2019 
OptimizeDotOfGather(HloInstruction * dot)2020 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
2021     HloInstruction* dot) {
2022   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2023   if (dnums.lhs_contracting_dimensions_size() != 1 ||
2024       dnums.rhs_contracting_dimensions_size() != 1 ||
2025       dnums.lhs_batch_dimensions_size() != 0 ||
2026       dnums.rhs_batch_dimensions_size() != 0 ||
2027       dot->shape().dimensions_size() != 2) {  // dot output 2D
2028     VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
2029     return nullptr;
2030   }
2031 
2032   // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
2033   // Currently a Gather is a DynamicSlice.
2034   auto is_dynamic_slice_constant_combination =
2035       [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
2036         // First operand is a DynamicSlice(Constant).
2037         if (a->opcode() != HloOpcode::kDynamicSlice) {
2038           return false;
2039         }
2040         auto* dynamic_slice_op = a->operand(0);
2041         if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
2042           return false;
2043         }
2044         // Second operand is a Constant.
2045         if (b->opcode() != HloOpcode::kConstant) {
2046           return false;
2047         }
2048         // The DynamicSlice output is a vector.
2049         const Shape& dynamic_slice_shape = a->shape();
2050         if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
2051           return false;
2052         }
2053         // Constant size is the same before and after slice in the contracting
2054         // dimension, otherwise we either must precompute for all possible slice
2055         // indices or dot is invalid.
2056         const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
2057         if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
2058             dynamic_slice_shape.dimensions(a_contracting_dimension)) {
2059           return false;
2060         }
2061         return true;
2062       };
2063 
2064   HloInstruction* lhs = dot->mutable_operand(0);
2065   HloInstruction* rhs = dot->mutable_operand(1);
2066   int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2067   int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2068 
2069   if (!is_dynamic_slice_constant_combination(
2070           lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2071       !is_dynamic_slice_constant_combination(
2072           rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2073     VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2074                 "dot(ctB, DS(ctA)), where the two constants have equal "
2075                 "contracting dimensions.";
2076     return nullptr;
2077   }
2078 
2079   // LHS is DynamicSlice:
2080   // input: dot(DS(ctA), ctB))
2081   // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2082   // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2083   // output: DS(dot(ctA, ctB))
2084   // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2085 
2086   // RHS is DynamicSlice:
2087   // input: dot(ctA, DS(ctB))
2088   // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2089   // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2090   // output: DS(dot(ctA, ctB))
2091   // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2092 
2093   bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2094   HloDynamicSliceInstruction* dynamic_slice =
2095       lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2096                            : Cast<HloDynamicSliceInstruction>(rhs);
2097 
2098   // ctA:
2099   HloInstruction* left_operand =
2100       lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2101   // ctB:
2102   HloInstruction* right_operand =
2103       lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2104   // Build ctA x ctB.
2105   const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2106   const int n =
2107       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2108   auto memoized_shape =
2109       ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2110   simplifier_->UpdateLayout(&memoized_shape);
2111   auto* memoized_inst = computation_->AddInstruction(
2112       HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2113                                 dnums, dot->precision_config()));
2114   // Get pair {start, 0} or {0, start}.
2115   // Position of start:
2116   int index_of_non_zero_start = lhs_is_dynamic_slice
2117                                     ? 1 - lhs_contracting_dimension
2118                                     : 1 - rhs_contracting_dimension;
2119   // Position of zero:
2120   int index_of_zero_start = 1 - index_of_non_zero_start;
2121 
2122   // Slice out start and 0 components and reorder if necessary.
2123   auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2124   Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2125   simplifier_->UpdateLayout(&s_shape);
2126   Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2127   simplifier_->UpdateLayout(&d_shape);
2128   HloInstruction* non_zero_start =
2129       dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2130   HloInstruction* zero_start =
2131       dynamic_slice->mutable_operand(1 + index_of_zero_start);
2132   std::vector<HloInstruction*> new_start_indices;
2133   if (lhs_is_dynamic_slice) {
2134     new_start_indices = {non_zero_start, zero_start};
2135   } else {
2136     new_start_indices = {zero_start, non_zero_start};
2137   }
2138 
2139   // Build DynamicSlice(ctA x ctB).
2140   const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2141   const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2142   auto* memoized_lookup =
2143       computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
2144           dot->shape(), memoized_inst, new_start_indices,
2145           {new_slice_m, new_slice_n}));
2146 
2147   return memoized_lookup;
2148 }
2149 
2150 // This function tries to transform
2151 //   dot(reshape(transpose(A)), Const) to
2152 //   dot(reshape(A), reshape(transpose(reshape(Const)))),
2153 // so that the reshape and transpose on the Const side can be constant folded.
2154 //
2155 // The basic idea is that since the accumulation in the dot operation is
2156 // associative, so as long as we permute the elements of the contracting
2157 // dimensions on both sides of the dot in the same way, the result of the
2158 // dot is not affected.
2159 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2160 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2161     HloInstruction* dot) {
2162   // This transformation assumes layout is not assigned yet.
2163   if (options_.is_layout_sensitive()) {
2164     return nullptr;
2165   }
2166 
2167   // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2168   // remainder of this function easier.
2169   auto dnums = dot->dot_dimension_numbers();
2170   auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2171   auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2172   auto* lhs = dot->mutable_operand(0);
2173   auto* rhs = dot->mutable_operand(1);
2174   if (dot->operand(0)->IsConstant()) {
2175     std::swap(lhs, rhs);
2176     std::swap(lhs_contracting_dims, rhs_contracting_dims);
2177   }
2178 
2179   // Require single contracting dim to make the implementation easier to
2180   // track contracting dims.
2181   if (dnums.lhs_contracting_dimensions_size() != 1) {
2182     return nullptr;
2183   }
2184 
2185   // Pattern match Dot(reshape(transpose(input), constant))
2186   HloInstruction* reshape;
2187   HloInstruction* transpose;
2188   HloInstruction* input;
2189   HloInstruction* constant;
2190   if (!Match(lhs,
2191              m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2192       !Match(rhs, m::Constant(&constant))) {
2193     return nullptr;
2194   }
2195 
2196   // Check that reshape squishes some dims into one dim and that this one
2197   // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2198   // be N - 1, where N is the rank of the reshape output. This means that the
2199   // reshape squishes some dims into one dim. lhs contracting dim should not
2200   // be in unmodified_dims. This means that the squishing target dim is the
2201   // lhs contracting dim.
2202   auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2203       reshape->operand(0)->shape(), reshape->shape());
2204   CHECK_EQ(lhs_contracting_dims.size(), 1);
2205   if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2206       absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
2207         return p.second == lhs_contracting_dims[0];
2208       })) {
2209     return nullptr;
2210   }
2211 
2212   // Virtually pull the reshape into the dot so the dot operates on the
2213   // transpose, with "unsquished" lhs contracting dims.  The new contracting
2214   // dims are all of the dims that are modified by the reshape -- that is, every
2215   // dimension that's not in `unmodified_dims[i].first`.
2216   //
2217   // (We don't need to actually create a new dot instruction. We can just keep
2218   // track of lhs and lhs_contracting_dims.)
2219   absl::flat_hash_set<int64> unmodified_transpose_dims;
2220   for (const auto& pair : unmodified_dims) {
2221     unmodified_transpose_dims.insert(pair.first);
2222   }
2223   lhs_contracting_dims.Clear();
2224   for (int64_t i = 0; i < transpose->shape().dimensions_size(); ++i) {
2225     if (!unmodified_transpose_dims.contains(i)) {
2226       lhs_contracting_dims.Add(i);
2227     }
2228   }
2229   // We require the "unsquished" lhs contracting dims to be consecutive.
2230   auto is_iota = [](absl::Span<const int64> dims) {
2231     return absl::c_adjacent_find(dims, [](const int64_t a, const int64_t b) {
2232              return (b != a + 1);
2233            }) == dims.end();
2234   };
2235   if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
2236     return nullptr;
2237   }
2238   lhs = lhs->mutable_operand(0);
2239 
2240   // Check that the transpose only permutes the contracting dims.
2241   const auto& transpose_dims = transpose->dimensions();
2242   for (int64_t i = 0; i < transpose_dims.size(); ++i) {
2243     if (transpose_dims[i] != i &&
2244         !absl::c_linear_search(lhs_contracting_dims, i)) {
2245       return nullptr;
2246     }
2247   }
2248   // Virtually pull the transpose into the dot. Now the dot is equivalent to
2249   // a new dot with "permuted" lhs contracting dims.
2250   std::vector<int64> permutation;
2251   for (auto dim : lhs_contracting_dims) {
2252     permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2253   }
2254   CHECK(IsPermutation(permutation));
2255   auto new_lhs_contracting_dims =
2256       ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
2257   lhs_contracting_dims.Clear();
2258   for (auto dim : new_lhs_contracting_dims) {
2259     lhs_contracting_dims.Add(dim);
2260   }
2261   lhs = lhs->mutable_operand(0);
2262 
2263   // All checks are passed at this point.
2264   //
2265   // Transform lhs. Remove the transpose and reshape by sorting the lhs
2266   // contracting dims and squishing them into a single one. We don't actually
2267   // squish the lhs_contracting_dims here because we still need the unsquished
2268   // contracting dims to invert reshape and transpose.
2269   absl::c_sort(lhs_contracting_dims);
2270   lhs = computation_->AddInstruction(
2271       HloInstruction::CreateReshape(reshape->shape(), lhs));
2272 
2273   // Transform rhs. Say the input HLO is:
2274   //
2275   //   t0 = f32[2, 2, 3] parameter(0)
2276   //   t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2277   //   t2 = f32[2, 6] reshape(t1)
2278   //   t3 = f32[6, 2] constant(...)
2279   //   dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2280   //                               rhs_contracting_dims={0}
2281   //
2282   // At this point in the function, we have decided that the second and third
2283   // dims of t0 can be switched to remove the transpose, and we have
2284   // "virtually decomposed" the input HLO to:
2285   //
2286   //   t0 = f32[2, 2, 3] parameter(0)
2287   //   t2' = f32[2, 6] reshape(t0)
2288   //   t3' = f32[6, 2] ops-to-be-filled ...
2289   //   dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2290   //                                 rhs_contracting_dims={0}
2291   //
2292   // The rest of this function is to fill in the ops of t3'. To do this, we
2293   // unsquish the contracting dimensions in t3 and then apply the inverse of
2294   // the transpose from t1.
2295 
2296   // Invert reshape.
2297   CHECK_EQ(rhs_contracting_dims.size(), 1);
2298   std::vector<int64> rhs_unsquished_shape_dims =
2299       SpanToVector(constant->shape().dimensions());
2300   auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2301                                             rhs_contracting_dims[0]);
2302   for (auto dim : lhs_contracting_dims) {
2303     it = rhs_unsquished_shape_dims.insert(it,
2304                                           transpose->shape().dimensions(dim));
2305     ++it;
2306   }
2307   HloInstruction* rhs_reshape =
2308       computation_->AddInstruction(HloInstruction::CreateReshape(
2309           ShapeUtil::MakeShape(constant->shape().element_type(),
2310                                rhs_unsquished_shape_dims),
2311           constant));
2312   rhs = rhs_reshape;
2313 
2314   // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2315   rhs_contracting_dims.Resize(lhs_contracting_dims.size(),
2316                               rhs_contracting_dims[0]);
2317   absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2318 
2319   // Invert transpose. First compute the shape.
2320   std::vector<int64> rhs_transpose_shape_dims =
2321       SpanToVector(rhs_reshape->shape().dimensions());
2322   it = rhs_transpose_shape_dims.erase(
2323       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2324       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2325           rhs_contracting_dims.size());
2326   for (auto dim : lhs_contracting_dims) {
2327     it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2328     ++it;
2329   }
2330   // Then compute the transpose dims.
2331   std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank());
2332   absl::c_iota(rhs_transpose_dims, 0);
2333   it = rhs_transpose_dims.erase(
2334       rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2335       rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2336           rhs_contracting_dims.size());
2337   auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2338   for (auto dim : lhs_contracting_dims) {
2339     it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2340                                            lhs_contracting_dims[0] +
2341                                            rhs_contracting_dims[0]);
2342     ++it;
2343   }
2344   HloInstruction* rhs_transpose =
2345       computation_->AddInstruction(HloInstruction::CreateTranspose(
2346           ShapeUtil::MakeShape(constant->shape().element_type(),
2347                                rhs_transpose_shape_dims),
2348           rhs_reshape, rhs_transpose_dims));
2349   rhs = rhs_transpose;
2350 
2351   // Squish the multiple rhs contracting dims into a single one.
2352   rhs = computation_->AddInstruction(
2353       HloInstruction::CreateReshape(constant->shape(), rhs));
2354 
2355   // If we virtually swapped lhs and rhs, we need to swap it back before
2356   // creating new dot.
2357   if (dot->operand(0)->IsConstant()) {
2358     std::swap(lhs, rhs);
2359   }
2360 
2361   HloInstruction* new_dot =
2362       computation_->AddInstruction(HloInstruction::CreateDot(
2363           dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2364   return new_dot;
2365 }
2366 
HandleDot(HloInstruction * dot)2367 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2368   CHECK(computation_ == dot->parent());
2369   HloInstruction *lhs, *rhs;
2370   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2371   if (options_.is_layout_sensitive()) {
2372     return Status::OK();
2373   }
2374   // Replace a zero element dot with a broadcast of the constant 0.
2375   if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2376       ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2377       ShapeUtil::IsZeroElementArray(rhs->shape())) {
2378     auto zero = computation_->AddInstruction(
2379         simplifier_->CreateConstantWithLayoutUpdated(
2380             LiteralUtil::Zero(dot->shape().element_type())));
2381     return ReplaceWithNewInstruction(
2382         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2383   }
2384 
2385   // If there are no contracting dimensions, a dot can be rewritten as
2386   // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2387   if (options_.enable_dot_to_multiply_rewrite() &&
2388       dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
2389     TF_ASSIGN_OR_RETURN(
2390         HloInstruction * new_lhs,
2391         NormalizeDotOperandToBatchMajorAndContractingMinor(
2392             lhs,
2393             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2394             AsInt64Slice(
2395                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2396     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2397       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2398     }
2399     TF_ASSIGN_OR_RETURN(
2400         HloInstruction * new_rhs,
2401         NormalizeDotOperandToBatchMajorAndContractingMinor(
2402             rhs,
2403             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2404             AsInt64Slice(
2405                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2406     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2407       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2408     }
2409     if (dot->shape().rank() != lhs->shape().rank()) {
2410       std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
2411       absl::c_iota(lhs_broadcast_dims, 0);
2412       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2413           dot->shape(), new_lhs, lhs_broadcast_dims));
2414     }
2415     if (dot->shape().rank() != rhs->shape().rank()) {
2416       std::vector<int64> rhs_broadcast_dims(
2417           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2418       absl::c_iota(rhs_broadcast_dims, 0);
2419       for (int64_t i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2420         rhs_broadcast_dims.push_back(i);
2421       }
2422       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2423           dot->shape(), new_rhs, rhs_broadcast_dims));
2424     }
2425     return ReplaceWithNewInstruction(
2426         dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2427                                           new_lhs, new_rhs));
2428   }
2429 
2430   // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2431   // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2432   if (options_.enable_dot_strength_reduction() &&
2433       ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2434             dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
2435         lhs->shape().rank()) ||
2436        (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
2437             dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
2438         rhs->shape().rank()))) {
2439     TF_ASSIGN_OR_RETURN(
2440         HloInstruction * new_lhs,
2441         NormalizeDotOperandToBatchMajorAndContractingMinor(
2442             lhs,
2443             AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2444             AsInt64Slice(
2445                 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2446     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2447       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2448     }
2449 
2450     TF_ASSIGN_OR_RETURN(
2451         HloInstruction * new_rhs,
2452         NormalizeDotOperandToBatchMajorAndContractingMinor(
2453             rhs,
2454             AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2455             AsInt64Slice(
2456                 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2457     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2458       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2459     }
2460 
2461     int64_t lhs_outer_dims =
2462         lhs->shape().rank() -
2463         (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2464          dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2465     int64_t rhs_outer_dims =
2466         rhs->shape().rank() -
2467         (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2468          dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
2469     CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2470     if (rhs_outer_dims > 0) {
2471       std::vector<int64> lhs_broadcast_dims(
2472           dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2473       absl::c_iota(lhs_broadcast_dims, 0);
2474       lhs_broadcast_dims.resize(lhs->shape().rank());
2475       std::iota(lhs_broadcast_dims.begin() +
2476                     dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
2477                 lhs_broadcast_dims.end(),
2478                 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2479                     rhs_outer_dims);
2480       new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2481           new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2482     } else if (lhs_outer_dims > 0) {
2483       std::vector<int64> rhs_broadcast_dims(
2484           dot->dot_dimension_numbers().rhs_batch_dimensions_size());
2485       absl::c_iota(rhs_broadcast_dims, 0);
2486       rhs_broadcast_dims.resize(rhs->shape().rank());
2487       std::iota(rhs_broadcast_dims.begin() +
2488                     dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
2489                 rhs_broadcast_dims.end(),
2490                 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2491                     lhs_outer_dims);
2492       new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2493           new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2494     }
2495 
2496     TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2497                         MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2498     std::vector<int64> reduce_dims(
2499         dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2500     PrimitiveType dot_type =
2501         ShapeUtil::ElementIsFloating(dot->shape())
2502             ? (dot->shape().element_type() == F64 ? F64 : F32)
2503             : dot->shape().element_type();
2504     new_dot = AsType(new_dot, dot_type);
2505     const int64_t outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2506     absl::c_iota(
2507         reduce_dims,
2508         outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2509     new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2510     new_dot = AsType(new_dot, dot->shape().element_type());
2511     return ReplaceInstruction(dot, new_dot);
2512   }
2513 
2514   // Simplify dot(reshape(transpose(A)), Const) to:
2515   // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2516   // and transpose on the Const side can be constant folded.
2517   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2518                       OptimizeDotOfReorderContractingDims(dot));
2519   if (dot_of_reorder_optimized) {
2520     VLOG(10) << " Replaced dot " << dot->ToString()
2521              << " with new dot operation: "
2522              << dot_of_reorder_optimized->ToString();
2523     return ReplaceInstruction(dot, dot_of_reorder_optimized);
2524   }
2525 
2526   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2527                       OptimizeDotOfConcat(dot));
2528   if (dot_of_concat_optimized) {
2529     VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2530                 "constant)...)";
2531     return ReplaceInstruction(dot, dot_of_concat_optimized);
2532   }
2533 
2534   // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2535   // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2536   // batched version of dot.
2537   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2538                       OptimizeDotOfGather(dot));
2539   if (dot_of_gather_optimized) {
2540     VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2541                 "gather(i, dot*(constA, constB))";
2542     return ReplaceInstruction(dot, dot_of_gather_optimized);
2543   }
2544 
2545   TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2546                       RemoveDegenerateDimensionFromDot(dot));
2547   if (removed_degenerate_dimensions) {
2548     return Status::OK();
2549   }
2550 
2551   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
2552   if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
2553       dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
2554       dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
2555       dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
2556       lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
2557     DotDimensionNumbers dot_dimension_numbers;
2558     dot_dimension_numbers.add_lhs_contracting_dimensions(1);
2559     dot_dimension_numbers.add_rhs_contracting_dimensions(0);
2560     auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
2561         ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
2562         rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
2563         dot->precision_config()));
2564     return ReplaceWithNewInstruction(
2565         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
2566   }
2567 
2568   return Status::OK();
2569 }
2570 
HandleGather(HloInstruction * gather)2571 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2572   const Shape& operand_shape = gather->operand(0)->shape();
2573   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2574     return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2575   }
2576 
2577   // Gathering from a scalar operand is simply a broadcast of that scalar
2578   if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2579     HloInstruction* new_operand = gather->mutable_operand(0);
2580     if (operand_shape.rank()) {
2581       TF_ASSIGN_OR_RETURN(new_operand,
2582                           MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2583                                              operand_shape.element_type()),
2584                                          new_operand));
2585     }
2586     HloInstruction* new_gather =
2587         MakeBroadcastHlo(new_operand, {}, gather->shape());
2588     return ReplaceInstruction(gather, new_gather);
2589   }
2590   // If the operand of a gather is very small, it is easier to fuse a
2591   // sequence of selects.
2592   const Shape& index_shape = gather->operand(1)->shape();
2593   if (operand_shape.rank() == 1 &&
2594       operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2595       gather->gather_dimension_numbers().index_vector_dim() ==
2596           index_shape.rank() &&
2597       gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2598     const int64_t operand_elements = operand_shape.dimensions(0);
2599     auto get_value = [&](int64_t i) {
2600       auto slice = computation_->AddInstruction(HloInstruction::CreateSlice(
2601           ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2602           gather->mutable_operand(0), {i}, {i + 1}, {1}));
2603       auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape(
2604           ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2605       return computation_->AddInstruction(
2606           HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2607     };
2608     auto result = get_value(0);
2609     auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2610     simplifier_->UpdateLayout(&pred_shape);
2611     auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2612                                                    index_shape.element_type());
2613     simplifier_->UpdateLayout(&iter_shape);
2614     for (int64_t i = 0; i < operand_elements; ++i) {
2615       auto index_mask =
2616           computation_->AddInstruction(HloInstruction::CreateCompare(
2617               pred_shape, gather->mutable_operand(1),
2618               MakeScalarLike(gather->mutable_operand(1), i),
2619               ComparisonDirection::kGe));
2620       result = computation_->AddInstruction(
2621           HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2622                                         index_mask, get_value(i), result));
2623     }
2624     return ReplaceInstruction(gather, result);
2625   }
2626   return Status::OK();
2627 }
2628 
2629 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast,AlgebraicSimplifier * simplifier)2630 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2631     HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2632     HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) {
2633   HloInstruction* clamp_lower_bound;
2634   CHECK(Match(clamp_lower_bound_bcast,
2635               m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2636       << clamp_lower_bound_bcast->ToString();
2637 
2638   HloInstruction* clamp_upper_bound;
2639   CHECK(Match(clamp_upper_bound_bcast,
2640               m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2641       << clamp_upper_bound_bcast->ToString();
2642 
2643   const Literal& lower_bound =
2644       Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2645   const Literal& upper_bound =
2646       Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2647 
2648   TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped,
2649                       lower_bound.Reshape({}));
2650   TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped,
2651                       upper_bound.Reshape({}));
2652   std::unique_ptr<HloInstruction> lower_bound_instr =
2653       HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped));
2654   std::unique_ptr<HloInstruction> upper_bound_instr =
2655       HloInstruction::CreateConstant(std::move(upper_bound_literal_reshaped));
2656 
2657   Shape compare_shape =
2658       ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED);
2659   simplifier->UpdateLayout(&compare_shape);
2660   std::unique_ptr<HloInstruction> cloned_instruction =
2661       HloInstruction::CreateCompare(compare_shape, lower_bound_instr.get(),
2662                                     upper_bound_instr.get(),
2663                                     ComparisonDirection::kLt);
2664 
2665   HloEvaluator evaluator;
2666   TF_ASSIGN_OR_RETURN(auto result,
2667                       evaluator.Evaluate(cloned_instruction.get()));
2668   if (result.IsAll(true)) {
2669     return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2670                                          clamp_lower_bound_bcast, to_clamp,
2671                                          clamp_upper_bound_bcast);
2672   }
2673   return std::unique_ptr<HloInstruction>();
2674 }
2675 }  // namespace
2676 
HandleMaximum(HloInstruction * maximum)2677 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2678   HloInstruction *lhs, *rhs;
2679   CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2680 
2681   HloInstruction* clamp_upper_bound_bcast;
2682   HloInstruction* clamp_lower_bound_bcast;
2683   HloInstruction* to_clamp;
2684   if (Match(maximum, m::MaximumAnyOrder(
2685                          m::Broadcast(&clamp_lower_bound_bcast,
2686                                       m::ConstantEffectiveScalar()),
2687                          m::MinimumAnyOrder(
2688                              m::Op(&to_clamp),
2689                              m::Broadcast(&clamp_upper_bound_bcast,
2690                                           m::ConstantEffectiveScalar()))))) {
2691     TF_ASSIGN_OR_RETURN(auto clamp,
2692                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2693                                       clamp_upper_bound_bcast, simplifier_));
2694     if (clamp) {
2695       return ReplaceWithNewInstruction(maximum, std::move(clamp));
2696     }
2697   }
2698 
2699   HloInstruction* clamp_lower_bound;
2700   HloInstruction* clamp_upper_bound;
2701   HloInstruction* max_operand;
2702   HloInstruction* clamp;
2703   if (Match(maximum,
2704             m::MaximumAnyOrder(
2705                 m::Op(&max_operand),
2706                 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2707                          m::Op(&clamp_upper_bound))))) {
2708     if (max_operand == clamp_lower_bound &&
2709         ReplaceInstructionIfSameShape(maximum, clamp)) {
2710       return Status::OK();
2711     }
2712   }
2713 
2714   return Status::OK();
2715 }
2716 
HandleMinimum(HloInstruction * minimum)2717 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2718   HloInstruction *lhs, *rhs;
2719   CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2720 
2721   HloInstruction* clamp_upper_bound_bcast;
2722   HloInstruction* clamp_lower_bound_bcast;
2723   HloInstruction* to_clamp;
2724   if (Match(minimum, m::MinimumAnyOrder(
2725                          m::Broadcast(&clamp_upper_bound_bcast,
2726                                       m::ConstantEffectiveScalar()),
2727                          m::MaximumAnyOrder(
2728                              m::Op(&to_clamp),
2729                              m::Broadcast(&clamp_lower_bound_bcast,
2730                                           m::ConstantEffectiveScalar()))))) {
2731     TF_ASSIGN_OR_RETURN(auto clamp,
2732                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2733                                       clamp_upper_bound_bcast, simplifier_));
2734     if (clamp) {
2735       return ReplaceWithNewInstruction(minimum, std::move(clamp));
2736     }
2737   }
2738 
2739   return Status::OK();
2740 }
2741 
HandleClamp(HloInstruction * clamp)2742 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2743   HloInstruction* clamp_lower_bound;
2744   HloInstruction* clamp_upper_bound;
2745   HloInstruction* to_clamp;
2746   CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2747                               m::Op(&clamp_upper_bound))));
2748 
2749   // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2750   if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2751                                m::Op().Is(clamp_upper_bound))) &&
2752       ReplaceInstructionIfSameShape(clamp, to_clamp)) {
2753     return Status::OK();
2754   }
2755 
2756   // Eliminate redundant clamping of replica-id or partition-id.
2757   if ((Match(to_clamp, m::PartitionId()) || Match(to_clamp, m::ReplicaId())) &&
2758       Match(clamp_lower_bound, m::ConstantScalar(0U)) &&
2759       Match(clamp_upper_bound, m::ConstantScalar())) {
2760     int64_t upper_bound = Cast<HloConstantInstruction>(clamp_upper_bound)
2761                               ->literal()
2762                               .GetFirstElement<uint32_t>();
2763     const HloModuleConfig& config = clamp->GetModule()->config();
2764     int64_t runtime_bound = Match(to_clamp, m::PartitionId())
2765                                 ? config.num_partitions()
2766                                 : config.replica_count();
2767 
2768     // If num_partitions or replica_count is 1, infer it as unknown.
2769     // pid/rid < runtime_bound => The clamp(0, pid/rid, upper_bound) is
2770     // redundant if the runtime_bound <= upper_bound + 1;
2771     if (runtime_bound != 1 && runtime_bound <= upper_bound + 1) {
2772       return ReplaceInstruction(clamp, to_clamp);
2773     }
2774   }
2775 
2776   return Status::OK();
2777 }
2778 
HandleMultiply(HloInstruction * multiply)2779 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
2780   HloInstruction *lhs, *rhs;
2781   CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
2782   // LHS*1 => LHS
2783   VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
2784   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
2785     return Status::OK();
2786   }
2787   // 1*RHS => RHS
2788   VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
2789   if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
2790     return Status::OK();
2791   }
2792 
2793   // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
2794   if (IsAll(lhs, 0) &&
2795       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2796       ReplaceInstructionIfSameShape(multiply, lhs)) {
2797     return Status::OK();
2798   }
2799   // LHS*0 => 0
2800   if (IsAll(rhs, 0) &&
2801       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2802       ReplaceInstructionIfSameShape(multiply, rhs)) {
2803     return Status::OK();
2804   }
2805 
2806   {
2807     HloInstruction* abs_operand;
2808     if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
2809         !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
2810       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
2811       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
2812       changed_ = true;
2813       return Status::OK();
2814     }
2815   }
2816 
2817   {
2818     HloInstruction *convert_operand, *operand;
2819     // Mul(Convert(Pred), operand) => select(pred, operand, 0)
2820     if (Match(multiply,
2821               m::MultiplyAnyOrder(
2822                   m::Op(&operand),
2823                   m::Convert(
2824                       m::Op(&convert_operand)
2825                           .WithShape(m::Shape().WithElementType(PRED)))))) {
2826       HloInstruction* zero_like_multiply =
2827           BroadcastZeros(computation_, multiply->shape().element_type(),
2828                          multiply->shape().dimensions());
2829       return ReplaceWithNewInstruction(
2830           multiply, HloInstruction::CreateTernary(
2831                         multiply->shape(), HloOpcode::kSelect, convert_operand,
2832                         operand, zero_like_multiply));
2833     }
2834   }
2835 
2836   {
2837     HloInstruction *a, *b, *c1, *c2;
2838     // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
2839     // constant1*constant2)
2840     if (Match(multiply,
2841               m::MultiplyAnyOrder(
2842                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2843                   m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
2844       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2845                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2846       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2847           !ShapeUtil::IsScalar(multiply->shape())) {
2848         product_of_constants =
2849             computation_->AddInstruction(HloInstruction::CreateBroadcast(
2850                 multiply->shape(), product_of_constants, {}));
2851       }
2852 
2853       return ReplaceWithNewInstruction(
2854           multiply,
2855           HloInstruction::CreateBinary(
2856               multiply->shape(), HloOpcode::kMultiply,
2857               computation_->AddInstruction(HloInstruction::CreateBinary(
2858                   multiply->shape(), HloOpcode::kMultiply, a, b)),
2859               product_of_constants));
2860     }
2861   }
2862 
2863   {
2864     HloInstruction *a, *c1, *c2;
2865     // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
2866     if (Match(multiply,
2867               m::MultiplyAnyOrder(
2868                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2869                   m::Constant(&c2)))) {
2870       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2871                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2872       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2873           !ShapeUtil::IsScalar(multiply->shape())) {
2874         product_of_constants =
2875             computation_->AddInstruction(HloInstruction::CreateBroadcast(
2876                 multiply->shape(), product_of_constants, {}));
2877       }
2878 
2879       return ReplaceWithNewInstruction(
2880           multiply,
2881           HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
2882                                        a, product_of_constants));
2883     }
2884   }
2885 
2886   {
2887     HloInstruction *a, *b, *constant, *op;
2888     // Mul(Mul(a, constant1), Broadcast(b)) =>
2889     // Mul(Broadcast(Mul(b, constant1), a))
2890     if (Match(multiply,
2891               m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
2892                                                       m::Constant(&constant)),
2893                                   m::Op(&op))) ||
2894         Match(multiply,
2895               m::MultiplyAnyOrder(
2896                   m::MultiplyAnyOrder(m::NonConstant(&a),
2897                                       m::Broadcast(m::Constant(&constant))),
2898                   m::Op(&op)))) {
2899       // Check that the other side was a broadcast, and not of a constant.
2900       if (ShapeUtil::IsScalar(constant->shape()) &&
2901           Match(op, m::Broadcast(m::NonConstant()))) {
2902         auto dims = op->dimensions();
2903         b = op->mutable_operand(0);
2904         if (!ShapeUtil::IsScalar(b->shape())) {
2905           constant = computation_->AddInstruction(
2906               HloInstruction::CreateBroadcast(b->shape(), constant, {}));
2907         }
2908 
2909         auto new_mul =
2910             computation_->AddInstruction(HloInstruction::CreateBinary(
2911                 b->shape(), HloOpcode::kMultiply, b, constant));
2912 
2913         return ReplaceWithNewInstruction(
2914             multiply,
2915             HloInstruction::CreateBinary(
2916                 multiply->shape(), HloOpcode::kMultiply, a,
2917                 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2918                     multiply->shape(), new_mul, dims))));
2919       }
2920     }
2921   }
2922 
2923   VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
2924   HloInstruction *a, *c1, *c2;
2925   if (Match(multiply,
2926             m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
2927                         m::Constant(&c2))) ||
2928       Match(multiply,
2929             m::Multiply(
2930                 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
2931                 m::Broadcast(m::ConstantScalar(&c2))))) {
2932     TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2933                         MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2934     if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2935         !ShapeUtil::IsScalar(multiply->shape())) {
2936       product_of_constants =
2937           computation_->AddInstruction(HloInstruction::CreateBroadcast(
2938               multiply->shape(), product_of_constants, {}));
2939     }
2940     return ReplaceWithNewInstruction(
2941         multiply,
2942         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
2943                                      product_of_constants));
2944   }
2945 
2946   VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
2947            << multiply->ToString();
2948   if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
2949     auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
2950         multiply->shape(), HloOpcode::kAdd, lhs, rhs));
2951     return ReplaceWithNewInstruction(
2952         multiply,
2953         HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
2954   }
2955 
2956   VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
2957            << multiply->ToString();
2958   HloInstruction* b;
2959   if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
2960       IsPositive(b, options_)) {
2961     return ReplaceWithNewInstruction(
2962         multiply,
2963         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
2964                                      MakeScalarLike(b, 1), b));
2965   }
2966 
2967   return Status::OK();
2968 }
2969 
HandleNegate(HloInstruction * negate)2970 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
2971   // negate(negate(x)) => x
2972   HloInstruction* x;
2973   if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
2974       ReplaceInstructionIfSameShape(negate, x)) {
2975     return Status::OK();
2976   }
2977   return Status::OK();
2978 }
2979 
HandleNot(HloInstruction * logical_not)2980 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
2981   // not(not(x)) => x
2982   HloInstruction* x;
2983   if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
2984       ReplaceInstructionIfSameShape(logical_not, x)) {
2985     return Status::OK();
2986   }
2987   return Status::OK();
2988 }
2989 
HandleOr(HloInstruction * logical_or)2990 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
2991   HloInstruction *lhs, *rhs;
2992   CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
2993 
2994   // Simplify logical or
2995   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
2996       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
2997     // A || True => True
2998     VLOG(10) << "trying transform [A || True => True]: "
2999              << logical_or->ToString();
3000     if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
3001       return Status::OK();
3002     }
3003     // True || A => True
3004     VLOG(10) << "trying transform [True || A => True]: "
3005              << logical_or->ToString();
3006     if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
3007       return Status::OK();
3008     }
3009   }
3010 
3011   // A || False => A and A | 0 => A
3012   VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
3013   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
3014     return Status::OK();
3015   }
3016 
3017   // False || A => A and 0 | A => A
3018   VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
3019   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
3020     return Status::OK();
3021   }
3022 
3023   return Status::OK();
3024 }
3025 
HandleLog(HloInstruction * log)3026 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
3027   // ln(exp(A)) => A
3028   VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
3029   HloInstruction *a, *b;
3030   if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
3031       ReplaceInstructionIfSameShape(log, a)) {
3032     return Status::OK();
3033   }
3034 
3035   // ln(pow(A,B)) => B*ln(abs(A))
3036   // or B*ln(A) if A is complex.
3037   if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
3038     auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
3039                      ? a
3040                      : computation_->AddInstruction(HloInstruction::CreateUnary(
3041                            log->shape(), HloOpcode::kAbs, a));
3042     auto new_log = computation_->AddInstruction(
3043         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
3044     return ReplaceWithNewInstruction(
3045         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3046                                           new_log, b));
3047   }
3048 
3049   if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
3050     auto new_log = computation_->AddInstruction(
3051         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3052     return ReplaceWithNewInstruction(
3053         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3054                                           new_log, MakeScalarLike(log, 0.5)));
3055   }
3056 
3057   if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
3058     auto new_log = computation_->AddInstruction(
3059         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3060     return ReplaceWithNewInstruction(
3061         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3062                                           new_log, MakeScalarLike(log, -0.5)));
3063   }
3064 
3065   return Status::OK();
3066 }
3067 
HandleGetTupleElement(HloInstruction * get_tuple_element)3068 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
3069     HloInstruction* get_tuple_element) {
3070   auto operand = get_tuple_element->mutable_operand(0);
3071   if (operand->opcode() == HloOpcode::kTuple) {
3072     // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
3073     VLOG(10) << "trying transform "
3074              << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
3075              << get_tuple_element->ToString();
3076     if (ReplaceInstructionIfSameShape(
3077             get_tuple_element,
3078             operand->mutable_operand(get_tuple_element->tuple_index()))) {
3079       return Status::OK();
3080     }
3081   }
3082   return Status::OK();
3083 }
3084 
3085 namespace {
3086 
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)3087 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
3088     const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
3089   CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3090   return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3091       hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3092 }
3093 
3094 // Returns true if the output of "instruction" is a permutation of the
3095 // elements of "operand". Precondition: "operand" is an operand of
3096 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3097 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3098                                           HloInstruction* operand) {
3099   DCHECK(!instruction->OperandIndices(operand).empty());
3100   switch (instruction->opcode()) {
3101     case HloOpcode::kReshape:
3102     case HloOpcode::kReverse:
3103     case HloOpcode::kTranspose:
3104       return true;
3105     case HloOpcode::kSort:
3106       return (!instruction->shape().IsTuple());
3107     default:
3108       return false;
3109   }
3110 }
3111 
3112 // Returns true if the output of "instruction" is a subset of the elements of
3113 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3114 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3115                                      HloInstruction* operand) {
3116   const auto operand_indices = instruction->OperandIndices(operand);
3117   CHECK(!operand_indices.empty());
3118   if (operand_indices.size() != 1) {
3119     return false;
3120   }
3121   int64_t operand_index = operand_indices[0];
3122   switch (instruction->opcode()) {
3123     case HloOpcode::kSlice:
3124       CHECK_EQ(0, operand_index);
3125       return true;
3126     case HloOpcode::kDynamicSlice:
3127       return operand_index == 0;
3128     default:
3129       return false;
3130   }
3131 }
3132 
3133 }  // namespace
3134 
HandleBroadcast(HloInstruction * broadcast)3135 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3136   HloInstruction* operand;
3137   CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3138   auto dims = broadcast->dimensions();
3139   // A degenerate broadcast of a reshape that does not change the number of
3140   // elements can be replaced by a reshape.
3141   if (std::is_sorted(dims.begin(), dims.end()) &&
3142       ShapeUtil::ElementsIn(broadcast->shape()) ==
3143           ShapeUtil::ElementsIn(operand->shape())) {
3144     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3145                 "n(broadcast(X)) == n(X)";
3146     return ReplaceWithNewInstruction(
3147         broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3148   }
3149 
3150   // A degenerate broadcast that has the same input and output rank can be
3151   // converted into a transpose.
3152   if (broadcast->shape().rank() == operand->shape().rank() &&
3153       ShapeUtil::ElementsIn(broadcast->shape()) ==
3154           ShapeUtil::ElementsIn(operand->shape())) {
3155     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3156                 "n(broadcast(X)) == n(X)";
3157     return ReplaceWithNewInstruction(
3158         broadcast,
3159         HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3160   }
3161 
3162   // A broadcast of a reshape which merely inserts 1-sized dimensions can
3163   // elide its operand.
3164   {
3165     bool merely_inserts_or_deletes_1_sized_dimensions;
3166     std::vector<int64> inserted_indices, deleted_indices;
3167     std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
3168              inserted_indices) =
3169         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3170     if (merely_inserts_or_deletes_1_sized_dimensions &&
3171         deleted_indices.empty()) {
3172       std::reverse(inserted_indices.begin(), inserted_indices.end());
3173       for (auto inserted_index : inserted_indices) {
3174         dims.erase(dims.begin() + inserted_index);
3175       }
3176       return ReplaceWithNewInstruction(
3177           broadcast,
3178           HloInstruction::CreateBroadcast(broadcast->shape(),
3179                                           operand->mutable_operand(0), dims));
3180     }
3181   }
3182 
3183   // A Broadcast that feeds a unary element-wise operation can sink the
3184   // broadcast after the unary element-wise operation.
3185   TF_ASSIGN_OR_RETURN(
3186       bool sink_succeeded,
3187       TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3188   changed_ |= sink_succeeded;
3189   if (sink_succeeded) {
3190     return Status::OK();
3191   }
3192 
3193   // A scalar broadcast feeding an instruction which only permutes (reshape,
3194   // transpose, sort, reverse) or selects a subset of operand elements (slice,
3195   // dynamic slice) can be replaced with a broadcast directly to the output
3196   // shape of the instruction.
3197   if (ShapeUtil::IsScalar(operand->shape())) {
3198     for (HloInstruction* user : broadcast->users()) {
3199       // Skip if the broadcast user has no uses itself.
3200       if (user->user_count() == 0 && user != computation_->root_instruction()) {
3201         continue;
3202       }
3203       if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3204           OutputIsSubsetOfOperandElements(user, broadcast)) {
3205         VLOG(10) << "transform permuting/subset  of a scalar broadcast into "
3206                  << "a single broadcast";
3207         HloInstruction* new_broadcast = computation_->AddInstruction(
3208             HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3209         // Use HloInstruction::ReplaceAllUsesWith instead of
3210         // HloComputation::ReplaceWithNewInstruction because we are replacing an
3211         // instruction other than the visited instruction.
3212         changed_ = true;
3213         return user->ReplaceAllUsesWith(new_broadcast);
3214       }
3215     }
3216     return Status::OK();
3217   }
3218 
3219   // broadcast(iota) -> iota.
3220   if (operand->opcode() == HloOpcode::kIota) {
3221     return ReplaceWithNewInstruction(
3222         broadcast,
3223         HloInstruction::CreateIota(
3224             broadcast->shape(),
3225             dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3226   }
3227 
3228   // Merge two consecutive broadcasts into a single one.
3229   if (operand->opcode() == HloOpcode::kBroadcast) {
3230     std::vector<int64> new_dimensions;
3231     for (auto dim : operand->dimensions()) {
3232       new_dimensions.push_back(dims[dim]);
3233     }
3234     return ReplaceWithNewInstruction(
3235         broadcast,
3236         HloInstruction::CreateBroadcast(
3237             broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3238   }
3239   if (options_.is_layout_sensitive()) {
3240     return Status::OK();
3241   }
3242   if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3243     auto new_operand =
3244         operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3245             ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3246     std::vector<int64> new_dims;
3247     new_dims.reserve(new_operand->shape().rank());
3248     for (int64_t i = 0; i < operand->shape().rank(); ++i) {
3249       if (operand->shape().dimensions(i) != 1) {
3250         new_dims.push_back(dims[i]);
3251       }
3252     }
3253     return ReplaceWithNewInstruction(
3254         broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3255                                                    new_operand, new_dims));
3256   }
3257   return Status::OK();
3258 }
3259 
HandleCompare(HloInstruction * compare)3260 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3261   HloInstruction* lhs;
3262   HloInstruction* rhs;
3263   CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3264   {
3265     // compare(broadcast(a) + x, broadcast(b)) ==>
3266     //   compare(x, broadcast(b-a)), only enabled for integral types.
3267     HloInstruction *x, *a, *b;
3268     if (Match(compare,
3269               m::Compare(
3270                   m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
3271                                                 m::Shape().IsScalar()))),
3272                   m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
3273       if (ShapeUtil::ElementIsSigned(x->shape()) &&
3274           ShapeUtil::ElementIsIntegral(x->shape())) {
3275         HloInstruction* sub =
3276             computation_->AddInstruction(HloInstruction::CreateBinary(
3277                 b->shape(), HloOpcode::kSubtract, b, a));
3278         HloInstruction* broadcast = computation_->AddInstruction(
3279             HloInstruction::CreateBroadcast(x->shape(), sub, {}));
3280         HloInstruction* new_compare = computation_->AddInstruction(
3281             HloInstruction::CreateCompare(compare->shape(), x, broadcast,
3282                                           compare->comparison_direction()));
3283         return ReplaceInstruction(compare, new_compare);
3284       }
3285     }
3286   }
3287 
3288   if (Cast<HloCompareInstruction>(compare)->type() ==
3289       Comparison::Type::kUnsigned) {
3290     // X u<  0 -> false
3291     if (compare->comparison_direction() == ComparisonDirection::kLt &&
3292         IsAll(rhs, 0)) {
3293       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3294     }
3295     // X u>= 0 -> true
3296     if (compare->comparison_direction() == ComparisonDirection::kGe &&
3297         IsAll(rhs, 0)) {
3298       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3299     }
3300     // 0 u>  X -> false
3301     if (compare->comparison_direction() == ComparisonDirection::kGt &&
3302         IsAll(lhs, 0)) {
3303       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3304     }
3305     // 0 u<= X -> true
3306     if (compare->comparison_direction() == ComparisonDirection::kLe &&
3307         IsAll(lhs, 0)) {
3308       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3309     }
3310   }
3311 
3312   if (compare->comparison_direction() == ComparisonDirection::kLt &&
3313       lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3314     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3315   } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3316              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3317     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3318   } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3319              lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3320     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3321   } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3322              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3323     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3324   }
3325   if (lhs == rhs &&
3326       primitive_util::IsIntegralType(lhs->shape().element_type())) {
3327     switch (compare->comparison_direction()) {
3328       case ComparisonDirection::kGt:
3329       case ComparisonDirection::kLt:
3330       case ComparisonDirection::kNe:
3331         return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3332       case ComparisonDirection::kEq:
3333       case ComparisonDirection::kGe:
3334       case ComparisonDirection::kLe:
3335         return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3336     }
3337   }
3338   return Status::OK();
3339 }
3340 
HandleConvert(HloInstruction * convert)3341 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3342   PrimitiveType src_type = convert->operand(0)->shape().element_type();
3343   PrimitiveType dest_type = convert->shape().element_type();
3344   // A conversion to the same element type as the operand is a nop and can be
3345   // removed.  A conversion of a constant can be simplified by making a new
3346   // constant.
3347   if (src_type == dest_type) {
3348     return ReplaceInstruction(convert, convert->mutable_operand(0));
3349   }
3350 
3351   // Eliminate a convert pair if it is a no-op. The following are a few
3352   // example cases that are being handled:
3353   // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3354   //    and convert(A, $TYPE1) is an upcast
3355   // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3356   //    and convert(A, $TYPE1) is an upcast and is an integral conversion from
3357   //    unsigned to signed (only signed to unsigned conversion is NOT allowed)
3358   // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3359   //    convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3360   //    $TYPE1) , floor(A), A) -> a case where the first convert has a
3361   //    fan-out
3362   if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
3363       IsConvertPairNoOp(convert)) {
3364     return ReplaceInstruction(convert,
3365                               convert->mutable_operand(0)->mutable_operand(0));
3366   }
3367   return Status::OK();
3368 }
3369 
3370 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3371 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3372   HloInstruction *c0, *c1;
3373   if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3374       c0 == c1) {
3375     return ReplaceInstruction(complex, c0);
3376   }
3377   return Status::OK();
3378 }
3379 
3380 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3381 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3382   HloInstruction* op;
3383   if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3384     return ReplaceInstruction(real, op);
3385   }
3386   return Status::OK();
3387 }
3388 
3389 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3390 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3391   HloInstruction* op;
3392   if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3393     return ReplaceInstruction(imag, op);
3394   }
3395   return Status::OK();
3396 }
3397 
HandleIota(HloInstruction * instruction)3398 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3399   // iota -> zero if the iota dimension never produces an element other than
3400   // zero.
3401   auto* iota = Cast<HloIotaInstruction>(instruction);
3402   if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3403     auto zero = computation_->AddInstruction(
3404         simplifier_->CreateConstantWithLayoutUpdated(
3405             LiteralUtil::Zero(iota->shape().element_type()).Clone()));
3406     return ReplaceWithNewInstruction(
3407         iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
3408   }
3409   return Status::OK();
3410 }
3411 
HandlePad(HloInstruction * pad)3412 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3413   if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3414     return ReplaceWithNewInstruction(
3415         pad, HloInstruction::CreateBroadcast(pad->shape(),
3416                                              pad->mutable_operand(1), {}));
3417   }
3418 
3419   // Interior padding on one sized dimensions have no effect. As a result it
3420   // makes other simplifications possible if there is no interior padding.
3421   if (HasInteriorPadding(pad->padding_config())) {
3422     PaddingConfig padding_config = pad->padding_config();
3423     bool cleared_interior_padding = false;
3424     for (int64_t i = 0; i < pad->shape().rank(); ++i) {
3425       if (padding_config.dimensions(i).interior_padding() > 0 &&
3426           pad->operand(0)->shape().dimensions(i) == 1) {
3427         cleared_interior_padding = true;
3428         padding_config.mutable_dimensions(i)->set_interior_padding(0);
3429       }
3430     }
3431     if (cleared_interior_padding) {
3432       return ReplaceWithNewInstruction(
3433           pad,
3434           HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3435                                     pad->mutable_operand(1), padding_config));
3436     }
3437   }
3438 
3439   // Eliminate nop pads (padding all zero), and replace a pad with negative
3440   // padding with a pad with non-negative padding followed by a slice.
3441   bool all_zero = true;
3442   bool has_negative = false;
3443   // Used to possibly split off the unchanged padding dimensions.
3444   std::vector<int64> padding_dimensions;
3445   int64_t dimension_index = 0;
3446   for (auto& padding_dimension : pad->padding_config().dimensions()) {
3447     if (padding_dimension.edge_padding_low() < 0 ||
3448         padding_dimension.edge_padding_high() < 0) {
3449       has_negative = true;
3450     }
3451     if (padding_dimension.edge_padding_low() != 0 ||
3452         padding_dimension.edge_padding_high() != 0) {
3453       all_zero = false;
3454       padding_dimensions.push_back(dimension_index);
3455     } else if (padding_dimension.interior_padding()) {
3456       padding_dimensions.push_back(dimension_index);
3457     }
3458     dimension_index++;
3459   }
3460 
3461   if (all_zero) {
3462     if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) {
3463       return Status::OK();
3464     }
3465   }
3466 
3467   // The context of this optimization can be found at b/163617402
3468   // It tries to capture the case of pad(broadcast(x)), where
3469   // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3470   // a subset of the padded dimensions in pad->config(),
3471   // and the padded dimensions in pad->config() is in turn a strict
3472   // subset of broadcast->shape().dimensions(). The combined op can be
3473   // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3474   // x  with dimensions that need to be padded, and broadcast2 extends
3475   // the result of padding to full dimensions.
3476   // TODO(qyi): for future extensions: The condition for broadcast(x)
3477   // ->dimensions() to be a subset of padded dimensions in pad->config()
3478   // does not have to be strictly required, but it makes the calculation
3479   // for optimization easier, so it is required by the current implementation.
3480   // Only the second condition between the padded dimensions and the
3481   // dimensions of the final shape have to be enforced for the optimization
3482   // to make sense. If needed to remove the first constraint, the shape
3483   // calculations across the implementation need to be re-adjusted.
3484   auto pad_dims = padding_dimensions.size();
3485   if (pad_dims < dimension_index &&
3486       pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3487       pad->operand(0)->user_count() == 1 &&
3488       pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3489     // Check broadcast operand dimensions is a subset of pading_dimensions.
3490     // If not, skip the optimization.
3491     bool opt_is_valid = true;
3492     std::vector<int64> broadcast_dimensions;
3493     HloBroadcastInstruction* broadcast =
3494         static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3495     for (auto broadcast_index : broadcast->dimensions()) {
3496       bool found = false;
3497       for (int i = 0; i < pad_dims; ++i) {
3498         if (broadcast_index == padding_dimensions[i]) {
3499           broadcast_dimensions.push_back(i);
3500           found = true;
3501           break;
3502         }
3503       }
3504       if (!found) {
3505         opt_is_valid = false;
3506         break;
3507       }
3508     }
3509     if (opt_is_valid) {
3510       auto pad_shape = pad->shape();
3511       auto broadcast_shape = broadcast->shape();
3512       auto pad_shape1 = pad_shape;
3513       auto broadcast_shape1 = broadcast_shape;
3514       PaddingConfig pad_config;
3515       for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3516         int64_t j = padding_dimensions[i];
3517         while (--dimension_index > j) {
3518           broadcast_shape1.DeleteDimension(dimension_index);
3519           pad_shape1.DeleteDimension(dimension_index);
3520         }
3521       }
3522       while (--dimension_index >= 0) {
3523         broadcast_shape1.DeleteDimension(dimension_index);
3524         pad_shape1.DeleteDimension(dimension_index);
3525       }
3526       for (auto dimension_to_pad : padding_dimensions) {
3527         auto dimension = pad_config.add_dimensions();
3528         *dimension = pad->padding_config().dimensions(dimension_to_pad);
3529       }
3530       *broadcast->mutable_shape() = broadcast_shape1;
3531       *broadcast->mutable_dimensions() = broadcast_dimensions;
3532       simplifier_->UpdateLayout(broadcast->mutable_shape());
3533       auto pad2 =
3534           computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3535       *pad2->mutable_padding_config() = pad_config;
3536       simplifier_->UpdateLayout(pad2->mutable_shape());
3537       auto broadcast2 = computation_->AddInstruction(
3538           HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3539       return ReplaceInstruction(pad, broadcast2);
3540     }
3541   }
3542 
3543   if (has_negative && options_.enable_negative_padding_replacement()) {
3544     // Pad has negative padding. Replace with a pad with the non-negative
3545     // padding followed by a slice which effectively performs the negative
3546     // padding.
3547     // TODO(b/34628603): Add support for negative padding in the backends, or
3548     // change kPad semantics to disallow negative padding and use slice
3549     // instead.
3550 
3551     // First construct the padding config with non-negative entries and the
3552     // compute the shape of this new pad instruction.
3553     PaddingConfig nonzero_padding = pad->padding_config();
3554     for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3555       PaddingConfig::PaddingConfigDimension* padding_dimension =
3556           nonzero_padding.mutable_dimensions(i);
3557       // Set negative padding to zero.
3558       if (padding_dimension->edge_padding_low() < 0) {
3559         padding_dimension->set_edge_padding_low(0);
3560       }
3561       if (padding_dimension->edge_padding_high() < 0) {
3562         padding_dimension->set_edge_padding_high(0);
3563       }
3564     }
3565 
3566     TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3567                         MakePadHlo(pad->mutable_operand(0),
3568                                    pad->mutable_operand(1), nonzero_padding));
3569     // Copy the layout from the original pad instructions. The new pad and the
3570     // slice instruction should all have the same layout.
3571     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3572         pad->shape(), nonzero_pad->mutable_shape()));
3573     simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3574 
3575     // Second, construct the slice instruction to perform the negative
3576     // padding.
3577     std::vector<int64> start_indices;
3578     std::vector<int64> end_indices;
3579     std::vector<int64> strides;
3580     for (int64_t i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3581       const PaddingConfig::PaddingConfigDimension& padding_dimension =
3582           pad->padding_config().dimensions(i);
3583       int64_t start = 0;
3584       if (padding_dimension.edge_padding_low() < 0) {
3585         start = -1 * padding_dimension.edge_padding_low();
3586       }
3587       int64_t end = nonzero_pad->shape().dimensions(i);
3588       if (padding_dimension.edge_padding_high() < 0) {
3589         end += padding_dimension.edge_padding_high();
3590       }
3591       start_indices.push_back(start);
3592       end_indices.push_back(end);
3593       strides.push_back(1);
3594     }
3595 
3596     TF_ASSIGN_OR_RETURN(
3597         HloInstruction * slice,
3598         MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3599     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3600         pad->shape(), slice->mutable_shape()));
3601     simplifier_->UpdateLayout(slice->mutable_shape());
3602 
3603     // Verify that the slice shape matches the pad shape.
3604     auto equal = Shape::Equal();
3605     if (!options_.is_layout_sensitive()) {
3606       equal.IgnoreTilesInLayout();
3607     }
3608     TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3609 
3610     return ReplaceInstruction(pad, slice);
3611   }
3612 
3613   return Status::OK();
3614 }
3615 
HandlePower(HloInstruction * power)3616 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3617   VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3618   HloInstruction *lhs, *rhs;
3619   CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3620   if (IsAll(rhs, 0)) {
3621     return ReplaceInstruction(power, MakeScalarLike(power, 1));
3622   }
3623 
3624   VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3625   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
3626     return Status::OK();
3627   }
3628 
3629   // pow(exp(A),B) => exp(A*B)
3630   HloInstruction *a, *b;
3631   if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3632     auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
3633         power->shape(), HloOpcode::kMultiply, a, b));
3634     return ReplaceWithNewInstruction(
3635         power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3636                                            a_times_b));
3637   }
3638 
3639   VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3640   if (IsAll(rhs, 2)) {
3641     return ReplaceWithNewInstruction(
3642         power, HloInstruction::CreateBinary(power->shape(),
3643                                             HloOpcode::kMultiply, lhs, lhs));
3644   }
3645 
3646   // Pow(A, 3) is used in GELU.
3647   VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3648   if (IsAll(rhs, 3)) {
3649     HloInstruction* tmp =
3650         computation_->AddInstruction(HloInstruction::CreateBinary(
3651             power->shape(), HloOpcode::kMultiply, lhs, lhs));
3652     return ReplaceWithNewInstruction(
3653         power, HloInstruction::CreateBinary(power->shape(),
3654                                             HloOpcode::kMultiply, lhs, tmp));
3655   }
3656 
3657   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3658   if (IsAll(rhs, -1)) {
3659     return ReplaceWithNewInstruction(
3660         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3661                                             MakeScalarLike(lhs, 1), lhs));
3662   }
3663 
3664   return Status::OK();
3665 }
3666 
3667 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3668 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3669     HloInstruction* broadcast) {
3670   TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3671   bool changed = false;
3672   if (ShapeUtil::IsScalar(broadcast->shape())) {
3673     return false;
3674   }
3675   HloInstruction* operand = broadcast->mutable_operand(0);
3676   auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3677     return instruction->opcode() == HloOpcode::kBroadcast &&
3678            ShapeUtil::IsScalar(instruction->operand(0)->shape());
3679   };
3680   auto is_equal_broadcast = [operand,
3681                              broadcast](const HloInstruction* instruction) {
3682     return instruction->opcode() == HloOpcode::kBroadcast &&
3683            ShapeUtil::Equal(operand->shape(),
3684                             instruction->operand(0)->shape()) &&
3685            broadcast->dimensions() == instruction->dimensions();
3686   };
3687   auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3688     return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3689   };
3690   for (HloInstruction* user : broadcast->users()) {
3691     if (user->user_count() == 0 && user != computation_->root_instruction()) {
3692       continue;
3693     }
3694     // Do not move reshapes or broadcasts past copies since the shape the copy
3695     // will operate on will change.
3696     if (user->opcode() == HloOpcode::kCopy) {
3697       continue;
3698     }
3699     // Do not change the shape of fusion nodes in case there a multiple shapes
3700     // inside the fusion node already.
3701     if (user->opcode() == HloOpcode::kFusion) {
3702       continue;
3703     }
3704     if (!user->IsElementwise()) {
3705       continue;
3706     }
3707 
3708     // Check if all the operands of the user are compatible broadcasts for
3709     // sinking. (They are either scalar broadcasts or broadcasts casting
3710     // from/to the same shape/dimensions)
3711     int64_t compatible_broadcast_count = 0;
3712     int64_t broadcast_use_count = 0;
3713     for (HloInstruction* user_operand : user->operands()) {
3714       if (is_compatible_broadcast(user_operand)) {
3715         ++compatible_broadcast_count;
3716       } else if (broadcast == user_operand) {
3717         ++broadcast_use_count;
3718       }
3719     }
3720     if (compatible_broadcast_count + broadcast_use_count !=
3721         user->operand_count()) {
3722       continue;
3723     }
3724     std::vector<HloInstruction*> new_operands;
3725     new_operands.reserve(user->operand_count());
3726 
3727     Shape changed_shape;
3728     for (HloInstruction* user_operand : user->operands()) {
3729       // If this is a broadcast operand that is not our original broadcast input
3730       // to this function then we might need to change the input.
3731       if (is_compatible_broadcast(user_operand)) {
3732         // If this is a broadcast from a scalar value rewrite a broadcast from
3733         // the scalar to the new shape enforced from the other broadcast
3734         // operands.
3735         if (is_scalar_broadcast(user_operand)) {
3736           changed_shape = ShapeUtil::ChangeElementType(
3737               operand->shape(), user_operand->shape().element_type());
3738           simplifier_->UpdateLayout(&changed_shape);
3739           new_operands.push_back(
3740               computation_->AddInstruction(HloInstruction::CreateBroadcast(
3741                   changed_shape, user_operand->mutable_operand(0), {})));
3742           user_operand->SetupDerivedInstruction(new_operands.back());
3743         } else {
3744           // For the non-scalar broadcasts we guarantee that the shape of the
3745           // operand of the broadcast needs to be already a compatible shape.
3746           new_operands.push_back(user_operand->mutable_operand(0));
3747         }
3748       } else {
3749         CHECK_EQ(broadcast, user_operand);
3750         new_operands.push_back(operand);
3751       }
3752     }
3753     VLOG(4) << "Sinking broadcast after user:";
3754     VLOG(4) << "  old broadcast: " << broadcast->ToString();
3755     VLOG(4) << "  old user: " << user->ToString();
3756     changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
3757                                                  user->shape().element_type());
3758     simplifier_->UpdateLayout(&changed_shape);
3759     HloInstruction* new_user = computation_->AddInstruction(
3760         user->CloneWithNewOperands(changed_shape, new_operands));
3761     VLOG(4) << "  new user: " << new_user->ToString();
3762     HloInstruction* new_broadcast =
3763         computation_->AddInstruction(HloInstruction::CreateBroadcast(
3764             user->shape(), new_user, broadcast->dimensions()));
3765     broadcast->SetupDerivedInstruction(new_broadcast);
3766     VLOG(4) << "  new broadcast: " << new_broadcast->ToString();
3767     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
3768     changed = true;
3769   }
3770   return changed;
3771 }
3772 
3773 namespace {
3774 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)3775 std::unique_ptr<HloInstruction> TryRemainderToAnd(
3776     HloInstruction* remainder, HloComputation* computation,
3777     AlgebraicSimplifier* simplifier) {
3778   HloInstruction *a, *b, *c;
3779   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3780 
3781   if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
3782       !Match(b, m::ConstantEffectiveScalar(&c)) &&
3783       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
3784     return nullptr;
3785   }
3786 
3787   if (ShapeUtil::ElementIsSigned(remainder->shape())) {
3788     int64_t b_value = c->literal().GetFirstElement<T>();
3789     if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
3790       // Handle negative dividends by negating the result of the division.
3791       HloInstruction* zero_like_a = BroadcastZeros(
3792           computation, a->shape().element_type(), a->shape().dimensions());
3793 
3794       Shape compare_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
3795       simplifier->UpdateLayout(&compare_shape);
3796       auto* dividend_is_negative =
3797           computation->AddInstruction(HloInstruction::CreateCompare(
3798               compare_shape, a, zero_like_a, ComparisonDirection::kLt));
3799 
3800       auto* negated_dividend = computation->AddInstruction(
3801           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
3802 
3803       auto* abs_dividend =
3804           computation->AddInstruction(HloInstruction::CreateTernary(
3805               a->shape(), HloOpcode::kSelect, dividend_is_negative,
3806               negated_dividend, a));
3807 
3808       auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
3809           remainder->shape(), HloOpcode::kAnd, abs_dividend,
3810           MakeScalarLike(abs_dividend, b_value - 1)));
3811 
3812       auto* neqated_quotient =
3813           computation->AddInstruction(HloInstruction::CreateUnary(
3814               quotient->shape(), HloOpcode::kNegate, quotient));
3815 
3816       return HloInstruction::CreateTernary(
3817           remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
3818           neqated_quotient, quotient);
3819     }
3820   } else {
3821     uint64 b_value = c->literal().GetFirstElement<T>();
3822     if (IsPowerOfTwo(b_value)) {
3823       HloInstruction* mask_amount = computation->AddInstruction(
3824           simplifier->CreateConstantWithLayoutUpdated(
3825               LiteralUtil::CreateR0<T>(b_value - 1)));
3826       if (!ShapeUtil::IsScalar(b->shape())) {
3827         mask_amount = computation->AddInstruction(
3828             HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
3829       }
3830       return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
3831                                           a, mask_amount);
3832     }
3833   }
3834   return nullptr;
3835 }
3836 }  // namespace
3837 
HandleRemainder(HloInstruction * remainder)3838 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
3839   HloInstruction *a, *b;
3840   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3841 
3842   // (A % B) % B == A % B.
3843   if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
3844     return ReplaceInstruction(remainder, a);
3845   }
3846 
3847   // A % B => A & (B - 1) if B is a power of 2.
3848   switch (remainder->shape().element_type()) {
3849     case S8:
3850       if (std::unique_ptr<HloInstruction> shift =
3851               TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
3852         return ReplaceWithNewInstruction(remainder, std::move(shift));
3853       }
3854       break;
3855     case S16:
3856       if (std::unique_ptr<HloInstruction> shift =
3857               TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
3858         return ReplaceWithNewInstruction(remainder, std::move(shift));
3859       }
3860       break;
3861     case S32:
3862       if (std::unique_ptr<HloInstruction> shift =
3863               TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
3864         return ReplaceWithNewInstruction(remainder, std::move(shift));
3865       }
3866       break;
3867     case S64:
3868       if (std::unique_ptr<HloInstruction> shift =
3869               TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
3870         return ReplaceWithNewInstruction(remainder, std::move(shift));
3871       }
3872       break;
3873     case U8:
3874       if (std::unique_ptr<HloInstruction> shift =
3875               TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
3876         return ReplaceWithNewInstruction(remainder, std::move(shift));
3877       }
3878       break;
3879     case U16:
3880       if (std::unique_ptr<HloInstruction> shift =
3881               TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
3882         return ReplaceWithNewInstruction(remainder, std::move(shift));
3883       }
3884       break;
3885     case U32:
3886       if (std::unique_ptr<HloInstruction> shift =
3887               TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
3888         return ReplaceWithNewInstruction(remainder, std::move(shift));
3889       }
3890       break;
3891     case U64:
3892       if (std::unique_ptr<HloInstruction> shift =
3893               TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
3894         return ReplaceWithNewInstruction(remainder, std::move(shift));
3895       }
3896       break;
3897     default:
3898       break;
3899   }
3900 
3901   // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
3902   //
3903   // Currently this only covers the case when N is a broadcasted constant
3904   // scalar.  We could also cover the case when N is a non-broadcasted constant
3905   // with the same value repeated.
3906   HloInstruction* iota;
3907   HloInstruction* divisor;
3908   if (Match(remainder,
3909             m::Remainder(m::Iota(&iota),
3910                          m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
3911     // The iota counts {0, ..., iota_upper_bound - 1}.  (Actually this is
3912     // conservative; the iota may overflow and count up to a smaller value than
3913     // this.  But that's OK for our purposes here.)
3914     int64_t iota_upper_bound = iota->shape().dimensions(
3915         Cast<HloIotaInstruction>(iota)->iota_dimension());
3916     absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3917         std::vector<int64>(0, divisor->shape().dimensions_size()));
3918     if (divisor_val && *divisor_val >= iota_upper_bound) {
3919       return ReplaceInstruction(remainder, iota);
3920     }
3921   }
3922 
3923   // (X + N) % N = X % N, so long as X + N does not overflow.
3924   //
3925   // We don't have range tracking in XLA that would let us know whether X + N
3926   // overflows, so for now we only do this simplification when X is an iota.  We
3927   // could add other operations where it's easy to see a range, such as
3928   // remainder, convert, etc., though at some point we'd probably want a
3929   // range-tracking analysis.
3930   HloInstruction* bcast;
3931   HloInstruction* addend;
3932   if (Match(
3933           remainder,
3934           m::Remainder(
3935               m::AddAnyOrder(m::Iota(&iota),
3936                              m::Broadcast(m::ConstantEffectiveScalar(&addend))),
3937               m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
3938       addend == divisor) {
3939     // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
3940     // that iota_upper_bound is conservative, and the true upper bound may be
3941     // smaller.
3942     int64_t iota_upper_bound = iota->shape().dimensions(
3943         Cast<HloIotaInstruction>(iota)->iota_dimension());
3944     absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3945         std::vector<int64>(0, divisor->shape().dimensions_size()));
3946     if (divisor_val) {
3947       // Check whether divisor_val + iota_upper_bound - 1 overflows.
3948       absl::optional<int64> max_val =
3949           OverflowSafeAdd(*divisor_val, iota_upper_bound);
3950       if (max_val.has_value() &&
3951           FitsInIntegralType(*max_val, iota->shape().element_type())) {
3952         return ReplaceWithNewInstruction(
3953             remainder,
3954             HloInstruction::CreateBinary(remainder->shape(),
3955                                          HloOpcode::kRemainder, iota, bcast));
3956       }
3957     }
3958   }
3959 
3960   return Status::OK();
3961 }
3962 
HandleReshape(HloInstruction * reshape)3963 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
3964   auto operand = reshape->mutable_operand(0);
3965 
3966   // Reshape directly to empty constant if the shape contains zero-element
3967   // dimension.
3968   if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
3969     // If the instruction doesn't have a layout, use a default layout for
3970     // the literal result.
3971     Shape reshaped_shape = reshape->shape();
3972     if (!LayoutUtil::HasLayout(reshaped_shape)) {
3973       LayoutUtil::SetToDefaultLayout(&reshaped_shape);
3974     }
3975     auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
3976         Literal::CreateFromShape(reshaped_shape));
3977 
3978     return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
3979   }
3980 
3981   // Delete no-op reshapes, i.e. where shape = operand shape.
3982   if (SameShape(reshape, operand)) {
3983     VLOG(10) << "deleting no-op reshape";
3984     return ReplaceInstruction(reshape, operand);
3985   }
3986 
3987   // Merge reshapes.
3988   if (HloOpcode::kReshape == operand->opcode()) {
3989     return ReplaceWithNewInstruction(
3990         reshape, HloInstruction::CreateReshape(reshape->shape(),
3991                                                operand->mutable_operand(0)));
3992   }
3993 
3994   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3995     *operand->mutable_shape() = reshape->shape();
3996     return ReplaceInstruction(reshape, operand);
3997   }
3998 
3999   if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
4000     auto opt_dims = ReshapeLeavesDimensionsUnmodified(
4001         reshape, reshape->operand(0)->dimensions());
4002     if (opt_dims.has_value()) {
4003       return ReplaceWithNewInstruction(
4004           reshape,
4005           HloInstruction::CreateBroadcast(
4006               reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
4007               *opt_dims));
4008     }
4009   }
4010 
4011   // reshape(iota) -> iota or a mixed radix calculation like
4012   // s32[2,3,4] reshape(s32[24] iota()) to
4013   // add(
4014   //    add(s32[2,3,4] iota() iota_dimension=2,
4015   //        4 * s32[2,3,4] iota() iota_dimension=1),
4016   //    12 * s32[2,3,4] iota() iota_dimension=0).
4017   if (operand->opcode() == HloOpcode::kIota) {
4018     auto* iota = Cast<HloIotaInstruction>(operand);
4019     auto common_factors =
4020         CommonFactors(reshape->operand(0)->shape().dimensions(),
4021                       reshape->shape().dimensions());
4022     auto iota_dim = absl::c_find_if(
4023         common_factors, [&](const std::pair<int64, int64>& dim_pair) {
4024           return dim_pair.first == iota->iota_dimension() &&
4025                  reshape->shape().dimensions(dim_pair.second) > 1;
4026         });
4027     auto next_dim = absl::c_find_if(
4028         common_factors, [&](const std::pair<int64, int64>& dim_pair) {
4029           return dim_pair.first == iota->iota_dimension() + 1;
4030         });
4031     if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
4032       int64_t multiplier = 1;
4033       HloInstruction* new_reshape = nullptr;
4034 
4035       for (int64_t dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
4036            --dim) {
4037         HloInstruction* new_iota = computation_->AddInstruction(
4038             HloInstruction::CreateIota(reshape->shape(), dim));
4039         iota->SetupDerivedInstruction(new_iota);
4040         if (new_reshape) {
4041           new_reshape =
4042               computation_->AddInstruction(HloInstruction::CreateBinary(
4043                   reshape->shape(), HloOpcode::kAdd, new_reshape,
4044                   computation_->AddInstruction(HloInstruction::CreateBinary(
4045                       reshape->shape(), HloOpcode::kMultiply, new_iota,
4046                       MakeScalarLike(reshape, multiplier)))));
4047           reshape->SetupDerivedInstruction(new_reshape);
4048         } else {
4049           new_reshape = new_iota;
4050         }
4051         multiplier *= reshape->shape().dimensions(dim);
4052       }
4053       reshape->SetupDerivedInstruction(new_reshape);
4054       return ReplaceInstruction(reshape, new_reshape);
4055     }
4056   }
4057 
4058   // Moves the reshape in reshape(dus(...), x, ...)) before dus so that it can
4059   // enable other optimizations, e.g., merging with broadcast, and sparse update
4060   // (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
4061   if (!options_.is_layout_sensitive()) {
4062     bool trivial_reshape;
4063     std::vector<int64> deleted_dims;
4064     std::vector<int64> inserted_dims;
4065 
4066     HloInstruction* dus;
4067     HloInstruction* slice;
4068     std::tie(trivial_reshape, deleted_dims, inserted_dims) =
4069         reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
4070     // 1-sized dimensions added and removed will be one sized in both the update
4071     // slice and the dynamic-update-slice result.
4072     if (trivial_reshape &&
4073         Match(reshape->mutable_operand(0),
4074               m::Op(&dus)
4075                   .WithOpcode(HloOpcode::kDynamicUpdateSlice)
4076                   .WithOperand(1, m::Op(&slice))) &&
4077         !dus->has_sharding() && !dus->operand(0)->has_sharding()) {
4078       auto new_operand =
4079           computation_->AddInstruction(HloInstruction::CreateReshape(
4080               reshape->shape(), dus->mutable_operand(0)));
4081       std::vector<int64> new_slice_shape;
4082       std::vector<HloInstruction*> new_dus_operands;
4083       new_dus_operands.push_back(new_operand);
4084       new_dus_operands.push_back(nullptr);
4085       auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
4086       const Shape& old_slice_shape = dus->operand(1)->shape();
4087       for (int64 i = 0; i <= old_slice_shape.rank(); ++i) {
4088         if (absl::c_linear_search(deleted_dims, i)) {
4089           continue;
4090         }
4091         while (absl::c_linear_search(inserted_dims, new_slice_shape.size())) {
4092           new_slice_shape.push_back(1);
4093           new_dus_operands.push_back(zero);
4094         }
4095         if (i < old_slice_shape.rank()) {
4096           new_slice_shape.push_back(old_slice_shape.dimensions(i));
4097           new_dus_operands.push_back(dus->mutable_operand(2 + i));
4098         }
4099       }
4100       auto new_slice =
4101           computation_->AddInstruction(HloInstruction::CreateReshape(
4102               ShapeUtil::MakeShape(old_slice_shape.element_type(),
4103                                    new_slice_shape),
4104               slice));
4105       new_dus_operands[1] = new_slice;
4106       auto new_dus =
4107           dus->CloneWithNewOperands(reshape->shape(), new_dus_operands);
4108       return ReplaceWithNewInstruction(reshape, std::move(new_dus));
4109     }
4110   }
4111 
4112   // Make this a bitcast if possible.
4113   if (HloInstruction* bitcast_operand =
4114           BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
4115     ReplaceWithBitcast(reshape, bitcast_operand);
4116   }
4117   return Status::OK();
4118 }
4119 
HandleReverse(HloInstruction * reverse)4120 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
4121   // When all the dimensions to reverse are trivial (i.e. the bound is 1),
4122   // there is nothing to be done.
4123   auto dim_is_one = [&](int64_t i) -> bool {
4124     return reverse->shape().dimensions(i) == 1;
4125   };
4126   if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
4127     return ReplaceInstruction(reverse, reverse->mutable_operand(0));
4128   }
4129   return Status::OK();
4130 }
4131 
TrySimplifyScalarSlice(HloInstruction * slice)4132 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
4133     HloInstruction* slice) {
4134   // Only try to do this for effective scalars. We could do the same for slicing
4135   // out larger pieces of padding (replacing with a broadcast of the padding
4136   // value), but this is probably not worth it.
4137   if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
4138     return false;
4139   }
4140 
4141   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
4142     VLOG(10) << "Trying to simplify scalar slice of concat";
4143     // Only do this for R1, there's no chance of this being useful otherwise.
4144     if (slice->shape().rank() != 1) {
4145       VLOG(10) << "Not folding, slice is not rank 1";
4146       return false;
4147     }
4148     HloConcatenateInstruction* concat =
4149         Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4150     int64_t operand_start = 0;
4151     int64_t operand_num = 0;
4152     // Weird loop structure to avoid annoying off-by-one errors.
4153     while (true) {
4154       TF_RET_CHECK(operand_num < concat->operand_count());
4155       const HloInstruction* operand = concat->operand(operand_num);
4156       int64_t next_operand_start =
4157           operand_start + operand->shape().dimensions(0);
4158       if (next_operand_start > slice->slice_starts(0)) {
4159         break;
4160       }
4161       operand_start = next_operand_start;
4162       operand_num++;
4163     }
4164 
4165     bool replaced = ReplaceInstructionIfSameShape(
4166         slice, concat->mutable_operand(operand_num));
4167     if (replaced) {
4168       VLOG(10) << "Folding scalar slice of concat into concat operand";
4169     } else {
4170       VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4171       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4172           slice, HloInstruction::CreateSlice(
4173                      slice->shape(), concat->mutable_operand(operand_num),
4174                      {slice->slice_starts(0) - operand_start},
4175                      {slice->slice_starts(0) - operand_start + 1},
4176                      slice->slice_strides())));
4177     }
4178     return true;
4179   }
4180 
4181   return false;
4182 }
4183 
TryToReorderSliceAndReshape(HloInstruction * slice)4184 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4185     HloInstruction* slice) {
4186   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4187   if (!IsUnstridedSlice(slice)) {
4188     return false;
4189   }
4190   HloInstruction* reshape = slice->mutable_operand(0);
4191   if (reshape->opcode() != HloOpcode::kReshape) {
4192     return false;
4193   }
4194   HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4195   int64_t slice_rank = slice->shape().rank();
4196   std::vector<int64> sliced_dims;
4197   for (int64_t i = 0; i < slice_rank; ++i) {
4198     if (slice->slice_starts(i) != 0 ||
4199         slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4200       sliced_dims.push_back(i);
4201     }
4202   }
4203 
4204   if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4205       slice->slice_starts(0) == 0) {
4206     const Shape& new_slice_shape = new_slice_operand->shape();
4207     const int64_t rank = new_slice_shape.rank();
4208     std::vector<int64> new_slice_starts(rank, 0);
4209     std::vector<int64> new_slice_stides(rank, 1);
4210     std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
4211                                         new_slice_shape.dimensions().end());
4212     int64_t slice_elements = ShapeUtil::ElementsIn(slice->shape());
4213     for (int64_t i = rank - 1; i >= 0; --i) {
4214       if (slice_elements >= new_slice_limits[i]) {
4215         if (slice_elements % new_slice_limits[i] != 0) {
4216           return false;
4217         }
4218         slice_elements /= new_slice_limits[i];
4219       } else {
4220         new_slice_limits[i] = slice_elements;
4221         slice_elements = 1;
4222       }
4223     }
4224     HloInstruction* new_slice =
4225         computation_->AddInstruction(HloInstruction::CreateSlice(
4226             ShapeUtil::MakeShape(new_slice_shape.element_type(),
4227                                  new_slice_limits),
4228             new_slice_operand, new_slice_starts, new_slice_limits,
4229             new_slice_stides));
4230     simplifier_->UpdateLayout(new_slice->mutable_shape());
4231     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4232         slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4233     return true;
4234   }
4235   return false;
4236 }
4237 
4238 // Allowing a slice to move through a reverse with any necessary updates to the
4239 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4240 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4241     HloInstruction* slice) {
4242   VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4243           << slice->ToString();
4244   if (Match(slice, m::Slice(m::Reverse()))) {
4245     HloInstruction* reverse = slice->mutable_operand(0);
4246     HloInstruction* reverse_operand = reverse->mutable_operand(0);
4247     std::vector<int64> new_starts = slice->slice_starts();
4248     std::vector<int64> new_limits = slice->slice_limits();
4249     std::vector<int64> new_strides = slice->slice_strides();
4250     for (auto rdim : reverse->dimensions()) {
4251       int64_t start = slice->slice_starts(rdim);
4252       int64_t limit = slice->slice_limits(rdim);
4253       int64_t stride = slice->slice_strides(rdim);
4254       // find_nth allows us to compute the appropriate index to begin
4255       // with during reverse even in the presence of non-unit strides
4256       int64_t find_nth = (limit - start - 1) / stride;
4257       find_nth = start + find_nth * stride;
4258       limit = find_nth + 1;
4259       new_starts[rdim] =
4260           (reverse->shape().dimensions(rdim) - start) - (limit - start);
4261       new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4262       VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4263               << limit << " and new (start, limit):" << new_starts[rdim] << ","
4264               << new_limits[rdim];
4265     }
4266     // New slice formed from the reverse_operand, but strides and shape of the
4267     // slice output remains the same. New slice's starts and limits are updated
4268     // for ONLY the reversed dimensions as indicated above.
4269     HloInstruction* new_slice = computation_->AddInstruction(
4270         HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4271                                     new_limits, new_strides));
4272     simplifier_->UpdateLayout(new_slice->mutable_shape());
4273     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4274         slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4275                                              reverse->dimensions())));
4276     // We do not delete the old reverse, since there might be another
4277     // consumer of that reverse (i.e., full reverse output). DCE should take
4278     // care of any deletion that is necessary if there was no use of reverse.
4279     return true;
4280   }
4281   return false;
4282 }
4283 
HandleSlice(HloInstruction * slice)4284 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4285   // Delete no-op slices, i.e. where shape = operand shape.
4286   if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
4287     return Status::OK();
4288   }
4289 
4290   HloInstruction* pad;
4291   HloInstruction* pad_operand;
4292   if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4293     // Is the result of the slice the pad operand.
4294     bool slice_undoes_pad = true;
4295     // Can the slice be moved to the pad_operand without any padding being read.
4296     bool slice_inside_pad = true;
4297     // Does this slice slice out pading only.
4298     bool slice_in_padding = false;
4299     std::vector<int64> new_starts = slice->slice_starts();
4300     std::vector<int64> new_limits = slice->slice_limits();
4301     for (int64_t i = 0; i < slice->shape().rank(); ++i) {
4302       const int64_t start = slice->slice_starts(i);
4303       const int64_t stride = slice->slice_strides(i);
4304       const int64_t limit = slice->slice_limits(i);
4305       const int64_t size = pad->shape().dimensions(i);
4306 
4307       const auto& dim = pad->padding_config().dimensions(i);
4308       const int64_t low = dim.edge_padding_low();
4309       const int64_t high = dim.edge_padding_high();
4310       const int64_t interior = dim.interior_padding();
4311       const int64_t edge = size - high;
4312 
4313       if (limit <= low || start >= edge) {
4314         slice_in_padding = true;
4315         break;
4316       }
4317 
4318       if (start != low || stride - 1 != interior) {
4319         slice_undoes_pad = false;
4320       }
4321 
4322       if (start < low || limit > edge || interior != 0 || stride != 1) {
4323         slice_inside_pad = false;
4324       }
4325       new_starts[i] -= low;
4326       new_limits[i] -= low;
4327     }
4328     if (slice_in_padding) {
4329       HloInstruction* broadcast =
4330           MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4331       *(broadcast->mutable_shape()) = slice->shape();
4332       return ReplaceInstruction(slice, broadcast);
4333     }
4334     if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
4335       return Status::OK();
4336     }
4337     if (slice_inside_pad) {
4338       TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4339                           MakeSliceHlo(pad_operand, new_starts, new_limits,
4340                                        slice->slice_strides()));
4341       *(new_slice->mutable_shape()) = slice->shape();
4342       return ReplaceInstruction(slice, new_slice);
4343     }
4344   }
4345 
4346   if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4347       IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4348     HloInstruction* operand_slice = slice->mutable_operand(0);
4349     std::vector<int64> new_slice_starts = slice->slice_starts();
4350     std::vector<int64> new_slice_limits = slice->slice_limits();
4351     for (int64_t i = 0; i < new_slice_starts.size(); ++i) {
4352       new_slice_starts[i] += operand_slice->slice_starts(i);
4353       new_slice_limits[i] += operand_slice->slice_starts(i);
4354     }
4355     return ReplaceWithNewInstruction(
4356         slice, HloInstruction::CreateSlice(
4357                    slice->shape(), operand_slice->mutable_operand(0),
4358                    new_slice_starts, new_slice_limits, slice->slice_strides()));
4359   }
4360 
4361   auto only_broadcast_dims_sliced = [&] {
4362     if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4363       return false;
4364     }
4365     for (int64_t dim : slice->operand(0)->dimensions()) {
4366       if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4367           slice->slice_limits(dim) !=
4368               slice->operand(0)->shape().dimensions(dim)) {
4369         return false;
4370       }
4371     }
4372     return true;
4373   };
4374   if (only_broadcast_dims_sliced()) {
4375     return ReplaceWithNewInstruction(
4376         slice,
4377         HloInstruction::CreateBroadcast(
4378             slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4379             slice->mutable_operand(0)->dimensions()));
4380   }
4381 
4382   TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4383   if (replaced) {
4384     return Status::OK();
4385   }
4386 
4387   HloInstruction* broadcast;
4388   HloInstruction* broadcast_operand;
4389   if (Match(slice,
4390             m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4391     std::vector<int64> new_slice_starts;
4392     std::vector<int64> new_slice_strides;
4393     std::vector<int64> new_slice_limits;
4394     new_slice_starts.reserve(broadcast_operand->shape().rank());
4395     new_slice_strides.reserve(broadcast_operand->shape().rank());
4396     new_slice_limits.reserve(broadcast_operand->shape().rank());
4397     for (int64_t dim : broadcast->dimensions()) {
4398       new_slice_starts.push_back(slice->slice_starts(dim));
4399       new_slice_strides.push_back(slice->slice_strides(dim));
4400       new_slice_limits.push_back(slice->slice_limits(dim));
4401     }
4402     VLOG(3) << "Sink broadcast through slice";
4403     VLOG(3) << "Original slice: " << slice->ToString();
4404     VLOG(3) << "Original broadcast: " << broadcast->ToString();
4405     auto new_slice_shape = broadcast_operand->shape();
4406     for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4407       int64_t size_i = (new_slice_limits[i] - new_slice_starts[i] +
4408                         new_slice_strides[i] - 1) /
4409                        new_slice_strides[i];
4410       new_slice_shape.set_dimensions(i, size_i);
4411     }
4412     simplifier_->UpdateLayout(&new_slice_shape);
4413     HloComputation* computation = broadcast_operand->parent();
4414     auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice(
4415         new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4416         new_slice_strides));
4417     auto new_broadcast = HloInstruction::CreateBroadcast(
4418         slice->shape(), new_slice, broadcast->dimensions());
4419     VLOG(3) << "New slice: " << slice->ToString();
4420     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4421     return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
4422   }
4423 
4424   // Try to simplify concat -> slice to an operand of concat.
4425   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4426       IsUnstridedSlice(slice)) {
4427     auto concat = slice->operand(0);
4428     int64_t concat_dim = concat->concatenate_dimension();
4429     int64_t piece_start = 0;
4430     for (auto piece : concat->operands()) {
4431       if (!SameShape(piece, slice)) {
4432         piece_start += piece->shape().dimensions(concat_dim);
4433         continue;
4434       }
4435       if (slice->slice_starts(concat_dim) == piece_start) {
4436         return ReplaceInstruction(slice, piece);
4437       }
4438       piece_start += piece->shape().dimensions(concat_dim);
4439     }
4440   }
4441 
4442   // Do not try to reorder slices and reshapes after layout assignment as it may
4443   // be invalid.
4444   if (!options_.is_layout_sensitive()) {
4445     TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4446   }
4447   if (replaced) {
4448     return Status::OK();
4449   }
4450 
4451   bool reversed = false;
4452   if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4453     TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4454   }
4455   if (reversed) {
4456     return Status::OK();
4457   }
4458 
4459   return Status::OK();
4460 }
4461 
HandleRsqrt(HloInstruction * rsqrt)4462 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4463   VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4464            << rsqrt->ToString();
4465   HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4466   if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4467       IsAll(rsqrt_operand->operand(1), -2) &&
4468       IsPositive(rsqrt_operand, options_)) {
4469     return ReplaceWithNewInstruction(
4470         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4471                                            rsqrt_operand->mutable_operand(0)));
4472   }
4473 
4474   VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4475            << rsqrt->ToString();
4476   if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4477       IsAll(rsqrt_operand->operand(0), 1) &&
4478       IsPositive(rsqrt_operand->operand(1), options_)) {
4479     return ReplaceWithNewInstruction(
4480         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4481                                            rsqrt_operand->mutable_operand(1)));
4482   }
4483 
4484   return Status::OK();
4485 }
4486 
HandleDynamicSlice(HloInstruction * dynamic_slice)4487 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4488     HloInstruction* dynamic_slice) {
4489   auto operand = dynamic_slice->mutable_operand(0);
4490   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4491     return ReplaceInstruction(dynamic_slice, operand);
4492   }
4493   // DynamicSlice where operand has the same size as the output is simply equal
4494   // to operand.
4495   if (SameShape(operand, dynamic_slice)) {
4496     return ReplaceInstruction(dynamic_slice, operand);
4497   }
4498 
4499   HloInstruction* broadcast_operand;
4500   if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4501     std::vector<HloInstruction*> new_indices;
4502     new_indices.reserve(broadcast_operand->shape().rank());
4503     std::vector<int64> new_slice_sizes;
4504     new_slice_sizes.reserve(broadcast_operand->shape().rank());
4505 
4506     for (int64_t dim : operand->dimensions()) {
4507       new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4508       new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4509     }
4510 
4511     VLOG(3) << "Sink broadcast through dynamic slice";
4512     VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4513     VLOG(3) << "Original broadcast: " << operand->ToString();
4514     HloInstruction* new_dynamic_slice = broadcast_operand;
4515     if (!new_slice_sizes.empty()) {
4516       auto new_ds_shape = broadcast_operand->shape();
4517       for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4518         new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4519       }
4520       simplifier_->UpdateLayout(&new_ds_shape);
4521       HloComputation* computation = broadcast_operand->parent();
4522       new_dynamic_slice =
4523           computation->AddInstruction(HloInstruction::CreateDynamicSlice(
4524               new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4525     }
4526     auto new_broadcast = HloInstruction::CreateBroadcast(
4527         dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
4528     VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4529     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4530     return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
4531   }
4532 
4533   // Convert a dynamic slice into a slice if all offsets are constant and the
4534   // operand is not constant.
4535   if (operand->opcode() != HloOpcode::kConstant &&
4536       absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4537                                     dynamic_slice->operands().end()),
4538                      [](HloInstruction* operand) {
4539                        return operand->opcode() == HloOpcode::kConstant &&
4540                               ShapeUtil::ElementIsIntegral(operand->shape());
4541                      })) {
4542     const int64_t rank = operand->shape().rank();
4543     std::vector<int64> slice_starts(rank);
4544     std::vector<int64> slice_limits(rank);
4545     std::vector<int64> slice_strides(rank, 1);
4546 
4547     for (int64_t i = 0; i < rank; ++i) {
4548       absl::optional<int64> offset =
4549           dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4550       if (!offset || *offset < 0) {
4551         return Status::OK();
4552       }
4553       const int64_t max_offset =
4554           dynamic_slice->operand(0)->shape().dimensions(i) -
4555           dynamic_slice->shape().dimensions(i);
4556       slice_starts[i] = std::min(max_offset, *offset);
4557       slice_limits[i] =
4558           std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4559     }
4560     return ReplaceWithNewInstruction(
4561         dynamic_slice,
4562         HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4563                                     slice_starts, slice_limits, slice_strides));
4564   }
4565 
4566   // Convert the dynamic slice of an iota to just a reference to the index
4567   // (possibly clamped and scaled). Index is always a scalar integer. Output
4568   // should be a rank 1 array of size 1 with element type matching that of the
4569   // scalar index (except the signedness).
4570   const PrimitiveType element_type = dynamic_slice->shape().element_type();
4571   if (operand->shape().rank() == 1 && dynamic_slice->shape().rank() == 1 &&
4572       dynamic_slice->shape().dimensions(0) == 1 &&
4573       (element_type == S32 || element_type == U32)) {
4574     // Match multiply(x, broadcast(scalar)) and return the scalar
4575     // constant.
4576     auto match_multiply_with_scalar =
4577         [&](HloInstruction* hlo) -> HloInstruction* {
4578       if (hlo->opcode() != HloOpcode::kMultiply) {
4579         return nullptr;
4580       }
4581       HloInstruction* broadcast = hlo->mutable_operand(1);
4582       if (broadcast->opcode() == HloOpcode::kBroadcast &&
4583           broadcast->dimensions().empty() &&
4584           ShapeUtil::IsScalar(broadcast->operand(0)->shape())) {
4585         return broadcast->mutable_operand(0);
4586       }
4587       return nullptr;
4588     };
4589 
4590     HloInstruction* multiplier = match_multiply_with_scalar(operand);
4591     if (multiplier) {
4592       operand = operand->mutable_operand(0);
4593     }
4594 
4595     if (operand->opcode() == HloOpcode::kIota) {
4596       // This dynamic_slice will have a single start_index operand (since its
4597       // operand is rank 1).
4598       HloInstruction* index = dynamic_slice->mutable_operand(1);
4599       const PrimitiveType index_type = index->shape().element_type();
4600 
4601       auto create_constant = [&](int64_t value) {
4602         if (index_type == S32) {
4603           return MakeScalarLike<int32_t>(index, value);
4604         } else {
4605           return MakeScalarLike<uint32_t>(index, value);
4606         }
4607       };
4608 
4609       if (index_type == S32 || index_type == U32) {
4610         // Clamp the index to the range of the iota.
4611         int64_t iota_size = operand->shape().dimensions(0);
4612         HloInstruction* low = create_constant(0);
4613         HloInstruction* high = create_constant(iota_size - 1);
4614         HloInstruction* clamped =
4615             computation_->AddInstruction(HloInstruction::CreateTernary(
4616                 index->shape(), HloOpcode::kClamp, low, index, high));
4617 
4618         // Convert the clamped index from index_type to element_type and
4619         // multiply with the multiplier.
4620         HloInstruction* result = clamped;
4621         if (index_type != element_type) {
4622           result = computation_->AddInstruction(HloInstruction::CreateConvert(
4623               ShapeUtil::MakeScalarShape(element_type), clamped));
4624         }
4625 
4626         if (multiplier) {
4627           result = computation_->AddInstruction(HloInstruction::CreateBinary(
4628               result->shape(), HloOpcode::kMultiply, result, multiplier));
4629         }
4630 
4631         return ReplaceWithNewInstruction(
4632             dynamic_slice,
4633             HloInstruction::CreateReshape(
4634                 ShapeUtil::MakeShape(element_type, {1}), result));
4635       }
4636     }
4637   }
4638 
4639   return Status::OK();
4640 }
4641 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)4642 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
4643     HloInstruction* dynamic_update_slice) {
4644   // Rewriting DynamicUpdateSlice when it matches
4645   // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
4646   // to a Pad(x, constant)
4647   // Only Broadcast considered currently, other ops need to be considered
4648   // in the future.
4649   HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
4650   HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
4651   HloInstruction* pad_value;
4652   if (Match(updated,
4653             m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
4654     auto updated_shape = updated->shape();
4655     auto update_shape = dus_update->shape();
4656     auto update_start_indx = dynamic_update_slice->operand(2);
4657     int64_t offset = 0;
4658     bool compatible = true;
4659     // Whether the start indices to dynamic update slice is a list,
4660     // output of a tuple/concatenate, we setup the update_start_indx
4661     // appropriately.
4662     if (ShapeUtil::IsScalar(update_start_indx->shape())) {
4663       update_start_indx = dynamic_update_slice;
4664       offset = 2;
4665     } else {
4666       if (update_start_indx->opcode() == HloOpcode::kTuple ||
4667           update_start_indx->opcode() == HloOpcode::kConcatenate) {
4668         offset = 0;
4669       } else {
4670         compatible = false;
4671       }
4672     }
4673     PaddingConfig padding_config;
4674     if (compatible) {
4675       for (int64_t dim = 0; dim < updated_shape.rank(); ++dim) {
4676         auto padding_config_dim = padding_config.add_dimensions();
4677         auto slice_dim_start = update_start_indx->operand(dim + offset);
4678         if (!Match(slice_dim_start, m::ConstantScalar())) {
4679           compatible = false;
4680           break;
4681         }
4682         VLOG(2) << "slice: " << slice_dim_start->ToString();
4683         absl::optional<int64> beg =
4684             slice_dim_start->literal().GetFirstInteger();
4685         if (!beg) {
4686           compatible = false;
4687           break;
4688         }
4689         VLOG(2) << "beg value: " << *beg;
4690         auto update_width = ShapeUtil::GetDimension(update_shape, dim);
4691         auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
4692         // Clamp beg so that it is non-negative.
4693         *beg = std::max<int64>(0, *beg);
4694         // Clamp beg so that it is in-bounds.
4695         *beg = std::min<int64>(bcast_width - update_width, *beg);
4696         VLOG(2) << "adjusted beg value: " << *beg;
4697         padding_config_dim->set_edge_padding_low(*beg);
4698         padding_config_dim->set_edge_padding_high(bcast_width -
4699                                                   (*beg + update_width));
4700         // dynamic_update_slice does not specify a stride
4701         padding_config_dim->set_interior_padding(0);
4702       }
4703     }
4704 
4705     if (compatible) {
4706       HloInstruction* pad =
4707           computation_->AddInstruction(HloInstruction::CreatePad(
4708               updated_shape, dus_update, pad_value, padding_config));
4709       VLOG(2) << dynamic_update_slice->ToString();
4710       VLOG(2) << " with pad:" << pad->ToString();
4711       VLOG(2) << " Computation before rewrite is: "
4712               << dynamic_update_slice->parent()->ToString();
4713       return ReplaceInstruction(dynamic_update_slice, pad);
4714     }
4715   }
4716 
4717   // DynamicUpdateSlice where operand and dus_update have the same size is
4718   // equal to dus_update.
4719   if (SameShape(dynamic_update_slice, dus_update)) {
4720     return ReplaceInstruction(dynamic_update_slice, dus_update);
4721   }
4722 
4723   // If any dimension of dus_update is 0, elide the DynamicUpdateSlice.  This
4724   // optimization becomes invalid should we later prefer to warn about out of
4725   // bound indices.
4726   if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
4727     return ReplaceInstruction(dynamic_update_slice, updated);
4728   }
4729   return Status::OK();
4730 }
4731 
HandleReduce(HloInstruction * hlo)4732 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
4733   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
4734   bool multi_output_reduce = reduce->shape().IsTuple();
4735   // For tuple reduce, we require all reduce shapes to be the same, up to the
4736   // element types, so we can just the first operand and the first result as a
4737   // representative.
4738   auto arg = reduce->inputs()[0];
4739   auto init_value = reduce->init_values()[0];
4740   const Shape& reduce_result_shape =
4741       multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
4742 
4743   absl::Span<const int64> dimensions(reduce->dimensions());
4744   HloComputation* function = reduce->to_apply();
4745   if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
4746       ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
4747     if (multi_output_reduce) {
4748       std::vector<HloInstruction*> broadcast_inits;
4749       int64_t inputs = reduce->input_count();
4750       for (int64_t i = 0; i < inputs; ++i) {
4751         broadcast_inits.push_back(computation_->AddInstruction(
4752             HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
4753                                             reduce->init_values()[i], {})));
4754       }
4755       return ReplaceWithNewInstruction(
4756           reduce, HloInstruction::CreateTuple(broadcast_inits));
4757     } else {
4758       return ReplaceWithNewInstruction(
4759           reduce,
4760           HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
4761     }
4762   }
4763 
4764   // Turn trivial variadic reductions into normal reductions.
4765   if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
4766       reduce->input_count() == 1 &&
4767       Match(function->root_instruction(), m::Tuple())) {
4768     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
4769         replacements;
4770     replacements[function->root_instruction()] = nullptr;
4771     auto new_function = computation_->parent()->AddEmbeddedComputation(
4772         function->CloneWithReplacements(
4773             std::move(replacements), /*extra_parameters=*/{},
4774             /*context=*/nullptr,
4775             /*suffix=*/"clone",
4776             /*new_root=*/function->root_instruction()->operand(0)));
4777     auto new_reduce = computation_->AddInstruction(
4778         HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
4779                                      reduce->dimensions(), new_function));
4780     return ReplaceWithNewInstruction(reduce,
4781                                      HloInstruction::CreateTuple({new_reduce}));
4782   }
4783 
4784   if (options_.is_layout_sensitive()) {
4785     return Status::OK();
4786   }
4787 
4788   // If the reduction results in the same number of elements, then the only
4789   // possible side effect would be a reshape. Since the init_value is an
4790   // identity of the reduction function, we can therefore replace the reduce
4791   // with a simple reshape, ignoring the reduction function completely.
4792   if (ShapeUtil::ElementsIn(reduce_result_shape) ==
4793       ShapeUtil::ElementsIn(arg->shape())) {
4794     if (multi_output_reduce) {
4795       std::vector<HloInstruction*> reshaped_args;
4796       int64_t inputs = reduce->input_count();
4797       for (int64_t i = 0; i < inputs; ++i) {
4798         reshaped_args.push_back(
4799             computation_->AddInstruction(HloInstruction::CreateReshape(
4800                 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
4801       }
4802       return ReplaceWithNewInstruction(
4803           reduce, HloInstruction::CreateTuple(reshaped_args));
4804     } else {
4805       return ReplaceWithNewInstruction(
4806           reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
4807     }
4808   }
4809 
4810   // TODO(b/131122694): Most of those optimizations below can be done for
4811   // multi-output reduces.
4812   if (multi_output_reduce) {
4813     return Status::OK();
4814   }
4815 
4816   // A Transpose feeding a reduce can simply permute the reduction dimensions
4817   // field if the output of the reduce is a vector or scalar. Higher ranked
4818   // result may require a transpose of the output.
4819   if (arg->opcode() == HloOpcode::kTranspose &&
4820       (reduce->shape().rank() < 2 || arg->user_count() == 1 ||
4821        absl::c_all_of(arg->users(), [](HloInstruction* use) {
4822          return use->opcode() == HloOpcode::kReduce;
4823        }))) {
4824     auto transpose_dimensions = arg->dimensions();
4825     std::vector<int64> new_reduce_dimensions;
4826     for (auto dim : dimensions) {
4827       new_reduce_dimensions.push_back(transpose_dimensions[dim]);
4828     }
4829 
4830     Shape new_reduce_result_shape = ShapeUtil::FilterDimensions(
4831         [&](const int64_t dim) {
4832           return !absl::c_linear_search(new_reduce_dimensions, dim);
4833         },
4834         arg->mutable_operand(0)->shape());
4835     HloInstruction* new_reduce =
4836         computation_->AddInstruction(HloInstruction::CreateReduce(
4837             new_reduce_result_shape, arg->mutable_operand(0), init_value,
4838             new_reduce_dimensions, function));
4839     reduce->SetupDerivedInstruction(new_reduce);
4840     std::vector<int64> new_transpose_dimensions;
4841     for (auto dim : transpose_dimensions) {
4842       if (absl::c_linear_search(new_reduce_dimensions, dim)) {
4843         continue;
4844       }
4845       new_transpose_dimensions.push_back(dim);
4846     }
4847 
4848     // If new transpose dimensions are sorted, then there is no need to
4849     // transpose reduce result.
4850     if (absl::c_is_sorted(new_transpose_dimensions)) {
4851       return ReplaceInstruction(reduce, new_reduce);
4852     }
4853     for (auto& d : new_transpose_dimensions) {
4854       auto old_dim = d;
4855       for (auto reduced_dim : new_reduce_dimensions) {
4856         if (old_dim > reduced_dim) {
4857           --d;
4858         }
4859       }
4860     }
4861     TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose,
4862                         MakeTransposeHlo(new_reduce, new_transpose_dimensions));
4863     return ReplaceInstruction(reduce, new_transpose);
4864   }
4865 
4866   // If a reduce feeds a reduce with the same computation and initial value,
4867   // they can be combined into a single reduce.
4868   if (arg->opcode() == HloOpcode::kReduce &&
4869       init_value->Identical(*arg->operand(1)) &&
4870       *function == *arg->to_apply()) {
4871     // Create a new reduce with the combined reduction dimensions of both
4872     // reduces.
4873     std::vector<int64> arg_dims = arg->dimensions();
4874     absl::c_sort(arg_dims);
4875     std::vector<int64> reduce_dims = reduce->dimensions();
4876     absl::c_sort(reduce_dims);
4877     // Transform reduce_dims to the same rank as the operand of the operand.
4878     for (int64_t arg_dim : arg_dims) {
4879       for (int64& dim : reduce_dims) {
4880         if (dim >= arg_dim) {
4881           ++dim;
4882         }
4883       }
4884     }
4885     std::vector<int64> new_dimensions;
4886     new_dimensions.reserve(arg->dimensions().size() +
4887                            reduce->dimensions().size());
4888     std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
4889                reduce_dims.end(), std::back_inserter(new_dimensions));
4890     return ReplaceWithNewInstruction(
4891         reduce, HloInstruction::CreateReduce(
4892                     reduce_result_shape, arg->mutable_operand(0), init_value,
4893                     new_dimensions, function));
4894   }
4895 
4896   // A reshape that collapses multiple dimensions into a dimension being
4897   // reduced can just reduce all of those dimensions instead of doing a
4898   // collapsing reshape before a reduction.
4899   if (options_.enable_reduce_of_reshape() &&
4900       arg->opcode() == HloOpcode::kReshape) {
4901     std::vector<std::pair<int64, int64>> unmodified_dims =
4902         ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
4903                                                  arg->shape());
4904     std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
4905     std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
4906     for (auto dim : dimensions) {
4907       arg_dim_in_output[dim] = false;
4908     }
4909     for (auto dim_pair : unmodified_dims) {
4910       arg_dim_unmodified[dim_pair.second] = true;
4911     }
4912     // The goal is to verify that all dimensions that are not removed in the
4913     // reduce are unmodified by the reshape. For example:
4914     // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
4915     bool can_move_reshape_into_reduce = true;
4916     for (int64_t i = 0; i < arg_dim_in_output.size(); ++i) {
4917       if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
4918         can_move_reshape_into_reduce = false;
4919       }
4920     }
4921     if (can_move_reshape_into_reduce) {
4922       changed_ = true;
4923       absl::flat_hash_set<int64> dimensions_not_to_reduce;
4924       for (auto dim_pair : unmodified_dims) {
4925         if (arg_dim_in_output[dim_pair.second]) {
4926           dimensions_not_to_reduce.insert(dim_pair.first);
4927         }
4928       }
4929       std::vector<int64> new_reduce_dimensions;
4930       for (int64_t i = 0; i < arg->operand(0)->shape().rank(); ++i) {
4931         if (!dimensions_not_to_reduce.contains(i)) {
4932           new_reduce_dimensions.push_back(i);
4933         }
4934       }
4935       return ReplaceWithNewInstruction(
4936           reduce, HloInstruction::CreateReduce(
4937                       reduce_result_shape, arg->mutable_operand(0), init_value,
4938                       new_reduce_dimensions, function));
4939     }
4940   }
4941   // Convert Reduce(concat({a,b,...})) to
4942   //  map(reduce(a),map(reduce(b),...,))
4943   //
4944   // This should make fusion easier or use less memory bandwidth in the unfused
4945   // case.
4946   if (arg->opcode() == HloOpcode::kConcatenate &&
4947       absl::c_linear_search(reduce->dimensions(),
4948                             arg->concatenate_dimension())) {
4949     HloInstruction* old_reduce = nullptr;
4950     for (HloInstruction* operand : arg->operands()) {
4951       HloInstruction* new_reduce = computation_->AddInstruction(
4952           HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
4953                                        reduce->dimensions(), function));
4954       if (old_reduce != nullptr) {
4955         new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
4956             reduce_result_shape, {old_reduce, new_reduce}, function));
4957       }
4958       old_reduce = new_reduce;
4959     }
4960     return ReplaceInstruction(reduce, old_reduce);
4961   }
4962 
4963   HloInstruction *dot, *lhs, *rhs;
4964   // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
4965   // batch dimensions of the dot. The transformation supports reducing other
4966   // dimensions as well.
4967   if (options_.enable_dot_strength_reduction() &&
4968       Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
4969       Match(reduce->to_apply()->root_instruction(),
4970             m::Add(m::Parameter(), m::Parameter())) &&
4971       absl::c_any_of(reduce->dimensions(), [&](int64_t dim) {
4972         return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
4973       })) {
4974     const auto& dnums = dot->dot_dimension_numbers();
4975     DotDimensionNumbers new_dnums = dnums;
4976     new_dnums.clear_lhs_batch_dimensions();
4977     new_dnums.clear_rhs_batch_dimensions();
4978     int64_t removed_dims = 0;
4979     for (int64_t batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
4980          ++batch_dim) {
4981       if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
4982         new_dnums.add_rhs_contracting_dimensions(
4983             dnums.rhs_batch_dimensions(batch_dim));
4984         new_dnums.add_lhs_contracting_dimensions(
4985             dnums.lhs_batch_dimensions(batch_dim));
4986         ++removed_dims;
4987       } else {
4988         new_dnums.add_rhs_batch_dimensions(
4989             dnums.rhs_batch_dimensions(batch_dim));
4990         new_dnums.add_lhs_batch_dimensions(
4991             dnums.lhs_batch_dimensions(batch_dim));
4992       }
4993     }
4994     std::vector<int64> reduce_dims;
4995     for (int64_t dim : reduce->dimensions()) {
4996       if (dim >= dnums.lhs_batch_dimensions_size()) {
4997         reduce_dims.push_back(dim - removed_dims);
4998       }
4999     }
5000     TF_ASSIGN_OR_RETURN(
5001         auto new_dot,
5002         MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
5003                    /*preferred_element_type=*/dot->shape().element_type()));
5004     dot->SetupDerivedInstruction(new_dot);
5005     if (reduce_dims.empty()) {
5006       return ReplaceInstruction(hlo, new_dot);
5007     }
5008     TF_ASSIGN_OR_RETURN(
5009         auto new_reduce,
5010         MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
5011     reduce->SetupDerivedInstruction(new_reduce);
5012     return ReplaceInstruction(hlo, new_reduce);
5013   }
5014   return Status::OK();
5015 }
5016 
HandleReduceWindow(HloInstruction * hlo)5017 Status AlgebraicSimplifierVisitor::HandleReduceWindow(HloInstruction* hlo) {
5018   auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
5019   const bool multi_output_reduce_window = reduce_window->shape().IsTuple();
5020   auto inputs = reduce_window->inputs();
5021   auto init_values = reduce_window->init_values();
5022   auto input_count = reduce_window->input_count();
5023   auto input_shapes = reduce_window->input_shapes();
5024   auto output_shapes = reduce_window->output_shapes();
5025   auto replace_with_span = [&](const std::vector<HloInstruction*>& elements) {
5026     CHECK(multi_output_reduce_window || elements.size() == 1);
5027     if (multi_output_reduce_window) {
5028       return ReplaceWithNewInstruction(reduce_window,
5029                                        HloInstruction::CreateTuple(elements));
5030     }
5031     return ReplaceInstruction(reduce_window, elements[0]);
5032   };
5033   // For tuple reduce, we require all reduce shapes to be the same, up to the
5034   // element types, so we can use just the first operand and the first result as
5035   // a representative.
5036   if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) ||
5037       ShapeUtil::IsZeroElementArray(*output_shapes[0])) {
5038     std::vector<HloInstruction*> broadcast_inits;
5039     for (int64_t i = 0; i < input_count; ++i) {
5040       broadcast_inits.push_back(
5041           computation_->AddInstruction(HloInstruction::CreateBroadcast(
5042               *output_shapes[i], init_values[i], {})));
5043     }
5044     return replace_with_span(broadcast_inits);
5045   }
5046   if (ShapeUtil::IsScalar(*input_shapes[0]) &&
5047       (!multi_output_reduce_window ||
5048        reduce_window->to_apply()->root_instruction()->opcode() ==
5049            HloOpcode::kTuple)) {
5050     std::vector<HloInstruction*> maps;
5051     for (int64_t i = 0; i < input_count; ++i) {
5052       TF_RET_CHECK(ShapeUtil::IsScalar(*input_shapes[i]));
5053       TF_RET_CHECK(ShapeUtil::IsScalar(*output_shapes[i]));
5054       HloInstruction* map_computation_root;
5055       absl::flat_hash_map<const HloInstruction*,
5056                           std::unique_ptr<HloInstruction>>
5057           replacements;
5058       if (multi_output_reduce_window) {
5059         map_computation_root =
5060             reduce_window->to_apply()->root_instruction()->mutable_operand(i);
5061         replacements[reduce_window->to_apply()->root_instruction()] = nullptr;
5062       } else {
5063         map_computation_root = reduce_window->to_apply()->root_instruction();
5064       }
5065       auto map_computation = computation_->parent()->AddEmbeddedComputation(
5066           reduce_window->to_apply()->CloneWithReplacements(
5067               std::move(replacements),
5068               /*extra_parameters=*/{}, nullptr, "clone", map_computation_root));
5069       auto map = computation_->AddInstruction(HloInstruction::CreateMap(
5070           reduce_window->shape(), {init_values[i], inputs[i]},
5071           map_computation));
5072       maps.push_back(map);
5073     }
5074     return replace_with_span(maps);
5075   }
5076   // Turn trivial variadic reduce windows into normal reduce windows.
5077   auto reduce_function_root = reduce_window->to_apply()->root_instruction();
5078   if (multi_output_reduce_window && input_count == 1 &&
5079       Match(reduce_function_root, m::Tuple())) {
5080     // Make a new reducer which is identical but does not have a tuple
5081     // instruction at the bottom.
5082     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5083         replacements;
5084     replacements[reduce_function_root] = nullptr;
5085     auto new_function = computation_->parent()->AddEmbeddedComputation(
5086         reduce_window->to_apply()->CloneWithReplacements(
5087             std::move(replacements), /*extra_parameters=*/{},
5088             /*context=*/nullptr,
5089             /*suffix=*/"clone",
5090             /*new_root=*/reduce_function_root->operand(0)));
5091     auto new_reduce =
5092         computation_->AddInstruction(HloInstruction::CreateReduceWindow(
5093             *output_shapes[0], inputs[0], init_values[0],
5094             reduce_window->window(), new_function));
5095     return ReplaceWithNewInstruction(reduce_window,
5096                                      HloInstruction::CreateTuple({new_reduce}));
5097   }
5098   // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
5099   if (multi_output_reduce_window) {
5100     return Status::OK();
5101   }
5102   auto operand = reduce_window->mutable_operand(0);
5103   auto function = reduce_window->to_apply();
5104   const Window& window = reduce_window->window();
5105 
5106   if (options_.enable_window_reduce_to_reduce_replacement()) {
5107     // A reduce window can be expressed as a reduce and a reshape if all
5108     // dimensions either have a window size of one or the entire dimension. If
5109     // there is no stride, dilation, or padding, this is as easy as checking the
5110     // size of the output shape and window dimension.
5111     //
5112     // The reshape is a bitcast since it adds one-sized dimensions. Often these
5113     // ones are immediately removed as well with another reshape. The
5114     // implementation of reduce tends to be slightly more efficient at reducing
5115     // entire dimensions compared to reduce window.
5116     auto effective_reduce_dims = [&] {
5117       if (window_util::HasStride(window) || window_util::HasDilation(window) ||
5118           window_util::HasPadding(window)) {
5119         return absl::InlinedVector<int64, 8>{};
5120       }
5121       absl::InlinedVector<int64, 8> reduce_dims;
5122       for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5123         if (window.dimensions(i).size() == 1) {
5124           continue;
5125         } else if (reduce_window->shape().dimensions(i) == 1) {
5126           reduce_dims.push_back(i);
5127         } else {
5128           return absl::InlinedVector<int64, 8>{};
5129         }
5130       }
5131       return reduce_dims;
5132     }();
5133 
5134     // If a reduce window can be expressed as a reduce, do so and reshape the
5135     // output.
5136     if (!effective_reduce_dims.empty()) {
5137       Shape reduce_shape = ShapeUtil::FilterDimensions(
5138           [&](int64_t dim) {
5139             return !absl::c_linear_search(effective_reduce_dims, dim);
5140           },
5141           reduce_window->shape());
5142       simplifier_->UpdateLayout(&reduce_shape);
5143       HloInstruction* reduce =
5144           computation_->AddInstruction(HloInstruction::CreateReduce(
5145               /*shape=*/reduce_shape,
5146               /*operand=*/operand,
5147               /*init_value=*/reduce_window->mutable_operand(1),
5148               /*dimensions_to_reduce=*/effective_reduce_dims,
5149               /*reduce_computation=*/function));
5150       return ReplaceWithNewInstruction(
5151           reduce_window,
5152           HloInstruction::CreateReshape(reduce_window->shape(), reduce));
5153     }
5154   }
5155 
5156   // This optimization folds a pad op into reduce_window.
5157   HloInstruction* pad;
5158   const HloInstruction* convert = nullptr;
5159   if (operand->opcode() == HloOpcode::kPad) {
5160     pad = operand;
5161   } else if (operand->opcode() == HloOpcode::kConvert &&
5162              operand->operand(0)->opcode() == HloOpcode::kPad) {
5163     convert = operand;
5164     pad = operand->mutable_operand(0);
5165   } else {
5166     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
5167     return Status::OK();
5168   }
5169 
5170   VLOG(10) << "Considering folding Pad: " << pad->ToString()
5171            << "\ninto reduce-window: " << reduce_window->ToString()
5172            << (convert != nullptr
5173                    ? absl::StrCat("\nvia convert: ", convert->ToString())
5174                    : "");
5175 
5176   // Do not fold interior padding into ReduceWindow since the backends do not
5177   // support it.
5178   const PaddingConfig& pad_config = pad->padding_config();
5179   if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
5180     VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
5181     return Status::OK();
5182   }
5183 
5184   // If reduce_window already has padding, the pad value of the pad op and the
5185   // init value of reduce_window must match to allow folding the pad.
5186   const HloInstruction* pad_value = pad->operand(1);
5187   const HloInstruction* reduce_init_value = reduce_window->operand(1);
5188   if (pad_value != reduce_init_value) {
5189     auto literals_are_equivalent = [&] {
5190       auto& pad_literal = pad_value->literal();
5191       auto& reduce_init_literal = reduce_init_value->literal();
5192       if (pad_literal == reduce_init_literal) {
5193         return true;
5194       }
5195       auto converted_pad_literal =
5196           pad_literal.ConvertToShape(reduce_init_value->shape());
5197       if (!converted_pad_literal.ok()) {
5198         return false;
5199       }
5200       return converted_pad_literal.ValueOrDie() == reduce_init_literal;
5201     };
5202     // The pad value is usually a constant, so we handle that case and do not
5203     // try to get more fancy about proving equivalence in cases beyond that.
5204     if (pad_value->opcode() != HloOpcode::kConstant ||
5205         reduce_init_value->opcode() != HloOpcode::kConstant ||
5206         !literals_are_equivalent()) {
5207       VLOG(10) << "Not folding pad into reduce-window due to different pad "
5208                   "values.";
5209       return Status::OK();
5210     }
5211   }
5212 
5213   // If the pad puts a single non-identity value in each window that we're
5214   // reducing, then this is a broadcast.
5215   HloInstruction* pad_operand = pad->mutable_operand(0);
5216   auto is_effective_broadcast = [&] {
5217     if (window_util::HasStride(window)) {
5218       VLOG(10) << "Window has stride.";
5219       return false;
5220     }
5221     if (!window_util::HasSymmetricPadding(pad_config)) {
5222       VLOG(10) << "Window has uneven padding.";
5223       return false;
5224     }
5225     if (HasInteriorPadding(pad_config)) {
5226       VLOG(10) << "Window has interior padding.";
5227       return false;
5228     }
5229     for (int64_t i = 0; i < pad_config.dimensions_size(); ++i) {
5230       const auto& pad_dimension = pad_config.dimensions(i);
5231       if ((pad_dimension.edge_padding_low() != 0 ||
5232            pad_dimension.edge_padding_high() != 0) &&
5233           pad_operand->shape().dimensions(i) != 1) {
5234         VLOG(10) << "Found non-trivial dimension being padded: " << i;
5235         return false;
5236       }
5237     }
5238     VLOG(10) << "Found to be padding trivial dimensions only.";
5239 
5240     for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5241       const auto& pad_dimension = pad_config.dimensions(i);
5242       const WindowDimension& window_dimension = window.dimensions(i);
5243       bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
5244                                     pad_dimension.edge_padding_high() != 0);
5245       if (dimension_has_padding &&
5246           window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
5247         VLOG(10) << "Found window did not cover single unpadded element in "
5248                     "dimension: "
5249                  << i;
5250         return false;
5251       }
5252       if (pad_operand->shape().dimensions(i) != 1 &&
5253           window_dimension.size() != 1) {
5254         VLOG(10) << "Found window covers more than one element in non-trivial "
5255                     "dimension: "
5256                  << i;
5257         return false;
5258       }
5259     }
5260     VLOG(10) << "Found window covers a single unpadded element.";
5261     return true;
5262   };
5263 
5264   HloInstruction* new_reduce_window_operand;
5265   if (convert != nullptr) {
5266     Shape changed_shape = ShapeUtil::ChangeElementType(
5267         pad_operand->shape(), convert->shape().element_type());
5268     simplifier_->UpdateLayout(&changed_shape);
5269     new_reduce_window_operand = computation_->AddInstruction(
5270         HloInstruction::CreateConvert(changed_shape, pad_operand));
5271   } else {
5272     new_reduce_window_operand = pad_operand;
5273   }
5274 
5275   if (is_effective_broadcast()) {
5276     VLOG(10) << "Replacing pad/reduce-window with broadcast.";
5277     auto fadd = [this](std::unique_ptr<HloInstruction> x) {
5278       return computation_->AddInstruction(std::move(x));
5279     };
5280     return ReplaceWithNewInstruction(
5281         reduce_window, HloInstruction::CreateBroadcastSequence(
5282                            /*output_shape=*/reduce_window->shape(),
5283                            /*operand=*/new_reduce_window_operand, fadd));
5284   }
5285 
5286   // Carry out the folding of the pad into reduce_window.
5287   VLOG(10) << "Folding pad into reduce-window.";
5288   Window new_window = window;
5289   const int64_t rank = reduce_window->shape().rank();
5290   TF_RET_CHECK(pad_config.dimensions_size() == rank);
5291   TF_RET_CHECK(window.dimensions_size() == rank);
5292   for (int64_t i = 0; i < rank; ++i) {
5293     const auto& pad_dim = pad_config.dimensions(i);
5294     auto& window_dim = *new_window.mutable_dimensions(i);
5295     window_dim.set_padding_low(window_dim.padding_low() +
5296                                pad_dim.edge_padding_low());
5297     window_dim.set_padding_high(window_dim.padding_high() +
5298                                 pad_dim.edge_padding_high());
5299     if (pad_dim.interior_padding() != 0) {
5300       CHECK_EQ(window_dim.base_dilation(), 1);
5301       window_dim.set_base_dilation(1 + pad_dim.interior_padding());
5302     }
5303   }
5304 
5305   return ReplaceWithNewInstruction(
5306       reduce_window, HloInstruction::CreateReduceWindow(
5307                          /*shape=*/reduce_window->shape(),
5308                          /*operand=*/new_reduce_window_operand,
5309                          /*init_value=*/reduce_window->mutable_operand(1),
5310                          /*window=*/new_window,
5311                          /*reduce_computation=*/function));
5312 }
5313 
HandleSelect(HloInstruction * select)5314 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
5315   // select(x, y, y) -> y.
5316   if (select->operand(1) == select->operand(2)) {
5317     return ReplaceInstruction(select, select->mutable_operand(1));
5318   }
5319   // select(true, x, y) -> x.
5320   if (IsAll(select->operand(0), true)) {
5321     return ReplaceInstruction(select, select->mutable_operand(1));
5322   }
5323   // select(false, x, y) -> y.
5324   if (IsAll(select->operand(0), false)) {
5325     return ReplaceInstruction(select, select->mutable_operand(2));
5326   }
5327   return Status::OK();
5328 }
5329 
HandleScatter(HloInstruction * scatter)5330 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) {
5331   if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) &&
5332       ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) {
5333     return Status::OK();
5334   }
5335   if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) &&
5336       SameShape(scatter, scatter->operand(0)) &&
5337       SameShape(scatter, scatter->operand(2))) {
5338     return ReplaceWithNewInstruction(
5339         scatter, HloInstruction::CreateMap(
5340                      scatter->shape(),
5341                      {scatter->mutable_operand(0), scatter->mutable_operand(2)},
5342                      scatter->to_apply()));
5343   }
5344   return Status::OK();
5345 }
HandleSort(HloInstruction * sort)5346 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5347   auto operand = sort->mutable_operand(0);
5348   int64_t dimension_to_sort = sort->dimensions(0);
5349   if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5350       operand->shape().dimensions(dimension_to_sort) <= 1) {
5351     if (sort->operand_count() == 1) {
5352       return ReplaceInstruction(sort, operand);
5353     }
5354     // If it is key/value sort, the output of sort is a tuple.
5355     return ReplaceWithNewInstruction(
5356         sort, HloInstruction::CreateTuple(sort->operands()));
5357   }
5358   return Status::OK();
5359 }
5360 
HandleSqrt(HloInstruction * sqrt)5361 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
5362   VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
5363   HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
5364   if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
5365       sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
5366     return ReplaceWithNewInstruction(
5367         sqrt, HloInstruction::CreateUnary(
5368                   sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
5369                   sqrt_operand->mutable_operand(0)));
5370   }
5371   return Status::OK();
5372 }
5373 
5374 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64> perm)5375 bool OnlyPermutesDegenerateDims(const Shape& shape,
5376                                 absl::Span<const int64> perm) {
5377   std::vector<int64> new_permutation;
5378   int64_t degenerate_count = 0;
5379   for (int64_t i = 0; i < perm.size(); ++i) {
5380     if (shape.dimensions(i) != 1) {
5381       new_permutation.push_back(perm[i]);
5382     } else {
5383       ++degenerate_count;
5384     }
5385   }
5386   return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
5387 }
5388 }  // namespace
5389 
HandleTranspose(HloInstruction * transpose)5390 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
5391   auto operand = transpose->mutable_operand(0);
5392   if (std::is_sorted(transpose->dimensions().begin(),
5393                      transpose->dimensions().end())) {
5394     VLOG(10) << "deleting no-op transpose";
5395     return ReplaceInstruction(transpose, operand);
5396   }
5397 
5398   if (HloOpcode::kTranspose == operand->opcode()) {
5399     return ReplaceWithNewInstruction(
5400         transpose, HloInstruction::CreateTranspose(
5401                        transpose->shape(), operand->mutable_operand(0),
5402                        ComposePermutations(operand->dimensions(),
5403                                            transpose->dimensions())));
5404   }
5405 
5406   // Convert transpose(dot(a,b)) to dot(b,a).
5407   if (operand->opcode() == HloOpcode::kDot && operand->user_count() == 1 &&
5408       operand->shape().rank() == 2) {
5409     TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr<bool> {
5410       const auto& dnums = operand->dot_dimension_numbers();
5411       if (dnums.lhs_batch_dimensions_size() != 0) {
5412         return false;
5413       }
5414       HloInstruction* lhs = operand->mutable_operand(0);
5415       if (lhs->shape().rank() != 1 + dnums.lhs_contracting_dimensions_size()) {
5416         return false;
5417       }
5418       HloInstruction* rhs = operand->mutable_operand(1);
5419       if (rhs->shape().rank() != 1 + dnums.rhs_contracting_dimensions_size()) {
5420         return false;
5421       }
5422       DotDimensionNumbers new_dnums;
5423       *new_dnums.mutable_lhs_contracting_dimensions() =
5424           dnums.rhs_contracting_dimensions();
5425       *new_dnums.mutable_rhs_contracting_dimensions() =
5426           dnums.lhs_contracting_dimensions();
5427       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
5428           transpose, HloInstruction::CreateDot(transpose->shape(), /*lhs=*/rhs,
5429                                                /*rhs=*/lhs, new_dnums,
5430                                                operand->precision_config())));
5431       return true;
5432     }());
5433     if (did_transform) {
5434       return Status::OK();
5435     }
5436   }
5437 
5438   // Replace transpose with a reshape if more than one degenerate method is
5439   // permuted.
5440   if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
5441     return ReplaceWithNewInstruction(
5442         transpose, HloInstruction::CreateReshape(
5443                        transpose->shape(), transpose->mutable_operand(0)));
5444   }
5445 
5446   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
5447     *operand->mutable_shape() = transpose->shape();
5448     return ReplaceInstruction(transpose, operand);
5449   }
5450 
5451   if (options_.is_layout_sensitive() &&
5452       options_.replace_transpose_with_bitcast() &&
5453       TransposeIsBitcast(transpose)) {
5454     ReplaceWithBitcast(transpose);
5455     return Status::OK();
5456   }
5457 
5458   return Status::OK();
5459 }
5460 
FoldConvInputPad(HloInstruction * convolution)5461 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
5462     HloInstruction* convolution) {
5463   HloInstruction *lhs, *a, *b;
5464   if (Match(convolution,
5465             m::Convolution(m::Pad(&lhs, m::Op(&a), m::ConstantScalar(0)),
5466                            m::Op(&b)))) {
5467     const auto& window = convolution->window();
5468     const ConvolutionDimensionNumbers& dnums =
5469         convolution->convolution_dimension_numbers();
5470 
5471     const auto& padding = lhs->padding_config();
5472 
5473     // Can't pad batch or feature dims.
5474     for (int64_t dim :
5475          {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
5476       const auto& p = padding.dimensions(dim);
5477       if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5478           p.interior_padding() != 0) {
5479         return false;
5480       }
5481     }
5482 
5483     // Compute the window which is the result of merging the kPad and the
5484     // convolution's existing window.
5485     Window new_window = window;
5486     for (int64_t dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
5487       auto& w = *new_window.mutable_dimensions(dim);
5488       const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
5489       // Edge padding composes with itself in the straightforward way, but
5490       // composing interior padding is nontrivial, and we cowardly refuse to
5491       // think about it. If we see interior padding in either the kPad or conv,
5492       // bail if there's any sort of padding in the other.
5493       if (p.interior_padding() != 0 &&
5494           (w.padding_low() != 0 || w.padding_high() != 0 ||
5495            w.base_dilation() != 1)) {
5496         return false;
5497       }
5498       if (w.base_dilation() != 1 &&
5499           (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5500            p.interior_padding() != 0)) {
5501         return false;
5502       }
5503 
5504       w.set_padding_low(w.padding_low() + p.edge_padding_low());
5505       w.set_padding_high(w.padding_high() + p.edge_padding_high());
5506       if (p.interior_padding() != 0) {
5507         CHECK_EQ(w.base_dilation(), 1);
5508         w.set_base_dilation(1 + p.interior_padding());
5509       }
5510     }
5511 
5512     auto new_conv =
5513         convolution->CloneWithNewOperands(convolution->shape(), {a, b});
5514     new_conv->set_window(new_window);
5515     TF_RETURN_IF_ERROR(
5516         ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5517     return true;
5518   }
5519   return false;
5520 }
5521 
FoldConvFilterPad(HloInstruction * convolution)5522 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
5523     HloInstruction* convolution) {
5524   auto* lhs = convolution->mutable_operand(0);
5525   auto* rhs = convolution->mutable_operand(1);
5526   const ConvolutionDimensionNumbers& dnums =
5527       convolution->convolution_dimension_numbers();
5528 
5529   if (rhs->opcode() != HloOpcode::kPad) {
5530     return false;
5531   }
5532 
5533   // Convolution's padding is always zero, so bail if the kPad is adding
5534   // something other than zero.
5535   if (!IsAll(rhs->operand(1), 0)) {
5536     return false;
5537   }
5538 
5539   const auto& padding = rhs->padding_config();
5540 
5541   // Can't pad or dilate feature dims.
5542   for (int64_t dim : {dnums.kernel_input_feature_dimension(),
5543                       dnums.kernel_output_feature_dimension()}) {
5544     const auto& p = padding.dimensions(dim);
5545     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5546         p.interior_padding() != 0) {
5547       return false;
5548     }
5549   }
5550 
5551   // Compute the window which is the result of merging the kPad and the
5552   // convolution's existing window.
5553   Window new_window = convolution->window();
5554   for (int64_t dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
5555     auto& w = *new_window.mutable_dimensions(dim);
5556     const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
5557 
5558     // We can only do this transformation if p adds dilation to the filter --
5559     // edge padding on the filter is not supported in conv.
5560     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
5561       return false;
5562     }
5563 
5564     // Nothing to do if the kPad for this dim is entirely a nop.
5565     if (p.interior_padding() == 0) {
5566       continue;
5567     }
5568 
5569     // We cowardly refuse to think about how dilation composes with itself;
5570     // bail if both the kPad and conv have dilation on this dimension.
5571     if (w.window_dilation() > 1) {
5572       return false;
5573     }
5574     CHECK_EQ(w.window_dilation(), 1);
5575     w.set_window_dilation(1 + p.interior_padding());
5576     w.set_size(rhs->operand(0)->shape().dimensions(
5577         dnums.kernel_spatial_dimensions(dim)));
5578   }
5579 
5580   auto new_conv = convolution->CloneWithNewOperands(
5581       convolution->shape(), {lhs, rhs->mutable_operand(0)});
5582   new_conv->set_window(new_window);
5583   TF_RETURN_IF_ERROR(
5584       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5585   return true;
5586 }
5587 
SwapConvOperands(HloInstruction * convolution)5588 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
5589     HloInstruction* convolution) {
5590   if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
5591     return false;
5592   }
5593   if (convolution->feature_group_count() > 1 ||
5594       convolution->batch_group_count() > 1) {
5595     return false;
5596   }
5597 
5598   const auto& dnums = convolution->convolution_dimension_numbers();
5599   const auto& window_dims = convolution->window().dimensions();
5600   Window swapped_window;
5601 
5602   HloInstruction *input = convolution->mutable_operand(0),
5603                  *kernel = convolution->mutable_operand(1);
5604   int64_t kernel_product = 1;
5605   int64_t swapped_kernel_product = 1;
5606   DimensionVector reverse_dimensions;
5607   for (int64_t spatial_dim = 0;
5608        spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
5609     const int64_t kernel_size = window_dims[spatial_dim].size();
5610     const bool can_be_group_or_contraction =
5611         !window_dims[spatial_dim].window_reversal() &&
5612         window_dims[spatial_dim].padding_low() == 0 &&
5613         window_dims[spatial_dim].padding_high() == 0 &&
5614         window_dims[spatial_dim].window_dilation() == 1;
5615     const bool is_group_dim =
5616         can_be_group_or_contraction &&
5617         window_dims[spatial_dim].base_dilation() == kernel_size &&
5618         window_dims[spatial_dim].stride() == kernel_size - 1;
5619     const int64_t input_size =
5620         input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
5621     const bool is_pure_contraction_dim =
5622         kernel_size == input_size && can_be_group_or_contraction &&
5623         window_dims[spatial_dim].base_dilation() == 1 &&
5624         window_dims[spatial_dim].stride() == 1;
5625     if (is_group_dim || is_pure_contraction_dim) {
5626       *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
5627       continue;
5628     }
5629 
5630     const int64_t dilated_kernel_size =
5631         1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
5632     const int64_t dilated_input_size =
5633         1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
5634 
5635     // Don't decide to swap if the input size is one, since many convolution
5636     // implementations can easily hand that special case efficiently.
5637     kernel_product *= kernel_size;
5638     swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
5639 
5640     auto new_dim = swapped_window.add_dimensions();
5641     new_dim->set_size(input_size);
5642     // If the kernel is not reversed, the activations must be manually reversed.
5643     if (!window_dims[spatial_dim].window_reversal()) {
5644       reverse_dimensions.push_back(
5645           dnums.kernel_spatial_dimensions(spatial_dim));
5646     }
5647     // The input is not originally reversed so it must be reversed to move the
5648     // kernel.
5649     new_dim->set_window_reversal(true);
5650     // Base dilation and window dilation switch places.
5651     new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
5652     new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
5653     new_dim->set_stride(window_dims[spatial_dim].stride());
5654     new_dim->set_padding_low(dilated_input_size +
5655                              window_dims[spatial_dim].padding_low() -
5656                              dilated_kernel_size);
5657     new_dim->set_padding_high(dilated_input_size +
5658                               window_dims[spatial_dim].padding_high() -
5659                               dilated_kernel_size);
5660   }
5661 
5662   // Don't transform if a naive convolution implementation would not have fewer
5663   // flops.
5664   if (kernel_product <= swapped_kernel_product) {
5665     return false;
5666   }
5667   ConvolutionDimensionNumbers swapped_dnums;
5668   *swapped_dnums.mutable_output_spatial_dimensions() =
5669       dnums.output_spatial_dimensions();
5670   // Swap batch and output feature of the output.
5671   swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
5672   swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
5673 
5674   // Swap input dnums with kernel dnums
5675   *swapped_dnums.mutable_input_spatial_dimensions() =
5676       dnums.kernel_spatial_dimensions();
5677   swapped_dnums.set_input_batch_dimension(
5678       dnums.kernel_output_feature_dimension());
5679   swapped_dnums.set_input_feature_dimension(
5680       dnums.kernel_input_feature_dimension());
5681 
5682   // Swap kernel dnums with input dnums
5683   *swapped_dnums.mutable_kernel_spatial_dimensions() =
5684       dnums.input_spatial_dimensions();
5685   swapped_dnums.set_kernel_output_feature_dimension(
5686       dnums.input_batch_dimension());
5687   swapped_dnums.set_kernel_input_feature_dimension(
5688       dnums.input_feature_dimension());
5689 
5690   PrecisionConfig precision_config;
5691   precision_config.add_operand_precision(
5692       convolution->precision_config().operand_precision(1));
5693   precision_config.add_operand_precision(
5694       convolution->precision_config().operand_precision(0));
5695   if (!reverse_dimensions.empty()) {
5696     TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
5697   }
5698   TF_ASSIGN_OR_RETURN(
5699       HloInstruction * new_convolution,
5700       MakeConvolveHlo(
5701           kernel, input, /*feature_group_count=*/1,
5702           /*batch_group_count=*/1, swapped_window, swapped_dnums,
5703           precision_config,
5704           /*preferred_element_type=*/convolution->shape().element_type()));
5705 
5706   convolution->SetupDerivedInstruction(new_convolution);
5707   TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
5708 
5709   return true;
5710 }
5711 
SimplifyConvToDot(HloInstruction * convolution)5712 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
5713     HloInstruction* convolution) {
5714   auto* lhs = convolution->mutable_operand(0);
5715   auto* rhs = convolution->mutable_operand(1);
5716   const auto& window = convolution->window();
5717   const ConvolutionDimensionNumbers& dnums =
5718       convolution->convolution_dimension_numbers();
5719 
5720   if (!options_.enable_conv_simplification()) {
5721     return false;
5722   }
5723 
5724   // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
5725   // layout-insensitive mode, for fear of adding nontrivial reshapes.
5726   if (!options_.is_layout_sensitive()) {
5727     return false;
5728   }
5729 
5730   const Shape& input_shape = lhs->shape();
5731   const Shape& filter_shape = rhs->shape();
5732   const Shape& convolution_shape = convolution->shape();
5733   TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
5734   TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
5735   TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
5736 
5737   // Require the spatial dimensions in the kernel to have a bound of one.
5738   for (int64_t i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
5739     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
5740       return false;
5741     }
5742   }
5743 
5744   // Stride ignores part of the output, which matrix multiplication does not do,
5745   // so require no stride. Padding and base (lhs) dilation both implicitly
5746   // extend the data, which matrix multiplication also does not do, so require
5747   // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
5748   // for a 1x1 window, so window dilation is no problem.
5749   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
5750       window_util::HasBaseDilation(window)) {
5751     return false;
5752   }
5753 
5754   // Also, the shapes must align for a rowmajor matmul:
5755   // - the input and output have the same layout.
5756   // - for input/output, the channel dimension must be the most minor. Other
5757   //   spatial dims can be in any order.
5758   // - for filters, the input channel dimension must be more major than the
5759   //   output channel dimension. The width+height don't matter because
5760   //   they are 1.
5761   //
5762   // These constraints are harsh. If the channel dimension is the most major
5763   // and/or the layout of input/output feature dimensions are reversed, we can
5764   // still convert Conv into more efficient Matmul with operand transposition
5765   // (such as the transposition flags in cuBLAS SGEMM).
5766   if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
5767       LayoutUtil::Minor(input_shape.layout(), 0) !=
5768           dnums.input_feature_dimension() ||
5769       LayoutUtil::Minor(convolution_shape.layout(), 0) !=
5770           dnums.output_feature_dimension() ||
5771       // The input feature dimension should come later in the minor-to-major
5772       // order.
5773       (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5774                            dnums.kernel_input_feature_dimension()) <
5775        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5776                            dnums.kernel_output_feature_dimension()))) {
5777     return false;
5778   }
5779 
5780   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
5781     std::vector<int64> dims(operand->shape().dimensions_size());
5782     std::iota(dims.begin(), dims.end(), 0);
5783     return computation_->AddInstruction(
5784         HloInstruction::CreateBitcast(shape, operand));
5785   };
5786 
5787   // Replace it with a dot, with bitcasts around it to get the right shape.
5788   const int64_t input_channels =
5789       input_shape.dimensions(dnums.input_feature_dimension());
5790   const int64_t output_channels =
5791       filter_shape.dimensions(dnums.kernel_output_feature_dimension());
5792 
5793   // Computes the product of the non-feature dimensions.
5794   int64_t conv_width = 1;
5795   for (int i = 0; i < input_shape.dimensions_size(); ++i) {
5796     if (i != dnums.input_feature_dimension()) {
5797       conv_width *= input_shape.dimensions(i);
5798     }
5799   }
5800 
5801   // We already checked feature_dimension is most minor, so data in input_shape
5802   // and row-major {conv_width,input_channels} are bitwise identical.
5803   Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5804       input_shape.element_type(), {conv_width, input_channels});
5805   simplifier_->UpdateLayout(&new_input_shape);
5806   // We already checked input_feature_dimension is more major than
5807   // output_feature_dimension, so data in filter_shape and row-major
5808   // {input_channels,output_channels} are bitwise identical.
5809   Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5810       filter_shape.element_type(), {input_channels, output_channels});
5811   simplifier_->UpdateLayout(&new_filter_shape);
5812   Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5813       convolution_shape.element_type(), {conv_width, output_channels});
5814   simplifier_->UpdateLayout(&dot_output_shape);
5815 
5816   auto new_lhs = add_bitcast(new_input_shape, lhs);
5817   auto new_rhs = add_bitcast(new_filter_shape, rhs);
5818   DotDimensionNumbers dot_dimension_numbers;
5819   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
5820   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
5821   auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
5822       dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
5823       convolution->precision_config()));
5824 
5825   TF_RETURN_IF_ERROR(
5826       ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
5827   return true;
5828 }
5829 
HandleConvolution(HloInstruction * convolution)5830 Status AlgebraicSimplifierVisitor::HandleConvolution(
5831     HloInstruction* convolution) {
5832   if (options_.enable_scalar_multiply_reduction()) {
5833     TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
5834   }
5835 
5836   // Zero-sized input or filter.
5837   if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
5838       ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
5839     return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
5840   }
5841 
5842   // Try to merge padding/dilation of the input with the convolution's window.
5843   TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
5844   if (folded_input_pad) {
5845     return Status::OK();
5846   }
5847 
5848   // Try to merge dilation of the filter with the convolution's window.
5849   TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
5850   if (folded_filter_pad) {
5851     return Status::OK();
5852   }
5853 
5854   // Try to swap convolution operands.
5855   TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
5856   if (swapped) {
5857     return Status::OK();
5858   }
5859   // Try to replace the convolution with a kDot instruction.
5860   TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
5861   if (replaced_with_dot) {
5862     return Status::OK();
5863   }
5864 
5865   return Status::OK();
5866 }
5867 
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)5868 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
5869     HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
5870     HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
5871   // Ensure shapes of min and max operand are equal to match current shape
5872   // inference.
5873   if (!SameShape(min_operand, max_operand)) {
5874     return false;
5875   }
5876 
5877   auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
5878                                              max_operand, operand, min_operand);
5879   TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
5880   return true;
5881 }
5882 
HandleMap(HloInstruction * map)5883 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
5884   auto* map_computation = map->to_apply();
5885   auto* map_root = map_computation->root_instruction();
5886   if (map_root->opcode() == HloOpcode::kParameter) {
5887     ReplaceInstructionIfSameShape(
5888         map, map->mutable_operand(map_root->parameter_number()));
5889     return Status::OK();
5890   }
5891   if (map_root->opcode() == HloOpcode::kConstant) {
5892     if (!ShapeUtil::IsScalar(map_root->shape())) {
5893       return Status::OK();
5894     }
5895     auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
5896     if (ShapeUtil::IsScalar(map->shape())) {
5897       return ReplaceWithNewInstruction(map, std::move(clone));
5898     }
5899     return ReplaceWithNewInstruction(
5900         map,
5901         HloInstruction::CreateBroadcast(
5902             map->shape(), computation_->AddInstruction(std::move(clone)), {}));
5903   }
5904   // Inline the map if the map computation only contains an elementwise
5905   // operation that can accept arbitrary shapes.
5906   if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
5907     return Status::OK();
5908   }
5909   std::vector<HloInstruction*> new_operands;
5910   for (auto* root_operand : map_root->operands()) {
5911     if (root_operand->opcode() != HloOpcode::kParameter) {
5912       return Status::OK();
5913     }
5914     new_operands.push_back(
5915         map->mutable_operand(root_operand->parameter_number()));
5916   }
5917   auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
5918   return ReplaceWithNewInstruction(map, std::move(clone));
5919 }
5920 
Run(HloModule * module)5921 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
5922   XLA_VLOG_LINES(2,
5923                  "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
5924   bool changed = false;
5925   AlgebraicSimplifierVisitor visitor(options_, this);
5926   for (auto* comp : module->MakeNonfusionComputations()) {
5927     if (visitor.Run(comp, options_, this)) {
5928       changed = true;
5929     }
5930   }
5931   XLA_VLOG_LINES(2,
5932                  "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
5933   return changed;
5934 }
5935 
5936 }  // namespace xla
5937