1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/bits.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/types.h"
67 #include "tensorflow/stream_executor/lib/statusor.h"
68
69 namespace xla {
70
71 namespace {
72
73 namespace m = match;
74
IsAll(const HloInstruction * op,int8_t value)75 bool IsAll(const HloInstruction* op, int8_t value) {
76 switch (op->opcode()) {
77 case HloOpcode::kBroadcast:
78 return IsAll(op->operand(0), value);
79 case HloOpcode::kConstant:
80 return op->literal().IsAll(value);
81 default:
82 return false;
83 }
84 }
85
IsAnyOperandComplex(const HloInstruction * hlo)86 bool IsAnyOperandComplex(const HloInstruction* hlo) {
87 for (auto operand : hlo->operands()) {
88 if (ShapeUtil::ElementIsComplex(operand->shape())) {
89 return true;
90 }
91 }
92 return false;
93 }
94
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)95 bool IsPositive(const HloInstruction* hlo,
96 const AlgebraicSimplifierOptions& options) {
97 // Utility only handles real types.
98 if (IsAnyOperandComplex(hlo)) {
99 return false;
100 }
101 switch (hlo->opcode()) {
102 case HloOpcode::kGetTupleElement: {
103 const HloInstruction* gte_operand = hlo->operand(0);
104 switch (gte_operand->opcode()) {
105 case HloOpcode::kCustomCall: {
106 const auto& target = gte_operand->custom_call_target();
107 return target ==
108 options.get_cudnn_batchnorm_forward_training_metadata() &&
109 hlo->tuple_index() == 2;
110 }
111 default:
112 return false;
113 }
114 }
115 case HloOpcode::kPower:
116 case HloOpcode::kAbs:
117 case HloOpcode::kRsqrt:
118 case HloOpcode::kSqrt:
119 return IsPositive(hlo->operand(0), options);
120
121 case HloOpcode::kMultiply: {
122 return hlo->operand(0) == hlo->operand(1) &&
123 IsPositive(hlo->operand(0), options);
124 }
125 default:
126 return false;
127 }
128 }
129
GetConstantValue(const HloInstruction * inst)130 absl::optional<double> GetConstantValue(const HloInstruction* inst) {
131 if (!ShapeUtil::IsEffectiveScalar(inst->shape())) {
132 return absl::nullopt;
133 }
134 switch (inst->shape().element_type()) {
135 case F16:
136 return static_cast<float>(inst->literal().GetFirstElement<half>());
137 case BF16:
138 return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
139 case F32:
140 return inst->literal().GetFirstElement<float>();
141 case F64:
142 return inst->literal().GetFirstElement<double>();
143 default:
144 return absl::nullopt;
145 }
146 }
147
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)148 bool IsNonNegative(const HloInstruction* hlo,
149 const AlgebraicSimplifierOptions& options) {
150 // Utility only handles real types.
151 if (IsAnyOperandComplex(hlo)) {
152 return false;
153 }
154 switch (hlo->opcode()) {
155 case HloOpcode::kMultiply: {
156 return hlo->operand(0) == hlo->operand(1);
157 }
158 case HloOpcode::kAbs: {
159 return true;
160 }
161 case HloOpcode::kBroadcast: {
162 return IsNonNegative(hlo->operand(0), options);
163 }
164 case HloOpcode::kConstant: {
165 if (absl::optional<double> value = GetConstantValue(hlo)) {
166 return *value >= 0.0;
167 }
168 return false;
169 }
170 case HloOpcode::kMaximum: {
171 return IsNonNegative(hlo->operand(0), options) ||
172 IsNonNegative(hlo->operand(1), options);
173 }
174 case HloOpcode::kSelect: {
175 return IsNonNegative(hlo->operand(1), options) &&
176 IsNonNegative(hlo->operand(2), options);
177 }
178 default:
179 return IsPositive(hlo, options);
180 }
181 }
182
183 // Checks whether `op` is a floating-point constant or broadcast of a constant
184 // of the form +/- 2^k for some integer k positive, negative, or zero. Such
185 // values are interesting because multiplying by a power of 2 just moves the
186 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)187 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
188 // Unwrap the broadcast if necessary.
189 const HloInstruction* c;
190 if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
191 !Match(op, m::Broadcast(m::Constant(&c).WithShape(
192 m::Shape().IsEffectiveScalar())))) {
193 return false;
194 }
195 auto val = [&]() -> absl::optional<double> {
196 switch (c->shape().element_type()) {
197 case BF16:
198 return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
199 case F16:
200 return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
201 case F32:
202 return c->literal().GetFirstElement<float>();
203 case F64:
204 return c->literal().GetFirstElement<double>();
205 default:
206 // Cowardly refuse to consider complex types.
207 return absl::nullopt;
208 }
209 }();
210 if (!val) {
211 return false;
212 }
213
214 int exp;
215 double mantissa = std::frexp(*val, &exp);
216 // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value
217 // of +/-0.5 therefore indicates that the floating point value is a power of
218 // 2.
219 return mantissa == 0.5 || mantissa == -0.5;
220 }
221
222 // Returns whether the given transpose produces a result which is bit-wise
223 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)224 bool TransposeIsBitcast(const HloInstruction* transpose) {
225 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
226 const HloInstruction* operand = transpose->operand(0);
227 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
228 transpose->dimensions());
229 }
230
231 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)232 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
233 HloInstruction* instr, HloInstruction* operand,
234 const AlgebraicSimplifierOptions& options) {
235 // Can't replace chain of copies and reshapes with bitcasts if the compiler
236 // used a memory layout which isn't compatible.
237 if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
238 return operand;
239 }
240
241 // If the operand is a copy or reshape try to see if the operand's operand
242 // would produce a bitcast with initial instruction.
243 if (HloOpcode::kReshape == operand->opcode() ||
244 HloOpcode::kCopy == operand->opcode()) {
245 return BitcastingOperandOfReshapeOrCopyChainHelper(
246 instr, operand->mutable_operand(0), options);
247 }
248 return nullptr;
249 }
250
251 // Returns an operand of a chain of reshapes and copies that is bit-wise
252 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)253 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
254 HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
255 if (!options.is_layout_sensitive()) {
256 return nullptr;
257 }
258 CHECK(HloOpcode::kReshape == instr->opcode() ||
259 HloOpcode::kCopy == instr->opcode());
260 return BitcastingOperandOfReshapeOrCopyChainHelper(
261 instr, instr->mutable_operand(0), options);
262 }
263
IsUnstridedSlice(const HloInstruction * hlo)264 bool IsUnstridedSlice(const HloInstruction* hlo) {
265 return absl::c_all_of(hlo->slice_strides(),
266 [](int64_t stride) { return stride == 1; });
267 }
268
269 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)270 bool IsConvertPairNoOp(const HloInstruction* convert) {
271 // [operand_convert] [convert]
272 // (src)->convert-(intermediate)->convert-(dest)
273 const HloInstruction* operand_convert = convert->operand(0);
274 CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
275 const Shape& src_shape = operand_convert->operand(0)->shape();
276 const Shape& intermediate_shape = operand_convert->shape();
277 const Shape& dest_shape = convert->shape();
278
279 const PrimitiveType src_type = src_shape.element_type();
280 const PrimitiveType intermediate_type = intermediate_shape.element_type();
281 const PrimitiveType dest_type = dest_shape.element_type();
282
283 // src_type must be equal to dest_type.
284 if (src_type != dest_type) {
285 return false;
286 }
287
288 // src_type must be a larger container than intermediate_type.
289 if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
290 ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
291 return false;
292 }
293
294 // Both src_type and intermediate_type must be either floating or integral.
295 bool is_conversion_floating =
296 ShapeUtil::ElementIsFloating(src_shape) &&
297 ShapeUtil::ElementIsFloating(intermediate_shape);
298 bool is_conversion_integral =
299 ShapeUtil::ElementIsIntegral(src_shape) &&
300 ShapeUtil::ElementIsIntegral(intermediate_shape);
301
302 return is_conversion_floating || is_conversion_integral;
303 }
304
305 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
306 // algebraic expressions to simplified forms. Note: This only supports
307 // simplifications that simply look at the operands of an instruction. For the
308 // more general case a worklist based approach would be needed.
309 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
310 public:
AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)311 explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
312 AlgebraicSimplifier* simplifier)
313 : options_(options), simplifier_(simplifier) {}
314
315 Status HandleAbs(HloInstruction* abs) override;
316
317 Status HandleAdd(HloInstruction* add) override;
318
319 Status HandleAnd(HloInstruction* logical_and) override;
320
321 Status HandleBitcast(HloInstruction* bitcast) override;
322
323 Status HandleBitcastConvert(HloInstruction* bitcast) override;
324
325 Status HandleBroadcast(HloInstruction* broadcast) override;
326
327 Status HandleCompare(HloInstruction* compare) override;
328
329 Status HandleConcatenate(HloInstruction* concatenate) override;
330
331 Status HandleConstant(HloInstruction* constant) override;
332
333 Status HandleCopy(HloInstruction* copy) override;
334
335 Status HandleConvert(HloInstruction* convert) override;
336
337 Status HandleComplex(HloInstruction* complex) override;
338
339 Status HandleReal(HloInstruction* real) override;
340
341 Status HandleImag(HloInstruction* imag) override;
342
343 Status HandleIota(HloInstruction* instruction) override;
344
345 Status HandleConvolution(HloInstruction* convolution) override;
346
347 Status HandleDivide(HloInstruction* divide) override;
348
349 Status HandleDot(HloInstruction* dot) override;
350
351 Status HandleGather(HloInstruction* gather) override;
352
353 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
354
355 Status HandleLog(HloInstruction* log) override;
356
357 Status HandleMaximum(HloInstruction* maximum) override;
358
359 Status HandleMinimum(HloInstruction* minimum) override;
360
361 Status HandleClamp(HloInstruction* clamp) override;
362
363 Status HandleMultiply(HloInstruction* multiply) override;
364
365 Status HandleNegate(HloInstruction* negate) override;
366
367 Status HandleNot(HloInstruction* logical_not) override;
368
369 Status HandleOr(HloInstruction* logical_or) override;
370
371 Status HandlePad(HloInstruction* pad) override;
372
373 Status HandlePower(HloInstruction* power) override;
374
375 Status HandleRemainder(HloInstruction* remainder) override;
376
377 Status HandleReshape(HloInstruction* reshape) override;
378
379 Status HandleReduce(HloInstruction* hlo) override;
380
381 Status HandleReduceWindow(HloInstruction* hlo) override;
382
383 Status HandleReverse(HloInstruction* reverse) override;
384
385 Status HandleRsqrt(HloInstruction* rsqrt) override;
386
387 Status HandleSlice(HloInstruction* slice) override;
388
389 Status HandleSqrt(HloInstruction* sqrt) override;
390
391 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
392
393 Status HandleDynamicUpdateSlice(
394 HloInstruction* dynamic_update_slice) override;
395 Status HandleScatter(HloInstruction* scatter) override;
396
397 Status HandleSelect(HloInstruction* select) override;
398
399 Status HandleSort(HloInstruction* sort) override;
400
401 Status HandleTranspose(HloInstruction* transpose) override;
402
403 Status HandleSubtract(HloInstruction* sub) override;
404
405 Status HandleMap(HloInstruction* map) override;
406
407 // Runs the visitor on a computation.
408 bool Run(HloComputation* computation,
409 const AlgebraicSimplifierOptions& options,
410 AlgebraicSimplifier* simplifier);
411
412 private:
413 // Removes degenerate dimension from dot.
414 StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
415
416 // Converts to primitive type if the input hlo is not that type, otherwise
417 // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)418 HloInstruction* AsType(HloInstruction* hlo,
419 const PrimitiveType element_type) {
420 if (hlo->shape().element_type() == element_type) {
421 return hlo;
422 }
423 Shape changed_shape =
424 ShapeUtil::ChangeElementType(hlo->shape(), element_type);
425 simplifier_->UpdateLayout(&changed_shape);
426 return computation_->AddInstruction(
427 HloInstruction::CreateConvert(changed_shape, hlo));
428 }
429
430 // Transposes a dot operand such that the batch dimensions are the most major,
431 // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)432 StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
433 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
434 absl::Span<const int64> contracting_dimensions) {
435 std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
436 batch_dimensions.end());
437 for (int64_t i = 0; i < dot_operand->shape().rank(); ++i) {
438 if (!(absl::c_linear_search(batch_dimensions, i) ||
439 absl::c_linear_search(contracting_dimensions, i))) {
440 transpose_dimensions.push_back(i);
441 }
442 }
443 transpose_dimensions.insert(transpose_dimensions.end(),
444 contracting_dimensions.begin(),
445 contracting_dimensions.end());
446 if (absl::c_is_sorted(transpose_dimensions)) {
447 return dot_operand;
448 }
449 return MakeTransposeHlo(dot_operand, transpose_dimensions);
450 }
451
452 // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims,PrimitiveType type)453 HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims,
454 PrimitiveType type) {
455 HloInstruction* zero = computation_->AddInstruction(
456 simplifier_->CreateConstantWithLayoutUpdated(
457 LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
458 HloComputation* AddReduce_computation =
459 GetOrCreateScalarAddComputation(type);
460 Shape shape = ShapeUtil::FilterDimensions(
461 [&](int64_t dim) { return !absl::c_linear_search(dims, dim); },
462 hlo->shape());
463 simplifier_->UpdateLayout(&shape);
464 return computation_->AddInstruction(HloInstruction::CreateReduce(
465 shape, hlo, zero, dims, AddReduce_computation));
466 }
467
468 // Move scalar multiply to the smallest side of convolution to
469 // reduce multiply computations.
470 Status ScalarMultiplyReduction(HloInstruction* dot);
471
472 // Convenience method for replacing an instruction with a bitcast. If operand
473 // is not null, then the bitcast will use the specified operand instead of the
474 // operand of the instruction.
475 void ReplaceWithBitcast(HloInstruction* instruction,
476 HloInstruction* operand = nullptr);
477
478 // Replace old instruction with new instruction if old and new instructions
479 // have the same shape. Updates uses and root instruction. Returns whether a
480 // replacement was made.
481 bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
482 HloInstruction* new_instruction);
483
484 // Returns whether the shape of the output of the given instructions are the
485 // same for the purposes of simplification. If options_.is_layout_sensitive()
486 // is true, then this tests shape equality including layout
487 // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
488 // tests shape compatibility (ShapeUtil::Compatible).
489 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
490
491 // Returns whether it was possible to transform `root` to a clamp instruction.
492 // With min a minimum instruction, max a maximum instruction, min_operand a
493 // operand of min and max_operand a operand of max.
494 // Precondition: root is either a minimum or a maximum.
495 bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
496 HloInstruction* min_operand,
497 HloInstruction* operand, HloInstruction* max,
498 HloInstruction* max_operand);
499
500 // A Broadcast that feeds an element-wise operation with a unique non-scalar
501 // operand can sink to after the operation.
502 StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
503 HloInstruction* broadcast);
504
505 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
506 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
507 const HloInstruction& dot, HloInstruction* lhs,
508 int64_t lhs_contracting_dim, HloInstruction* rhs,
509 int64_t rhs_contracting_dim, bool swapped);
510
511 StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
512
513 StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
514 HloInstruction* dot);
515
GetOrCreateScalarAddComputation(PrimitiveType type)516 HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
517 HloComputation*& scalar_add_computation = scalar_add_computations_[type];
518 if (scalar_add_computation) {
519 return scalar_add_computation;
520 }
521
522 HloComputation::Builder b("scalar_add_computation");
523 Shape shape = ShapeUtil::MakeShape(type, {});
524 simplifier_->UpdateLayout(&shape);
525 auto scalar_lhs = b.AddInstruction(
526 HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
527 auto scalar_rhs = b.AddInstruction(
528 HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
529 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
530 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
531 scalar_add_computation =
532 computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
533 return scalar_add_computation;
534 }
535
536 // Tries to fold a kPad in the input or filter into the convolution
537 // instruction's window.
538 StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
539 StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
540
541 // Tries to swap convolution operands if they would result in a more efficient
542 // convolution.
543 StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
544
545 // Tries to use a kDot in place of the given convolution.
546 StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
547
548 // Tries to simplify a slice where the result of the slice is a scalar.
549 StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
550
551 // Tries to convert slice(reshape(X)) into reshape(slice(X))
552 StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
553
554 // Tries to convert slice(reverse(X)) into reverse(slice(X))
555 StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice);
556
557 // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
558 // `(< a N)`. This is crucial for being able to figure out the loop trip
559 // count.
560 //
561 // Assumes that the input is conjunction.
562 StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
563
564 // Useful when we want to use the same visitor over multiple computations.
565 void ResetState(HloComputation* computation);
566
567 // Current HloComputation instance the AlgebraicSimplifierVisitor is
568 // traversing.
569 HloComputation* computation_;
570
571 // The backend-specific options selected for the algebraic simplifier.
572 const AlgebraicSimplifierOptions& options_;
573
574 // Whether algebraic simplification has occurred.
575 bool changed_ = false;
576
577 // Cached computation for adding two scalars of a given type.
578 absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
579
580 AlgebraicSimplifier* simplifier_ = nullptr;
581 };
582
583 } // namespace
584
ResetState(HloComputation * computation)585 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
586 changed_ = false;
587 ResetVisitStates();
588 computation_ = computation;
589 }
590
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)591 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
592 const AlgebraicSimplifierOptions& options,
593 AlgebraicSimplifier* simplifier) {
594 ResetState(computation);
595 TF_CHECK_OK(computation->Accept(this));
596 return changed_ || changed();
597 }
598
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const599 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
600 const HloInstruction* rhs) const {
601 if (options_.is_layout_sensitive()) {
602 return ShapeUtil::Equal(lhs->shape(), rhs->shape());
603 } else {
604 return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
605 }
606 }
607
608 namespace {
609
IsOpCodeMultiplyCommutative(HloOpcode opcode)610 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
611 switch (opcode) {
612 case HloOpcode::kMultiply:
613 case HloOpcode::kTranspose:
614 case HloOpcode::kReshape:
615 case HloOpcode::kSelect:
616 return true;
617 default:
618 return false;
619 }
620 }
621
MakeScalarInstruction(HloInstruction * target,float multiplier)622 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
623 float multiplier) {
624 switch (target->shape().element_type()) {
625 case BF16:
626 return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
627 LiteralUtil::CreateR0<float>(multiplier)));
628 break;
629 case F32:
630 return HloInstruction::CreateConstant(
631 LiteralUtil::CreateR0<float>(multiplier));
632 break;
633 default:
634 LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
635 }
636 }
637
638 } // namespace
639
ScalarMultiplyReduction(HloInstruction * dot)640 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
641 HloInstruction* dot) {
642 // We only process bfloat16 and float32 for now.
643 if (dot->shape().element_type() != BF16 &&
644 dot->shape().element_type() != F32) {
645 return Status::OK();
646 }
647
648 auto lhs = dot->mutable_operand(0);
649 auto rhs = dot->mutable_operand(1);
650
651 const int64_t dot_size = ShapeUtil::ElementsIn(dot->shape());
652 const int64_t lhs_size = ShapeUtil::ElementsIn(lhs->shape());
653 const int64_t rhs_size = ShapeUtil::ElementsIn(rhs->shape());
654
655 HloInstruction* target = nullptr;
656 // (current node, user, operand_index)
657 std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
658 std::vector<HloInstruction*> users;
659
660 // Find which side of dot has the smallest size:
661 // operand 0, operand 1, or output.
662 if (dot_size <= std::min(lhs_size, rhs_size)) {
663 target = dot;
664 if (dot_size < lhs_size) {
665 operands.emplace_back(lhs, dot, 0);
666 }
667 if (dot_size < rhs_size) {
668 operands.emplace_back(rhs, dot, 1);
669 }
670 } else if (lhs_size <= rhs_size) {
671 target = lhs;
672 if (lhs_size < rhs_size) {
673 operands.emplace_back(rhs, dot, 1);
674 }
675 if (lhs_size < dot_size && dot->user_count() == 1) {
676 users.push_back(dot->users().front());
677 }
678 } else {
679 target = rhs;
680 if (rhs_size < lhs_size) {
681 operands.emplace_back(lhs, dot, 0);
682 }
683 if (rhs_size < dot_size && dot->user_count() == 1) {
684 users.push_back(dot->users().front());
685 }
686 }
687
688 std::vector<float> values;
689
690 // DFS to find scalar multiply ops from the operands.
691 while (!operands.empty()) {
692 HloInstruction* inst;
693 HloInstruction* user;
694 int64_t index;
695 std::tie(inst, user, index) = operands.back();
696 operands.pop_back();
697
698 // Skip the op types that are not commutative with multiply.
699 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
700 continue;
701 }
702
703 HloInstruction* operand;
704 HloInstruction* multiplier;
705 // Pattern match a scalar multiply.
706 if (Match(inst, m::MultiplyAnyOrder(
707 m::Op(&operand),
708 m::Broadcast(m::ConstantScalar(&multiplier))))) {
709 CHECK_LT(index, user->operand_count());
710 CHECK_EQ(inst, user->operands()[index]);
711
712 // When found a scalar multiply, save its scalar value.
713 values.push_back(*GetConstantValue(multiplier));
714 // And remove the scalar multiply op.
715 TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
716 inst = operand;
717 }
718
719 // Push the operands of inst.
720 int64_t i = 0;
721 for (auto* operand : inst->operands()) {
722 operands.emplace_back(operand, inst, i++);
723 }
724 }
725
726 // DFS to find scalar multiply ops from the users.
727 while (!users.empty()) {
728 auto inst = users.back();
729 users.pop_back();
730
731 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
732 continue;
733 }
734
735 HloInstruction* operand;
736 HloInstruction* multiplier;
737 if (Match(inst, m::MultiplyAnyOrder(
738 m::Op(&operand),
739 m::Broadcast(m::ConstantScalar(&multiplier))))) {
740 values.push_back(*GetConstantValue(multiplier));
741
742 TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
743 inst = operand;
744 }
745
746 // Process the instructions with only one user.
747 // Otherwise moving scalar multiply to the operands changes the values of
748 // other users.
749 if (inst->user_count() == 1) {
750 users.push_back(inst->users().front());
751 }
752 }
753
754 if (values.empty()) {
755 return Status::OK();
756 }
757
758 changed_ = true;
759
760 // Combine all constant multipliers.
761 float multiplier = 1.0;
762 for (const float v : values) {
763 multiplier *= v;
764 }
765
766 // Create a new const scalar multiply instruction.
767 HloInstruction* new_const_inst;
768 new_const_inst =
769 computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
770
771 // Broadcast the scalar multiplier.
772 HloInstruction* new_broadcast = computation_->AddInstruction(
773 HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
774 // Create a new scalar multiply instruction.
775 HloInstruction* new_multiply =
776 computation_->AddInstruction(HloInstruction::CreateBinary(
777 target->shape(), HloOpcode::kMultiply, target, new_broadcast));
778 CHECK_EQ(new_multiply->shape(), target->shape());
779
780 // Update the dependency with the rest of the instructions.
781 if (target == lhs) {
782 return dot->ReplaceOperandWith(0, new_multiply);
783 } else if (target == rhs) {
784 return dot->ReplaceOperandWith(1, new_multiply);
785 } else {
786 CHECK_EQ(target, dot);
787 return dot->ReplaceAllUsesWith(new_multiply);
788 }
789 }
790
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)791 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
792 HloInstruction* operand) {
793 CHECK_EQ(1, instruction->operand_count());
794 if (operand == nullptr) {
795 operand = instruction->mutable_operand(0);
796 }
797 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
798 ShapeUtil::ElementsIn(operand->shape()));
799 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
800 ShapeUtil::ByteSizeOf(operand->shape()));
801
802 auto bitcast = computation_->AddInstruction(
803 HloInstruction::CreateBitcast(instruction->shape(), operand));
804 bitcast->set_metadata(instruction->metadata());
805 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
806 }
807
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)808 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
809 HloInstruction* old_instruction, HloInstruction* new_instruction) {
810 if (!SameShape(old_instruction, new_instruction)) {
811 return false;
812 }
813 TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
814 return true;
815 }
816
HandleAbs(HloInstruction * abs)817 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
818 HloInstruction* abs_operand = abs->mutable_operand(0);
819 VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
820 << " Abs operand is: " << abs_operand->ToString();
821 if (IsNonNegative(abs->operand(0), options_)) {
822 return ReplaceInstruction(abs, abs_operand);
823 }
824 return Status::OK();
825 }
826
HandleAdd(HloInstruction * add)827 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
828 HloInstruction *lhs, *rhs;
829 CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
830
831 // A + 0 => A
832 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
833 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
834 return Status::OK();
835 }
836 // 0 + A => A
837 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
838 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
839 return Status::OK();
840 }
841
842 // Canonicalization: Put constants on the right. This makes the reassociation
843 // rules below simpler.
844 VLOG(10) << "trying transform [Const + A => A + Const]";
845 if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
846 return ReplaceWithNewInstruction(
847 add,
848 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
849 }
850
851 // Reassociate to allow constant folding.
852 //
853 // Note: This is not general. For example, we won't reassociate
854 //
855 // (A + C1) + (B + C2) => A + B + (C1 + C2).
856 //
857 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
858 HloInstruction *a, *c1, *c2;
859 if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
860 m::Constant(&c2))) ||
861 Match(add, m::Add(m::Add(m::NonConstant(&a),
862 m::Broadcast(m::ConstantScalar(&c1))),
863 m::Broadcast(m::ConstantScalar(&c2))))) {
864 TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
865 MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
866 if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
867 !ShapeUtil::IsScalar(add->shape())) {
868 sum_of_constants = computation_->AddInstruction(
869 HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
870 }
871 return ReplaceWithNewInstruction(
872 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
873 sum_of_constants));
874 }
875
876 // Convert add with fullshape into add with partial shape when a
877 // portion of add is effective:
878 // zero (fullshape) rhs (partialshape)
879 // . | |
880 // . lhs . dynamic_update_slice (fullshape)
881 // . | |
882 // Add (fullshape)
883 //
884 // to:
885 // lhs
886 // |
887 // dynamic_slice (partialshape) rhs (partialshape)
888 // . | |
889 // . lhs . add (partial_shape)+----+
890 // . | |
891 // dynamic_update_slice (fullshape)
892 //
893 // This is pattern is discovered in control flow V2 gradient update.
894 if (Match(add,
895 m::Add(m::Op(&lhs),
896 m::Op(&rhs)
897 .WithOpcode(HloOpcode::kDynamicUpdateSlice)
898 .WithOperand(
899 0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
900 const Shape& partial_shape = rhs->operand(1)->shape();
901 auto sliced_lhs =
902 computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
903 partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
904 partial_shape.dimensions()));
905
906 auto add_partial = computation_->AddInstruction(
907 HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
908 sliced_lhs, rhs->mutable_operand(1)));
909
910 auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
911 lhs->shape(), lhs, add_partial,
912 absl::MakeSpan(rhs->operands()).subspan(2));
913
914 return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
915 }
916
917 // A*C + B*C => (A+B)*C
918 //
919 // - If A, B, and C are integers, do this unconditionally. Proof of
920 // correctness: https://rise4fun.com/Alive/u9X.
921 //
922 // - If A, B, and C are floating point, do this if C is a scalar constant or
923 // broadcast of scalar constant and is equal to +/- 2^k for some (possibly
924 // negative) integer k.
925 //
926 // Multiplying by a power of 2 just moves the exponent, so our answer is
927 // exact modulo rounding of intermediate results so long as
928 //
929 // - none of the three products has an exponent which underflows (so the
930 // result is 0 or denormal), and
931 // - none of the three products overflows to inf.
932 //
933 // Proof: See algebraic_simplifier_proof_distributive_property.py.
934 //
935 // We deem these differences in rounding, underflow, and overflow
936 // acceptable in the ML context.
937 //
938 // Furthermore, if `enable_floats_are_real` is true, the simplification is
939 // done nonetheless. This might cause numerical differences even if there
940 // is no underflow or overflow.
941 HloInstruction *b, *c;
942 if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
943 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
944 (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
945 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
946 // Make sure we would decrease the number of multiplies.
947 (lhs->user_count() == 1 && rhs->user_count() == 1) &&
948 (ShapeUtil::ElementIsIntegral(add->shape()) ||
949 options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
950 return ReplaceWithNewInstruction(
951 add, HloInstruction::CreateBinary(
952 add->shape(), HloOpcode::kMultiply,
953 computation_->AddInstruction(HloInstruction::CreateBinary(
954 add->shape(), HloOpcode::kAdd, a, b)),
955 c));
956 }
957
958 if (options_.is_layout_sensitive()) {
959 return Status::OK();
960 }
961
962 HloInstruction* lhs_scatter_operand = nullptr;
963 HloInstruction* rhs_scatter_operand = nullptr;
964 HloInstruction* lhs_scatter_update = nullptr;
965 HloInstruction* rhs_scatter_update = nullptr;
966 HloInstruction* lhs_scatter_index = nullptr;
967 HloInstruction* rhs_scatter_index = nullptr;
968 bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
969 m::Op(&lhs_scatter_index),
970 m::Op(&lhs_scatter_update))
971 .WithOneUse()) &&
972 Match(lhs->to_apply()->root_instruction(),
973 m::Add(m::Parameter(), m::Parameter()));
974 bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
975 m::Op(&rhs_scatter_index),
976 m::Op(&rhs_scatter_update))
977 .WithOneUse()) &&
978 Match(rhs->to_apply()->root_instruction(),
979 m::Add(m::Parameter(), m::Parameter()));
980 if (rhs_scatter && lhs_scatter) {
981 const auto& lhs_dnums = lhs->scatter_dimension_numbers();
982 const auto& rhs_dnums = rhs->scatter_dimension_numbers();
983 absl::optional<int64> index_concat_dimension;
984 absl::optional<int64> update_concat_dimension;
985 // Don't try to combine scatters of different ranks.
986 if (lhs_scatter_index->shape().rank() !=
987 rhs_scatter_index->shape().rank()) {
988 return Status::OK();
989 }
990
991 int64_t first_index_dim = lhs_scatter_index->shape().rank();
992 int64_t first_update_dim = lhs_scatter_update->shape().rank();
993 // Find a dimension where it is possible to concatenate the indices and
994 // updates. This is the first and only non-equal dimension or the first
995 // equally sized dimension.
996 for (int64_t d = lhs_scatter_index->shape().rank() - 1,
997 update_dim = lhs_scatter_update->shape().rank() - 1;
998 d >= 0; --d) {
999 if (d == lhs_dnums.index_vector_dim()) {
1000 continue;
1001 }
1002 while (
1003 absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
1004 --update_dim;
1005 }
1006 if (lhs_scatter_index->shape().dimensions(d) ==
1007 rhs_scatter_index->shape().dimensions(d)) {
1008 first_index_dim = d;
1009 first_update_dim = update_dim--;
1010 continue;
1011 }
1012 // More than one dimension of unequal size was found, bail out.
1013 if (index_concat_dimension) {
1014 return Status::OK();
1015 }
1016 index_concat_dimension = d;
1017 update_concat_dimension = update_dim--;
1018 }
1019 if (!index_concat_dimension) {
1020 index_concat_dimension = first_index_dim;
1021 update_concat_dimension = first_update_dim;
1022 }
1023
1024 // A scalar scatter will require additional reshapes of the index and
1025 // update.
1026 if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
1027 return Status::OK();
1028 }
1029 const bool update_concat_is_cheap =
1030 ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
1031 ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
1032 ShapeUtil::ElementsIn(lhs->shape());
1033 if (!update_concat_is_cheap) {
1034 return Status::OK();
1035 }
1036 const bool same_dimension_numbers =
1037 lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
1038 absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
1039 rhs_dnums.scatter_dims_to_operand_dims()) &&
1040 absl::c_equal(lhs_dnums.inserted_window_dims(),
1041 rhs_dnums.inserted_window_dims()) &&
1042 absl::c_equal(lhs_dnums.update_window_dims(),
1043 rhs_dnums.update_window_dims());
1044 const bool index_concat_is_safe =
1045 !lhs->unique_indices() && !rhs->unique_indices() &&
1046 !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
1047 !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
1048
1049 Shape lhs_update_window = ShapeUtil::FilterDimensions(
1050 [&](int64_t dim) {
1051 return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
1052 },
1053 lhs_scatter_update->shape());
1054 Shape rhs_update_window = ShapeUtil::FilterDimensions(
1055 [&](int64_t dim) {
1056 return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
1057 },
1058 rhs_scatter_update->shape());
1059 // Concatenate the indices and updates
1060 if (index_concat_is_safe && same_dimension_numbers &&
1061 index_concat_dimension &&
1062 lhs_scatter_index->shape().element_type() ==
1063 rhs_scatter_index->shape().element_type() &&
1064 ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
1065 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1066 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1067 rhs_scatter_operand));
1068 TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
1069 MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
1070 *index_concat_dimension));
1071 TF_ASSIGN_OR_RETURN(
1072 HloInstruction * new_update,
1073 MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
1074 *update_concat_dimension));
1075 return ReplaceWithNewInstruction(
1076 add, HloInstruction::CreateScatter(
1077 add->shape(), new_operand, new_index, new_update,
1078 lhs->to_apply(), lhs_dnums, false, false));
1079 }
1080 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1081 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1082 rhs_scatter_operand));
1083 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1084 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
1085 return ReplaceInstruction(add, lhs);
1086 } else if (rhs_scatter) {
1087 TF_ASSIGN_OR_RETURN(
1088 HloInstruction * new_operand,
1089 MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
1090 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1091 return ReplaceInstruction(add, rhs);
1092 } else if (lhs_scatter) {
1093 TF_ASSIGN_OR_RETURN(
1094 HloInstruction * new_operand,
1095 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
1096 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
1097 return ReplaceInstruction(add, lhs);
1098 }
1099 return Status::OK();
1100 }
1101
TrySimplifyTautologicalCompare(HloInstruction * conjunction)1102 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
1103 HloInstruction* conjunction) {
1104 HloInstruction *lhs, *rhs;
1105 if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
1106 return false;
1107 }
1108 struct LessThanCompareInfo { // (LT var constant)
1109 HloInstruction* var;
1110 int64 constant;
1111 };
1112
1113 auto get_compare_info =
1114 [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
1115 HloInstruction *lhs, *rhs;
1116 auto scalar_shape_matcher =
1117 m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
1118 if (Match(cmp, m::Compare(m::Op(&lhs),
1119 m::Constant(&rhs).WithShape(scalar_shape_matcher))
1120 .WithComparisonDirection(ComparisonDirection::kLt))) {
1121 return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
1122 } else if (Match(
1123 cmp,
1124 m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
1125 m::Op(&rhs))
1126 .WithComparisonDirection(ComparisonDirection::kGt))) {
1127 return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
1128 }
1129 return absl::nullopt;
1130 };
1131
1132 absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
1133 absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
1134 if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
1135 int64_t new_bound = std::min(lhs_info->constant, rhs_info->constant);
1136 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1137 conjunction,
1138 HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
1139 MakeScalarLike(lhs_info->var, new_bound),
1140 ComparisonDirection::kLt)));
1141 return true;
1142 }
1143 return false;
1144 }
1145
HandleAnd(HloInstruction * logical_and)1146 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1147 HloInstruction *lhs, *rhs;
1148 CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1149 // Simplify logical and
1150 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1151 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1152 // A && True => A
1153 VLOG(10) << "trying transform [A && True => A]: "
1154 << logical_and->ToString();
1155 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1156 return Status::OK();
1157 }
1158 // True && A => A
1159 VLOG(10) << "trying transform [True && A => A]: "
1160 << logical_and->ToString();
1161 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1162 return Status::OK();
1163 }
1164 }
1165
1166 // A && False => False or A & 0 => 0
1167 VLOG(10) << "trying transform [A && False => False]: "
1168 << logical_and->ToString();
1169 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1170 return Status::OK();
1171 }
1172
1173 // False && A => False or A & 0 => 0
1174 VLOG(10) << "trying transform [False && A => False]: "
1175 << logical_and->ToString();
1176 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1177 return Status::OK();
1178 }
1179
1180 // Simplify tautological conjunctions.
1181 TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1182 TrySimplifyTautologicalCompare(logical_and));
1183 if (found_tautological_compare) {
1184 return Status::OK();
1185 }
1186
1187 return Status::OK();
1188 }
1189
HandleBitcast(HloInstruction * bitcast)1190 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1191 // If a bitcast feeds a bitcast, make it a single bitcast.
1192 HloInstruction* op;
1193 if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1194 return ReplaceWithNewInstruction(
1195 bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1196 }
1197 // All bitcasts can be eliminated (assuming layout constraints are
1198 // satisfied).
1199 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1200 return Status::OK();
1201 }
1202
HandleBitcastConvert(HloInstruction * bitcast)1203 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1204 HloInstruction* bitcast) {
1205 // Eliminate bitcast converts between same shape.
1206 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1207 return Status::OK();
1208 }
1209
HandleCopy(HloInstruction * copy)1210 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1211 // If a copy feeds a copy, make it a single copy.
1212 HloInstruction* op;
1213 if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1214 return ReplaceWithNewInstruction(
1215 copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1216 }
1217 // All copies can be eliminated (assuming layout constraints are satisfied).
1218 if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
1219 return Status::OK();
1220 }
1221
1222 if (HloInstruction* bitcast_operand =
1223 BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1224 ReplaceWithBitcast(copy, bitcast_operand);
1225 return Status::OK();
1226 }
1227
1228 // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1229 if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1230 copy->operand(0)->user_count() == 1 &&
1231 ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1232 return ReplaceWithNewInstruction(
1233 copy,
1234 copy->operand(0)->CloneWithNewOperands(
1235 copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1236 }
1237 return Status::OK();
1238 }
1239
HandleConcatenate(HloInstruction * concatenate)1240 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1241 HloInstruction* concatenate) {
1242 absl::Span<HloInstruction* const> operands(concatenate->operands());
1243 if (operands.size() == 1) {
1244 // Unary concatenates are useless.
1245 ReplaceInstructionIfSameShape(concatenate, operands[0]);
1246 return Status::OK();
1247 }
1248 // Filter out and remove empty operands.
1249 std::vector<HloInstruction*> nonempty_operands;
1250 for (HloInstruction* operand : operands) {
1251 if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1252 nonempty_operands.push_back(operand);
1253 }
1254 }
1255 if (nonempty_operands.size() < operands.size()) {
1256 HloInstruction* replacement;
1257 if (nonempty_operands.empty()) {
1258 replacement = operands[0];
1259 } else if (nonempty_operands.size() == 1) {
1260 replacement = nonempty_operands[0];
1261 } else {
1262 replacement =
1263 computation_->AddInstruction(concatenate->CloneWithNewOperands(
1264 concatenate->shape(), nonempty_operands));
1265 }
1266 VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1267 << replacement->ToString();
1268 ReplaceInstructionIfSameShape(concatenate, replacement);
1269 return Status::OK();
1270 }
1271
1272 if (options_.is_layout_sensitive()) {
1273 return Status::OK();
1274 }
1275
1276 // concat(x, concat(y, z)) -> concat(x, y, z). We only do this in
1277 // layout-insensitive mode because some backends may have (late,
1278 // layout-sensitive) passes that break up ops with many operands into smaller
1279 // pieces. This would undo that.
1280 absl::InlinedVector<HloInstruction*, 8> unnested_concat_operands;
1281 for (HloInstruction* operand : operands) {
1282 if (operand->opcode() == HloOpcode::kConcatenate &&
1283 operand->concatenate_dimension() ==
1284 concatenate->concatenate_dimension()) {
1285 for (HloInstruction* instr : operand->operands()) {
1286 unnested_concat_operands.push_back(instr);
1287 }
1288 } else {
1289 unnested_concat_operands.push_back(operand);
1290 }
1291 }
1292 if (unnested_concat_operands.size() != concatenate->operand_count()) {
1293 return ReplaceWithNewInstruction(
1294 concatenate, HloInstruction::CreateConcatenate(
1295 concatenate->shape(), unnested_concat_operands,
1296 concatenate->concatenate_dimension()));
1297 }
1298
1299 // Check if we can merge "adjacent" slice operands which take slices from the
1300 // same other op. For simplicity we only merge unstrided slices.
1301 int64_t concatenate_dimension = concatenate->concatenate_dimension();
1302 std::vector<HloInstruction*> new_operands;
1303 int64_t i = 0;
1304 while (i < operands.size()) {
1305 if (operands[i]->opcode() != HloOpcode::kSlice ||
1306 !IsUnstridedSlice(operands[i])) {
1307 new_operands.push_back(operands[i]);
1308 ++i;
1309 continue;
1310 }
1311 int64_t slice_end = operands[i]->slice_limits(concatenate_dimension);
1312 HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1313 int64_t j = i + 1;
1314 while (j < operands.size()) {
1315 if (operands[j]->opcode() != HloOpcode::kSlice ||
1316 !IsUnstridedSlice(operands[j]) ||
1317 operands[j]->operand(0) != slice_operand ||
1318 operands[j]->slice_starts(concatenate_dimension) != slice_end) {
1319 break;
1320 }
1321 // Check that all the slice_start values are the same in all other
1322 // dimensions. This implies that the slice_limit values are also the same,
1323 // because operands of concatenate need to have the same shape, and we
1324 // already checked that the slices are unstrided.
1325 bool same_other_starts = true;
1326 for (int64_t k = 0; k < operands[j]->slice_starts().size(); ++k) {
1327 if (k == concatenate_dimension) {
1328 continue;
1329 }
1330 if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1331 same_other_starts = false;
1332 break;
1333 }
1334 }
1335 if (!same_other_starts) {
1336 break;
1337 }
1338 slice_end = operands[j]->slice_limits(concatenate_dimension);
1339 ++j;
1340 }
1341 if (j - i > 1) {
1342 Shape new_slice_shape = operands[i]->shape();
1343 new_slice_shape.set_dimensions(
1344 concatenate_dimension,
1345 slice_end - operands[i]->slice_starts(concatenate_dimension));
1346 simplifier_->UpdateLayout(&new_slice_shape);
1347 auto new_limit_indices = operands[i]->slice_limits();
1348 new_limit_indices[concatenate_dimension] = slice_end;
1349 auto new_slice_op =
1350 computation_->AddInstruction(HloInstruction::CreateSlice(
1351 new_slice_shape, slice_operand,
1352 /*start_indices=*/operands[i]->slice_starts(),
1353 /*limit_indices=*/new_limit_indices,
1354 /*strides=*/operands[i]->slice_strides()));
1355 new_operands.push_back(new_slice_op);
1356 } else {
1357 new_operands.push_back(operands[i]);
1358 }
1359 i = j;
1360 }
1361 if (new_operands.size() < operands.size()) {
1362 auto replacement = computation_->AddInstruction(
1363 concatenate->CloneWithNewOperands(concatenate->shape(), new_operands));
1364 ReplaceInstructionIfSameShape(concatenate, replacement);
1365 return Status::OK();
1366 }
1367
1368 if (operands.size() == 2) {
1369 // A binary concat with a broadcasted scalar as an operand can be converted
1370 // into a pad which is simpler to fold into other operations.
1371 bool is_effective_low_pad = Match(
1372 operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1373 bool is_effective_high_pad = Match(
1374 operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1375 if (!is_effective_low_pad && !is_effective_high_pad) {
1376 return Status::OK();
1377 }
1378 PaddingConfig padding_config;
1379 for (int64_t dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1380 auto padding_config_dim = padding_config.add_dimensions();
1381 padding_config_dim->set_edge_padding_high(0);
1382 padding_config_dim->set_edge_padding_low(0);
1383 padding_config_dim->set_interior_padding(0);
1384 if (dim == concatenate_dimension) {
1385 if (is_effective_low_pad) {
1386 padding_config_dim->set_edge_padding_low(
1387 operands[0]->shape().dimensions(dim));
1388 } else {
1389 padding_config_dim->set_edge_padding_high(
1390 operands[1]->shape().dimensions(dim));
1391 }
1392 }
1393 }
1394 int64_t operand_to_pad = is_effective_low_pad ? 1 : 0;
1395 int64_t pad_value_operand = is_effective_low_pad ? 0 : 1;
1396 HloInstruction* pad =
1397 computation_->AddInstruction(HloInstruction::CreatePad(
1398 concatenate->shape(), operands[operand_to_pad],
1399 operands[pad_value_operand]->mutable_operand(0), padding_config));
1400 return ReplaceInstruction(concatenate, pad);
1401 }
1402
1403 if (absl::c_count(operands, operands[0]) == operands.size() &&
1404 operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1405 Shape new_shape = operands[0]->shape();
1406 absl::InlinedVector<int64, 8> broadcast_dims;
1407 for (int64_t i = 0; i < new_shape.rank(); ++i) {
1408 if (i == concatenate_dimension) {
1409 continue;
1410 }
1411 broadcast_dims.push_back(i);
1412 }
1413 new_shape.DeleteDimension(concatenate_dimension);
1414 return ReplaceInstruction(
1415 concatenate,
1416 MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1417 broadcast_dims, concatenate->shape()));
1418 }
1419 return Status::OK();
1420 }
1421
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1422 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1423 const LiteralSlice& literal,
1424 AlgebraicSimplifier* simplifier) {
1425 if (literal.shape().IsTuple()) {
1426 std::vector<HloInstruction*> elems;
1427 elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1428 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1429 elems.push_back(BuildTupleConstant(
1430 computation, LiteralSlice(literal, {i}), simplifier));
1431 }
1432 return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1433 } else {
1434 return computation->AddInstruction(
1435 simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1436 }
1437 }
1438
HandleConstant(HloInstruction * constant)1439 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1440 // Tuple constants aren't directly supported by any backend. Expand them into
1441 // explicit Tuple instructions.
1442 if (constant->shape().IsTuple()) {
1443 return ReplaceInstruction(
1444 constant,
1445 BuildTupleConstant(computation_, constant->literal(), simplifier_));
1446 }
1447
1448 if (constant->shape().element_type() == TOKEN) {
1449 return Status::OK();
1450 }
1451
1452 // If a literal is all the same element replace it with a scalar broadcast.
1453 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1454 constant->literal().IsAllFirst()) {
1455 Literal unique_scalar(
1456 LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1457 HloInstruction* scalar = computation_->AddInstruction(
1458 simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1459 return ReplaceWithNewInstruction(
1460 constant,
1461 HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1462 }
1463
1464 // If a literal is an increasing sequence from zero, replace it with an iota.
1465 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1466 constant->literal().IsR1Iota()) {
1467 return ReplaceWithNewInstruction(
1468 constant, HloInstruction::CreateIota(constant->shape(), 0));
1469 }
1470
1471 if (absl::optional<int64> stride = constant->literal().IsR1StridedIota()) {
1472 // Replace the constant with iota * stride.
1473 HloInstruction* stride_hlo = MakeScalarLike(constant, *stride);
1474 HloInstruction* iota = computation_->AddInstruction(
1475 HloInstruction::CreateIota(constant->shape(), 0));
1476 return ReplaceWithNewInstruction(
1477 constant,
1478 HloInstruction::CreateBinary(constant->shape(), HloOpcode::kMultiply,
1479 iota, stride_hlo));
1480 }
1481
1482 return Status::OK();
1483 }
1484
HandleSubtract(HloInstruction * sub)1485 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1486 HloInstruction *lhs, *rhs;
1487 CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1488 // A - 0 => A
1489 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1490 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
1491 return Status::OK();
1492 }
1493
1494 // Canonicalize subtraction of a constant to addition.
1495 VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1496 if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1497 Match(sub, m::Subtract(m::NonConstant(&lhs),
1498 m::Broadcast(m::Constant(&rhs))))) {
1499 HloInstruction* negative_const = computation_->AddInstruction(
1500 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1501 if (const HloInstruction* broadcast =
1502 DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1503 negative_const =
1504 computation_->AddInstruction(HloInstruction::CreateBroadcast(
1505 broadcast->shape(), negative_const, broadcast->dimensions()));
1506 }
1507 return ReplaceWithNewInstruction(
1508 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1509 negative_const));
1510 }
1511
1512 return Status::OK();
1513 }
1514 namespace {
1515 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1516 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1517 return result->Populate<T>([&](absl::Span<const int64> indices) {
1518 return T{1.0} / constant.literal().Get<T>(indices);
1519 });
1520 }
1521
1522 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1523 std::unique_ptr<HloInstruction> TryDivideToShift(
1524 HloInstruction* divide, HloComputation* computation,
1525 AlgebraicSimplifier* simplifier) {
1526 HloInstruction *a, *b, *c;
1527 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1528
1529 if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1530 !Match(b, m::ConstantEffectiveScalar(&c)) &&
1531 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1532 return nullptr;
1533 }
1534
1535 if (ShapeUtil::ElementIsSigned(divide->shape())) {
1536 int64_t b_value = c->literal().GetFirstElement<T>();
1537 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
1538 // Handle negative dividends by negating the result of the division.
1539 HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1540
1541 Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1542 simplifier->UpdateLayout(&changed_shape);
1543 auto* dividend_is_negative =
1544 computation->AddInstruction(HloInstruction::CreateCompare(
1545 changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1546
1547 auto* negated_dividend = computation->AddInstruction(
1548 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1549
1550 auto* abs_dividend =
1551 computation->AddInstruction(HloInstruction::CreateTernary(
1552 a->shape(), HloOpcode::kSelect, dividend_is_negative,
1553 negated_dividend, a));
1554
1555 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
1556 divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1557 MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
1558
1559 auto* neqated_quotient =
1560 computation->AddInstruction(HloInstruction::CreateUnary(
1561 quotient->shape(), HloOpcode::kNegate, quotient));
1562
1563 return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1564 dividend_is_negative,
1565 neqated_quotient, quotient);
1566 }
1567 } else {
1568 uint64 b_value = c->literal().GetFirstElement<T>();
1569 if (IsPowerOfTwo(b_value)) {
1570 return HloInstruction::CreateBinary(
1571 divide->shape(), HloOpcode::kShiftRightLogical, a,
1572 MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
1573 }
1574 }
1575
1576 return nullptr;
1577 }
1578 } // namespace
1579
HandleDivide(HloInstruction * divide)1580 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1581 HloInstruction *a, *b, *c, *d;
1582 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1583 // A/1 => A
1584 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1585 if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
1586 return Status::OK();
1587 }
1588
1589 // A / B => A >> log2(B) if B is a power of 2.
1590 switch (divide->shape().element_type()) {
1591 case S8:
1592 if (std::unique_ptr<HloInstruction> shift =
1593 TryDivideToShift<int8>(divide, computation_, simplifier_)) {
1594 return ReplaceWithNewInstruction(divide, std::move(shift));
1595 }
1596 break;
1597 case S16:
1598 if (std::unique_ptr<HloInstruction> shift =
1599 TryDivideToShift<int16>(divide, computation_, simplifier_)) {
1600 return ReplaceWithNewInstruction(divide, std::move(shift));
1601 }
1602 break;
1603 case S32:
1604 if (std::unique_ptr<HloInstruction> shift =
1605 TryDivideToShift<int32>(divide, computation_, simplifier_)) {
1606 return ReplaceWithNewInstruction(divide, std::move(shift));
1607 }
1608 break;
1609 case S64:
1610 if (std::unique_ptr<HloInstruction> shift =
1611 TryDivideToShift<int64>(divide, computation_, simplifier_)) {
1612 return ReplaceWithNewInstruction(divide, std::move(shift));
1613 }
1614 break;
1615 case U8:
1616 if (std::unique_ptr<HloInstruction> shift =
1617 TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
1618 return ReplaceWithNewInstruction(divide, std::move(shift));
1619 }
1620 break;
1621 case U16:
1622 if (std::unique_ptr<HloInstruction> shift =
1623 TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
1624 return ReplaceWithNewInstruction(divide, std::move(shift));
1625 }
1626 break;
1627 case U32:
1628 if (std::unique_ptr<HloInstruction> shift =
1629 TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
1630 return ReplaceWithNewInstruction(divide, std::move(shift));
1631 }
1632 break;
1633 case U64:
1634 if (std::unique_ptr<HloInstruction> shift =
1635 TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
1636 return ReplaceWithNewInstruction(divide, std::move(shift));
1637 }
1638 break;
1639 default:
1640 break;
1641 }
1642
1643 Shape* shape;
1644 // exp(A)/exp(B) => exp(A-B)
1645 if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1646 .WithShape(m::Shape(&shape)))) {
1647 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1648 HloInstruction* subtract = computation_->AddInstruction(
1649 HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1650 return ReplaceWithNewInstruction(
1651 divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1652 }
1653
1654 // A/exp(B) => A*exp(-B)
1655 if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1656 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1657 HloInstruction* negate = computation_->AddInstruction(
1658 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1659 HloInstruction* new_exp = computation_->AddInstruction(
1660 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1661 return ReplaceWithNewInstruction(
1662 divide, HloInstruction::CreateBinary(divide->shape(),
1663 HloOpcode::kMultiply, a, new_exp));
1664 }
1665
1666 // A/pow(B,C) => A*pow(B,-C)
1667 if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1668 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1669 // The output shape of the created negate operator should be the same as the
1670 // input.
1671 const Shape& negate_shape = c->shape();
1672 HloInstruction* negate = computation_->AddInstruction(
1673 HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1674 // And the power operator should retain the output shape of the old one.
1675 const Shape& new_power_shape = b->shape();
1676 HloInstruction* new_power =
1677 computation_->AddInstruction(HloInstruction::CreateBinary(
1678 new_power_shape, HloOpcode::kPower, b, negate));
1679 return ReplaceWithNewInstruction(
1680 divide, HloInstruction::CreateBinary(
1681 divide->shape(), HloOpcode::kMultiply, a, new_power));
1682 }
1683
1684 // A/sqrt(B) => A*rsqrt(X).
1685 if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1686 auto* rsqrt = computation_->AddInstruction(
1687 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1688 return ReplaceWithNewInstruction(
1689 divide, HloInstruction::CreateBinary(rsqrt->shape(),
1690 HloOpcode::kMultiply, a, rsqrt));
1691 }
1692
1693 // A/rsqrt(B) => A*sqrt(B).
1694 if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1695 auto* sqrt = computation_->AddInstruction(
1696 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1697 return ReplaceWithNewInstruction(
1698 divide, HloInstruction::CreateBinary(sqrt->shape(),
1699 HloOpcode::kMultiply, a, sqrt));
1700 }
1701
1702 // Simplifying integral division would produce unexpected results.
1703 if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1704 return Status::OK();
1705 }
1706
1707 // A / Const => A * (1 / Const)
1708 //
1709 // (Backends can do this transformation, but generally only if the constant is
1710 // a scalar.)
1711 if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1712 (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1713 Shape result_shape = c->literal().shape();
1714 Literal new_literal(result_shape);
1715 switch (result_shape.element_type()) {
1716 case F16:
1717 TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1718 break;
1719 case F32:
1720 TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1721 break;
1722 case BF16:
1723 TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1724 break;
1725 case F64:
1726 TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1727 break;
1728 case C64:
1729 TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1730 break;
1731 case C128:
1732 TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1733 break;
1734 default:
1735 return Status::OK();
1736 }
1737 auto inverse = computation_->AddInstruction(
1738 simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1739 if (b != c) {
1740 inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1741 b->shape(), inverse, b->dimensions()));
1742 }
1743 TF_ASSIGN_OR_RETURN(auto new_divide,
1744 MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1745 return ReplaceInstruction(divide, new_divide);
1746 }
1747
1748 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
1749 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1750 m::Divide(m::Op(&c), m::Op(&d))))) {
1751 TF_ASSIGN_OR_RETURN(auto a_times_d,
1752 MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1753 TF_ASSIGN_OR_RETURN(auto b_times_c,
1754 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1755 TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1756 a_times_d, b_times_c));
1757
1758 return ReplaceInstruction(divide, new_divide);
1759 }
1760
1761 // (A / B) / C => A / (B * C)
1762 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1763 TF_ASSIGN_OR_RETURN(auto b_times_c,
1764 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1765 TF_ASSIGN_OR_RETURN(auto new_divide,
1766 MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1767 return ReplaceInstruction(divide, new_divide);
1768 }
1769
1770 // A / (B / C) => (A*C) / B
1771 if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1772 TF_ASSIGN_OR_RETURN(auto a_times_c,
1773 MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1774 TF_ASSIGN_OR_RETURN(auto new_divide,
1775 MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1776 return ReplaceInstruction(divide, new_divide);
1777 }
1778
1779 // If X is a convert from pred, then
1780 // X / broadcast(Y) => broadcast(1/Y) * X
1781 if (Match(divide,
1782 m::Divide(
1783 m::Convert(&a,
1784 m::Op().WithShape(m::Shape().WithElementType(PRED))),
1785 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1786 TF_ASSIGN_OR_RETURN(
1787 auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1788 auto recip_bcast = computation_->AddInstruction(
1789 HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1790 TF_ASSIGN_OR_RETURN(auto mul,
1791 MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1792 return ReplaceInstruction(divide, mul);
1793 }
1794
1795 return Status::OK();
1796 }
1797
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1798 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1799 HloInstruction* dot) {
1800 const Shape& lhs_shape = dot->operand(0)->shape();
1801 int64_t num_degenerate_lhs_dims = 0;
1802 std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1);
1803 for (int64_t i = 0; i < lhs_shape.rank(); ++i) {
1804 if (lhs_shape.dimensions(i) == 1) {
1805 ++num_degenerate_lhs_dims;
1806 } else {
1807 lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1808 }
1809 }
1810
1811 const Shape& rhs_shape = dot->operand(1)->shape();
1812 int64_t num_degenerate_rhs_dims = 0;
1813 std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1);
1814 for (int64_t i = 0; i < rhs_shape.rank(); ++i) {
1815 if (rhs_shape.dimensions(i) == 1) {
1816 ++num_degenerate_rhs_dims;
1817 } else {
1818 rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1819 }
1820 }
1821 if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1822 return false;
1823 }
1824 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1825 DotDimensionNumbers new_dnums;
1826 for (int64_t dim : dnums.lhs_batch_dimensions()) {
1827 int64_t new_dim = lhs_dimension_map[dim];
1828 if (new_dim != -1) {
1829 new_dnums.add_lhs_batch_dimensions(new_dim);
1830 }
1831 }
1832 for (int64_t dim : dnums.lhs_contracting_dimensions()) {
1833 int64_t new_dim = lhs_dimension_map[dim];
1834 if (new_dim != -1) {
1835 new_dnums.add_lhs_contracting_dimensions(new_dim);
1836 }
1837 }
1838
1839 for (int64_t dim : dnums.rhs_batch_dimensions()) {
1840 int64_t new_dim = rhs_dimension_map[dim];
1841 if (new_dim != -1) {
1842 new_dnums.add_rhs_batch_dimensions(new_dim);
1843 }
1844 }
1845 for (int64_t dim : dnums.rhs_contracting_dimensions()) {
1846 int64_t new_dim = rhs_dimension_map[dim];
1847 if (new_dim != -1) {
1848 new_dnums.add_rhs_contracting_dimensions(new_dim);
1849 }
1850 }
1851
1852 HloInstruction* new_lhs =
1853 num_degenerate_lhs_dims > 0
1854 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1855 ShapeUtil::DropDegenerateDimensions(lhs_shape),
1856 dot->mutable_operand(0)))
1857 : dot->mutable_operand(0);
1858 HloInstruction* new_rhs =
1859 num_degenerate_rhs_dims > 0
1860 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1861 ShapeUtil::DropDegenerateDimensions(rhs_shape),
1862 dot->mutable_operand(1)))
1863 : dot->mutable_operand(1);
1864 TF_ASSIGN_OR_RETURN(
1865 auto new_dot,
1866 MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
1867 /*preferred_element_type=*/dot->shape().element_type()));
1868 if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
1869 TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
1870 } else {
1871 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1872 dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
1873 }
1874 return true;
1875 }
1876
OptimizeDotOfConcat(HloInstruction * dot)1877 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1878 HloInstruction* dot) {
1879 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1880 if (dnums.lhs_contracting_dimensions_size() != 1 ||
1881 dnums.lhs_batch_dimensions_size() != 0 ||
1882 dot->shape().dimensions_size() != 2) { // dot output 2D
1883 return nullptr;
1884 }
1885
1886 const int64_t lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1887 const int64_t rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1888 HloInstruction *lhs, *rhs;
1889 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1890
1891 TF_ASSIGN_OR_RETURN(
1892 HloInstruction * optimized_lhs_concat,
1893 OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1894 rhs_contracting_dim, /*swapped=*/false));
1895 if (optimized_lhs_concat) {
1896 return optimized_lhs_concat;
1897 }
1898
1899 return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1900 lhs_contracting_dim, /*swapped=*/true);
1901 }
1902
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64_t lhs_contracting_dim,HloInstruction * rhs,int64_t rhs_contracting_dim,bool swapped)1903 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1904 const HloInstruction& dot, HloInstruction* lhs, int64_t lhs_contracting_dim,
1905 HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) {
1906 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1907 lhs->concatenate_dimension() == lhs_contracting_dim &&
1908 rhs->opcode() == HloOpcode::kConstant;
1909 if (!can_optimize) {
1910 return nullptr;
1911 }
1912
1913 // We're replacing this:
1914 //
1915 // +-----+-----+-----+ +-------------------+
1916 // | | | | | |
1917 // | | | | | R_0 |
1918 // | | | | | |
1919 // | | | | +-------------------+
1920 // | | | | | |
1921 // | L_0 | L_1 | L_2 | * | R_1 |
1922 // | | | | | |
1923 // | | | | +-------------------+
1924 // | | | | | |
1925 // | | | | | R_2 |
1926 // | | | | | |
1927 // +-----+-----+-----+ +-------------------+
1928 //
1929 // with this:
1930 //
1931 // [Sum over i]
1932 //
1933 // +-----+ +-------------------+
1934 // | | | |
1935 // | | * | R_i |
1936 // | | | |
1937 // | | +-------------------+
1938 // | |
1939 // | L_i |
1940 // | |
1941 // | |
1942 // | |
1943 // | |
1944 // | |
1945 // +-----+
1946 //
1947 // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1948 // for free) and the RHS is a constant tensor (and thus can be split at
1949 // compile time). In the future, we may also want to do this when both the
1950 // LHS and the RHS are concatenate operations that line up along the dimension
1951 // being contracted over.
1952 //
1953 // We should be able to generalize this transform to work on a non-constant
1954 // RHS when/if we have in-place slices or support input-fusing slices into
1955 // Dots.
1956
1957 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1958 // the diagram above).
1959 DotDimensionNumbers new_dot_dnums;
1960 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1961 : lhs_contracting_dim);
1962 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1963 : rhs_contracting_dim);
1964
1965 // Here we use the MKN notation, where the contracted dimension has K
1966 // elements and the two non-contracted dimensions have M and N elements.
1967 HloInstruction* add_result = nullptr;
1968 int64_t rhs_contracting_dim_offset = 0;
1969 int64_t n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1970 for (HloInstruction* concat_op : lhs->operands()) {
1971 int64_t sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1972 Shape rhs_slice_shape(rhs->shape());
1973 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1974 simplifier_->UpdateLayout(&rhs_slice_shape);
1975
1976 std::array<int64, 2> start_indices;
1977 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1978 start_indices[1 - rhs_contracting_dim] = 0;
1979
1980 std::array<int64, 2> limit_indices;
1981 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1982 limit_indices[1 - rhs_contracting_dim] = n;
1983
1984 HloInstruction* rhs_slice =
1985 computation_->AddInstruction(HloInstruction::CreateSlice(
1986 rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1987 /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1988
1989 // TODO(b/69062148): We can get rid of `swapped` once all backends support
1990 // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1991 // LHS with dimension 0 of the RHS). But for now we keep the same
1992 // contraction dimensions as the incoming dot operation to ensure the new
1993 // dot operations can be lowered.
1994 HloInstruction *new_dot_lhs, *new_dot_rhs;
1995 if (swapped) {
1996 new_dot_lhs = rhs_slice;
1997 new_dot_rhs = concat_op;
1998 } else {
1999 new_dot_lhs = concat_op;
2000 new_dot_rhs = rhs_slice;
2001 }
2002
2003 auto* new_dot = computation_->AddInstruction(
2004 HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
2005 new_dot_dnums, dot.precision_config()));
2006
2007 if (add_result) {
2008 add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
2009 dot.shape(), HloOpcode::kAdd, add_result, new_dot));
2010 } else {
2011 add_result = new_dot;
2012 }
2013
2014 rhs_contracting_dim_offset += sub_k;
2015 }
2016
2017 return add_result;
2018 }
2019
OptimizeDotOfGather(HloInstruction * dot)2020 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
2021 HloInstruction* dot) {
2022 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2023 if (dnums.lhs_contracting_dimensions_size() != 1 ||
2024 dnums.rhs_contracting_dimensions_size() != 1 ||
2025 dnums.lhs_batch_dimensions_size() != 0 ||
2026 dnums.rhs_batch_dimensions_size() != 0 ||
2027 dot->shape().dimensions_size() != 2) { // dot output 2D
2028 VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
2029 return nullptr;
2030 }
2031
2032 // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
2033 // Currently a Gather is a DynamicSlice.
2034 auto is_dynamic_slice_constant_combination =
2035 [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
2036 // First operand is a DynamicSlice(Constant).
2037 if (a->opcode() != HloOpcode::kDynamicSlice) {
2038 return false;
2039 }
2040 auto* dynamic_slice_op = a->operand(0);
2041 if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
2042 return false;
2043 }
2044 // Second operand is a Constant.
2045 if (b->opcode() != HloOpcode::kConstant) {
2046 return false;
2047 }
2048 // The DynamicSlice output is a vector.
2049 const Shape& dynamic_slice_shape = a->shape();
2050 if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
2051 return false;
2052 }
2053 // Constant size is the same before and after slice in the contracting
2054 // dimension, otherwise we either must precompute for all possible slice
2055 // indices or dot is invalid.
2056 const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
2057 if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
2058 dynamic_slice_shape.dimensions(a_contracting_dimension)) {
2059 return false;
2060 }
2061 return true;
2062 };
2063
2064 HloInstruction* lhs = dot->mutable_operand(0);
2065 HloInstruction* rhs = dot->mutable_operand(1);
2066 int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2067 int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2068
2069 if (!is_dynamic_slice_constant_combination(
2070 lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2071 !is_dynamic_slice_constant_combination(
2072 rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2073 VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2074 "dot(ctB, DS(ctA)), where the two constants have equal "
2075 "contracting dimensions.";
2076 return nullptr;
2077 }
2078
2079 // LHS is DynamicSlice:
2080 // input: dot(DS(ctA), ctB))
2081 // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2082 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2083 // output: DS(dot(ctA, ctB))
2084 // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2085
2086 // RHS is DynamicSlice:
2087 // input: dot(ctA, DS(ctB))
2088 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2089 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2090 // output: DS(dot(ctA, ctB))
2091 // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2092
2093 bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2094 HloDynamicSliceInstruction* dynamic_slice =
2095 lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2096 : Cast<HloDynamicSliceInstruction>(rhs);
2097
2098 // ctA:
2099 HloInstruction* left_operand =
2100 lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2101 // ctB:
2102 HloInstruction* right_operand =
2103 lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2104 // Build ctA x ctB.
2105 const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2106 const int n =
2107 right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2108 auto memoized_shape =
2109 ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2110 simplifier_->UpdateLayout(&memoized_shape);
2111 auto* memoized_inst = computation_->AddInstruction(
2112 HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2113 dnums, dot->precision_config()));
2114 // Get pair {start, 0} or {0, start}.
2115 // Position of start:
2116 int index_of_non_zero_start = lhs_is_dynamic_slice
2117 ? 1 - lhs_contracting_dimension
2118 : 1 - rhs_contracting_dimension;
2119 // Position of zero:
2120 int index_of_zero_start = 1 - index_of_non_zero_start;
2121
2122 // Slice out start and 0 components and reorder if necessary.
2123 auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2124 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2125 simplifier_->UpdateLayout(&s_shape);
2126 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2127 simplifier_->UpdateLayout(&d_shape);
2128 HloInstruction* non_zero_start =
2129 dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2130 HloInstruction* zero_start =
2131 dynamic_slice->mutable_operand(1 + index_of_zero_start);
2132 std::vector<HloInstruction*> new_start_indices;
2133 if (lhs_is_dynamic_slice) {
2134 new_start_indices = {non_zero_start, zero_start};
2135 } else {
2136 new_start_indices = {zero_start, non_zero_start};
2137 }
2138
2139 // Build DynamicSlice(ctA x ctB).
2140 const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2141 const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2142 auto* memoized_lookup =
2143 computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
2144 dot->shape(), memoized_inst, new_start_indices,
2145 {new_slice_m, new_slice_n}));
2146
2147 return memoized_lookup;
2148 }
2149
2150 // This function tries to transform
2151 // dot(reshape(transpose(A)), Const) to
2152 // dot(reshape(A), reshape(transpose(reshape(Const)))),
2153 // so that the reshape and transpose on the Const side can be constant folded.
2154 //
2155 // The basic idea is that since the accumulation in the dot operation is
2156 // associative, so as long as we permute the elements of the contracting
2157 // dimensions on both sides of the dot in the same way, the result of the
2158 // dot is not affected.
2159 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2160 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2161 HloInstruction* dot) {
2162 // This transformation assumes layout is not assigned yet.
2163 if (options_.is_layout_sensitive()) {
2164 return nullptr;
2165 }
2166
2167 // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2168 // remainder of this function easier.
2169 auto dnums = dot->dot_dimension_numbers();
2170 auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2171 auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2172 auto* lhs = dot->mutable_operand(0);
2173 auto* rhs = dot->mutable_operand(1);
2174 if (dot->operand(0)->IsConstant()) {
2175 std::swap(lhs, rhs);
2176 std::swap(lhs_contracting_dims, rhs_contracting_dims);
2177 }
2178
2179 // Require single contracting dim to make the implementation easier to
2180 // track contracting dims.
2181 if (dnums.lhs_contracting_dimensions_size() != 1) {
2182 return nullptr;
2183 }
2184
2185 // Pattern match Dot(reshape(transpose(input), constant))
2186 HloInstruction* reshape;
2187 HloInstruction* transpose;
2188 HloInstruction* input;
2189 HloInstruction* constant;
2190 if (!Match(lhs,
2191 m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2192 !Match(rhs, m::Constant(&constant))) {
2193 return nullptr;
2194 }
2195
2196 // Check that reshape squishes some dims into one dim and that this one
2197 // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2198 // be N - 1, where N is the rank of the reshape output. This means that the
2199 // reshape squishes some dims into one dim. lhs contracting dim should not
2200 // be in unmodified_dims. This means that the squishing target dim is the
2201 // lhs contracting dim.
2202 auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2203 reshape->operand(0)->shape(), reshape->shape());
2204 CHECK_EQ(lhs_contracting_dims.size(), 1);
2205 if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2206 absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
2207 return p.second == lhs_contracting_dims[0];
2208 })) {
2209 return nullptr;
2210 }
2211
2212 // Virtually pull the reshape into the dot so the dot operates on the
2213 // transpose, with "unsquished" lhs contracting dims. The new contracting
2214 // dims are all of the dims that are modified by the reshape -- that is, every
2215 // dimension that's not in `unmodified_dims[i].first`.
2216 //
2217 // (We don't need to actually create a new dot instruction. We can just keep
2218 // track of lhs and lhs_contracting_dims.)
2219 absl::flat_hash_set<int64> unmodified_transpose_dims;
2220 for (const auto& pair : unmodified_dims) {
2221 unmodified_transpose_dims.insert(pair.first);
2222 }
2223 lhs_contracting_dims.Clear();
2224 for (int64_t i = 0; i < transpose->shape().dimensions_size(); ++i) {
2225 if (!unmodified_transpose_dims.contains(i)) {
2226 lhs_contracting_dims.Add(i);
2227 }
2228 }
2229 // We require the "unsquished" lhs contracting dims to be consecutive.
2230 auto is_iota = [](absl::Span<const int64> dims) {
2231 return absl::c_adjacent_find(dims, [](const int64_t a, const int64_t b) {
2232 return (b != a + 1);
2233 }) == dims.end();
2234 };
2235 if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
2236 return nullptr;
2237 }
2238 lhs = lhs->mutable_operand(0);
2239
2240 // Check that the transpose only permutes the contracting dims.
2241 const auto& transpose_dims = transpose->dimensions();
2242 for (int64_t i = 0; i < transpose_dims.size(); ++i) {
2243 if (transpose_dims[i] != i &&
2244 !absl::c_linear_search(lhs_contracting_dims, i)) {
2245 return nullptr;
2246 }
2247 }
2248 // Virtually pull the transpose into the dot. Now the dot is equivalent to
2249 // a new dot with "permuted" lhs contracting dims.
2250 std::vector<int64> permutation;
2251 for (auto dim : lhs_contracting_dims) {
2252 permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2253 }
2254 CHECK(IsPermutation(permutation));
2255 auto new_lhs_contracting_dims =
2256 ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
2257 lhs_contracting_dims.Clear();
2258 for (auto dim : new_lhs_contracting_dims) {
2259 lhs_contracting_dims.Add(dim);
2260 }
2261 lhs = lhs->mutable_operand(0);
2262
2263 // All checks are passed at this point.
2264 //
2265 // Transform lhs. Remove the transpose and reshape by sorting the lhs
2266 // contracting dims and squishing them into a single one. We don't actually
2267 // squish the lhs_contracting_dims here because we still need the unsquished
2268 // contracting dims to invert reshape and transpose.
2269 absl::c_sort(lhs_contracting_dims);
2270 lhs = computation_->AddInstruction(
2271 HloInstruction::CreateReshape(reshape->shape(), lhs));
2272
2273 // Transform rhs. Say the input HLO is:
2274 //
2275 // t0 = f32[2, 2, 3] parameter(0)
2276 // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2277 // t2 = f32[2, 6] reshape(t1)
2278 // t3 = f32[6, 2] constant(...)
2279 // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2280 // rhs_contracting_dims={0}
2281 //
2282 // At this point in the function, we have decided that the second and third
2283 // dims of t0 can be switched to remove the transpose, and we have
2284 // "virtually decomposed" the input HLO to:
2285 //
2286 // t0 = f32[2, 2, 3] parameter(0)
2287 // t2' = f32[2, 6] reshape(t0)
2288 // t3' = f32[6, 2] ops-to-be-filled ...
2289 // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2290 // rhs_contracting_dims={0}
2291 //
2292 // The rest of this function is to fill in the ops of t3'. To do this, we
2293 // unsquish the contracting dimensions in t3 and then apply the inverse of
2294 // the transpose from t1.
2295
2296 // Invert reshape.
2297 CHECK_EQ(rhs_contracting_dims.size(), 1);
2298 std::vector<int64> rhs_unsquished_shape_dims =
2299 SpanToVector(constant->shape().dimensions());
2300 auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2301 rhs_contracting_dims[0]);
2302 for (auto dim : lhs_contracting_dims) {
2303 it = rhs_unsquished_shape_dims.insert(it,
2304 transpose->shape().dimensions(dim));
2305 ++it;
2306 }
2307 HloInstruction* rhs_reshape =
2308 computation_->AddInstruction(HloInstruction::CreateReshape(
2309 ShapeUtil::MakeShape(constant->shape().element_type(),
2310 rhs_unsquished_shape_dims),
2311 constant));
2312 rhs = rhs_reshape;
2313
2314 // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2315 rhs_contracting_dims.Resize(lhs_contracting_dims.size(),
2316 rhs_contracting_dims[0]);
2317 absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2318
2319 // Invert transpose. First compute the shape.
2320 std::vector<int64> rhs_transpose_shape_dims =
2321 SpanToVector(rhs_reshape->shape().dimensions());
2322 it = rhs_transpose_shape_dims.erase(
2323 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2324 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2325 rhs_contracting_dims.size());
2326 for (auto dim : lhs_contracting_dims) {
2327 it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2328 ++it;
2329 }
2330 // Then compute the transpose dims.
2331 std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank());
2332 absl::c_iota(rhs_transpose_dims, 0);
2333 it = rhs_transpose_dims.erase(
2334 rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2335 rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2336 rhs_contracting_dims.size());
2337 auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2338 for (auto dim : lhs_contracting_dims) {
2339 it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2340 lhs_contracting_dims[0] +
2341 rhs_contracting_dims[0]);
2342 ++it;
2343 }
2344 HloInstruction* rhs_transpose =
2345 computation_->AddInstruction(HloInstruction::CreateTranspose(
2346 ShapeUtil::MakeShape(constant->shape().element_type(),
2347 rhs_transpose_shape_dims),
2348 rhs_reshape, rhs_transpose_dims));
2349 rhs = rhs_transpose;
2350
2351 // Squish the multiple rhs contracting dims into a single one.
2352 rhs = computation_->AddInstruction(
2353 HloInstruction::CreateReshape(constant->shape(), rhs));
2354
2355 // If we virtually swapped lhs and rhs, we need to swap it back before
2356 // creating new dot.
2357 if (dot->operand(0)->IsConstant()) {
2358 std::swap(lhs, rhs);
2359 }
2360
2361 HloInstruction* new_dot =
2362 computation_->AddInstruction(HloInstruction::CreateDot(
2363 dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2364 return new_dot;
2365 }
2366
HandleDot(HloInstruction * dot)2367 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2368 CHECK(computation_ == dot->parent());
2369 HloInstruction *lhs, *rhs;
2370 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2371 if (options_.is_layout_sensitive()) {
2372 return Status::OK();
2373 }
2374 // Replace a zero element dot with a broadcast of the constant 0.
2375 if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2376 ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2377 ShapeUtil::IsZeroElementArray(rhs->shape())) {
2378 auto zero = computation_->AddInstruction(
2379 simplifier_->CreateConstantWithLayoutUpdated(
2380 LiteralUtil::Zero(dot->shape().element_type())));
2381 return ReplaceWithNewInstruction(
2382 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2383 }
2384
2385 // If there are no contracting dimensions, a dot can be rewritten as
2386 // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2387 if (options_.enable_dot_to_multiply_rewrite() &&
2388 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
2389 TF_ASSIGN_OR_RETURN(
2390 HloInstruction * new_lhs,
2391 NormalizeDotOperandToBatchMajorAndContractingMinor(
2392 lhs,
2393 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2394 AsInt64Slice(
2395 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2396 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2397 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2398 }
2399 TF_ASSIGN_OR_RETURN(
2400 HloInstruction * new_rhs,
2401 NormalizeDotOperandToBatchMajorAndContractingMinor(
2402 rhs,
2403 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2404 AsInt64Slice(
2405 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2406 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2407 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2408 }
2409 if (dot->shape().rank() != lhs->shape().rank()) {
2410 std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
2411 absl::c_iota(lhs_broadcast_dims, 0);
2412 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2413 dot->shape(), new_lhs, lhs_broadcast_dims));
2414 }
2415 if (dot->shape().rank() != rhs->shape().rank()) {
2416 std::vector<int64> rhs_broadcast_dims(
2417 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2418 absl::c_iota(rhs_broadcast_dims, 0);
2419 for (int64_t i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2420 rhs_broadcast_dims.push_back(i);
2421 }
2422 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2423 dot->shape(), new_rhs, rhs_broadcast_dims));
2424 }
2425 return ReplaceWithNewInstruction(
2426 dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2427 new_lhs, new_rhs));
2428 }
2429
2430 // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2431 // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2432 if (options_.enable_dot_strength_reduction() &&
2433 ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2434 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
2435 lhs->shape().rank()) ||
2436 (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
2437 dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
2438 rhs->shape().rank()))) {
2439 TF_ASSIGN_OR_RETURN(
2440 HloInstruction * new_lhs,
2441 NormalizeDotOperandToBatchMajorAndContractingMinor(
2442 lhs,
2443 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2444 AsInt64Slice(
2445 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2446 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2447 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2448 }
2449
2450 TF_ASSIGN_OR_RETURN(
2451 HloInstruction * new_rhs,
2452 NormalizeDotOperandToBatchMajorAndContractingMinor(
2453 rhs,
2454 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2455 AsInt64Slice(
2456 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2457 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2458 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2459 }
2460
2461 int64_t lhs_outer_dims =
2462 lhs->shape().rank() -
2463 (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2464 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2465 int64_t rhs_outer_dims =
2466 rhs->shape().rank() -
2467 (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2468 dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
2469 CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2470 if (rhs_outer_dims > 0) {
2471 std::vector<int64> lhs_broadcast_dims(
2472 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2473 absl::c_iota(lhs_broadcast_dims, 0);
2474 lhs_broadcast_dims.resize(lhs->shape().rank());
2475 std::iota(lhs_broadcast_dims.begin() +
2476 dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
2477 lhs_broadcast_dims.end(),
2478 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2479 rhs_outer_dims);
2480 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2481 new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2482 } else if (lhs_outer_dims > 0) {
2483 std::vector<int64> rhs_broadcast_dims(
2484 dot->dot_dimension_numbers().rhs_batch_dimensions_size());
2485 absl::c_iota(rhs_broadcast_dims, 0);
2486 rhs_broadcast_dims.resize(rhs->shape().rank());
2487 std::iota(rhs_broadcast_dims.begin() +
2488 dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
2489 rhs_broadcast_dims.end(),
2490 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2491 lhs_outer_dims);
2492 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2493 new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2494 }
2495
2496 TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2497 MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2498 std::vector<int64> reduce_dims(
2499 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2500 PrimitiveType dot_type =
2501 ShapeUtil::ElementIsFloating(dot->shape())
2502 ? (dot->shape().element_type() == F64 ? F64 : F32)
2503 : dot->shape().element_type();
2504 new_dot = AsType(new_dot, dot_type);
2505 const int64_t outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2506 absl::c_iota(
2507 reduce_dims,
2508 outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2509 new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2510 new_dot = AsType(new_dot, dot->shape().element_type());
2511 return ReplaceInstruction(dot, new_dot);
2512 }
2513
2514 // Simplify dot(reshape(transpose(A)), Const) to:
2515 // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2516 // and transpose on the Const side can be constant folded.
2517 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2518 OptimizeDotOfReorderContractingDims(dot));
2519 if (dot_of_reorder_optimized) {
2520 VLOG(10) << " Replaced dot " << dot->ToString()
2521 << " with new dot operation: "
2522 << dot_of_reorder_optimized->ToString();
2523 return ReplaceInstruction(dot, dot_of_reorder_optimized);
2524 }
2525
2526 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2527 OptimizeDotOfConcat(dot));
2528 if (dot_of_concat_optimized) {
2529 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2530 "constant)...)";
2531 return ReplaceInstruction(dot, dot_of_concat_optimized);
2532 }
2533
2534 // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2535 // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2536 // batched version of dot.
2537 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2538 OptimizeDotOfGather(dot));
2539 if (dot_of_gather_optimized) {
2540 VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2541 "gather(i, dot*(constA, constB))";
2542 return ReplaceInstruction(dot, dot_of_gather_optimized);
2543 }
2544
2545 TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2546 RemoveDegenerateDimensionFromDot(dot));
2547 if (removed_degenerate_dimensions) {
2548 return Status::OK();
2549 }
2550
2551 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
2552 if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
2553 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
2554 dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
2555 dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
2556 lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
2557 DotDimensionNumbers dot_dimension_numbers;
2558 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
2559 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
2560 auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
2561 ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
2562 rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
2563 dot->precision_config()));
2564 return ReplaceWithNewInstruction(
2565 dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
2566 }
2567
2568 return Status::OK();
2569 }
2570
HandleGather(HloInstruction * gather)2571 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2572 const Shape& operand_shape = gather->operand(0)->shape();
2573 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2574 return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2575 }
2576
2577 // Gathering from a scalar operand is simply a broadcast of that scalar
2578 if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2579 HloInstruction* new_operand = gather->mutable_operand(0);
2580 if (operand_shape.rank()) {
2581 TF_ASSIGN_OR_RETURN(new_operand,
2582 MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2583 operand_shape.element_type()),
2584 new_operand));
2585 }
2586 HloInstruction* new_gather =
2587 MakeBroadcastHlo(new_operand, {}, gather->shape());
2588 return ReplaceInstruction(gather, new_gather);
2589 }
2590 // If the operand of a gather is very small, it is easier to fuse a
2591 // sequence of selects.
2592 const Shape& index_shape = gather->operand(1)->shape();
2593 if (operand_shape.rank() == 1 &&
2594 operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2595 gather->gather_dimension_numbers().index_vector_dim() ==
2596 index_shape.rank() &&
2597 gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2598 const int64_t operand_elements = operand_shape.dimensions(0);
2599 auto get_value = [&](int64_t i) {
2600 auto slice = computation_->AddInstruction(HloInstruction::CreateSlice(
2601 ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2602 gather->mutable_operand(0), {i}, {i + 1}, {1}));
2603 auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape(
2604 ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2605 return computation_->AddInstruction(
2606 HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2607 };
2608 auto result = get_value(0);
2609 auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2610 simplifier_->UpdateLayout(&pred_shape);
2611 auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2612 index_shape.element_type());
2613 simplifier_->UpdateLayout(&iter_shape);
2614 for (int64_t i = 0; i < operand_elements; ++i) {
2615 auto index_mask =
2616 computation_->AddInstruction(HloInstruction::CreateCompare(
2617 pred_shape, gather->mutable_operand(1),
2618 MakeScalarLike(gather->mutable_operand(1), i),
2619 ComparisonDirection::kGe));
2620 result = computation_->AddInstruction(
2621 HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2622 index_mask, get_value(i), result));
2623 }
2624 return ReplaceInstruction(gather, result);
2625 }
2626 return Status::OK();
2627 }
2628
2629 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast,AlgebraicSimplifier * simplifier)2630 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2631 HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2632 HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) {
2633 HloInstruction* clamp_lower_bound;
2634 CHECK(Match(clamp_lower_bound_bcast,
2635 m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2636 << clamp_lower_bound_bcast->ToString();
2637
2638 HloInstruction* clamp_upper_bound;
2639 CHECK(Match(clamp_upper_bound_bcast,
2640 m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2641 << clamp_upper_bound_bcast->ToString();
2642
2643 const Literal& lower_bound =
2644 Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2645 const Literal& upper_bound =
2646 Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2647
2648 TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped,
2649 lower_bound.Reshape({}));
2650 TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped,
2651 upper_bound.Reshape({}));
2652 std::unique_ptr<HloInstruction> lower_bound_instr =
2653 HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped));
2654 std::unique_ptr<HloInstruction> upper_bound_instr =
2655 HloInstruction::CreateConstant(std::move(upper_bound_literal_reshaped));
2656
2657 Shape compare_shape =
2658 ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED);
2659 simplifier->UpdateLayout(&compare_shape);
2660 std::unique_ptr<HloInstruction> cloned_instruction =
2661 HloInstruction::CreateCompare(compare_shape, lower_bound_instr.get(),
2662 upper_bound_instr.get(),
2663 ComparisonDirection::kLt);
2664
2665 HloEvaluator evaluator;
2666 TF_ASSIGN_OR_RETURN(auto result,
2667 evaluator.Evaluate(cloned_instruction.get()));
2668 if (result.IsAll(true)) {
2669 return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2670 clamp_lower_bound_bcast, to_clamp,
2671 clamp_upper_bound_bcast);
2672 }
2673 return std::unique_ptr<HloInstruction>();
2674 }
2675 } // namespace
2676
HandleMaximum(HloInstruction * maximum)2677 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2678 HloInstruction *lhs, *rhs;
2679 CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2680
2681 HloInstruction* clamp_upper_bound_bcast;
2682 HloInstruction* clamp_lower_bound_bcast;
2683 HloInstruction* to_clamp;
2684 if (Match(maximum, m::MaximumAnyOrder(
2685 m::Broadcast(&clamp_lower_bound_bcast,
2686 m::ConstantEffectiveScalar()),
2687 m::MinimumAnyOrder(
2688 m::Op(&to_clamp),
2689 m::Broadcast(&clamp_upper_bound_bcast,
2690 m::ConstantEffectiveScalar()))))) {
2691 TF_ASSIGN_OR_RETURN(auto clamp,
2692 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2693 clamp_upper_bound_bcast, simplifier_));
2694 if (clamp) {
2695 return ReplaceWithNewInstruction(maximum, std::move(clamp));
2696 }
2697 }
2698
2699 HloInstruction* clamp_lower_bound;
2700 HloInstruction* clamp_upper_bound;
2701 HloInstruction* max_operand;
2702 HloInstruction* clamp;
2703 if (Match(maximum,
2704 m::MaximumAnyOrder(
2705 m::Op(&max_operand),
2706 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2707 m::Op(&clamp_upper_bound))))) {
2708 if (max_operand == clamp_lower_bound &&
2709 ReplaceInstructionIfSameShape(maximum, clamp)) {
2710 return Status::OK();
2711 }
2712 }
2713
2714 return Status::OK();
2715 }
2716
HandleMinimum(HloInstruction * minimum)2717 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2718 HloInstruction *lhs, *rhs;
2719 CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2720
2721 HloInstruction* clamp_upper_bound_bcast;
2722 HloInstruction* clamp_lower_bound_bcast;
2723 HloInstruction* to_clamp;
2724 if (Match(minimum, m::MinimumAnyOrder(
2725 m::Broadcast(&clamp_upper_bound_bcast,
2726 m::ConstantEffectiveScalar()),
2727 m::MaximumAnyOrder(
2728 m::Op(&to_clamp),
2729 m::Broadcast(&clamp_lower_bound_bcast,
2730 m::ConstantEffectiveScalar()))))) {
2731 TF_ASSIGN_OR_RETURN(auto clamp,
2732 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2733 clamp_upper_bound_bcast, simplifier_));
2734 if (clamp) {
2735 return ReplaceWithNewInstruction(minimum, std::move(clamp));
2736 }
2737 }
2738
2739 return Status::OK();
2740 }
2741
HandleClamp(HloInstruction * clamp)2742 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2743 HloInstruction* clamp_lower_bound;
2744 HloInstruction* clamp_upper_bound;
2745 HloInstruction* to_clamp;
2746 CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2747 m::Op(&clamp_upper_bound))));
2748
2749 // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2750 if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2751 m::Op().Is(clamp_upper_bound))) &&
2752 ReplaceInstructionIfSameShape(clamp, to_clamp)) {
2753 return Status::OK();
2754 }
2755
2756 // Eliminate redundant clamping of replica-id or partition-id.
2757 if ((Match(to_clamp, m::PartitionId()) || Match(to_clamp, m::ReplicaId())) &&
2758 Match(clamp_lower_bound, m::ConstantScalar(0U)) &&
2759 Match(clamp_upper_bound, m::ConstantScalar())) {
2760 int64_t upper_bound = Cast<HloConstantInstruction>(clamp_upper_bound)
2761 ->literal()
2762 .GetFirstElement<uint32_t>();
2763 const HloModuleConfig& config = clamp->GetModule()->config();
2764 int64_t runtime_bound = Match(to_clamp, m::PartitionId())
2765 ? config.num_partitions()
2766 : config.replica_count();
2767
2768 // If num_partitions or replica_count is 1, infer it as unknown.
2769 // pid/rid < runtime_bound => The clamp(0, pid/rid, upper_bound) is
2770 // redundant if the runtime_bound <= upper_bound + 1;
2771 if (runtime_bound != 1 && runtime_bound <= upper_bound + 1) {
2772 return ReplaceInstruction(clamp, to_clamp);
2773 }
2774 }
2775
2776 return Status::OK();
2777 }
2778
HandleMultiply(HloInstruction * multiply)2779 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
2780 HloInstruction *lhs, *rhs;
2781 CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
2782 // LHS*1 => LHS
2783 VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
2784 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
2785 return Status::OK();
2786 }
2787 // 1*RHS => RHS
2788 VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
2789 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
2790 return Status::OK();
2791 }
2792
2793 // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
2794 if (IsAll(lhs, 0) &&
2795 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2796 ReplaceInstructionIfSameShape(multiply, lhs)) {
2797 return Status::OK();
2798 }
2799 // LHS*0 => 0
2800 if (IsAll(rhs, 0) &&
2801 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2802 ReplaceInstructionIfSameShape(multiply, rhs)) {
2803 return Status::OK();
2804 }
2805
2806 {
2807 HloInstruction* abs_operand;
2808 if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
2809 !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
2810 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
2811 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
2812 changed_ = true;
2813 return Status::OK();
2814 }
2815 }
2816
2817 {
2818 HloInstruction *convert_operand, *operand;
2819 // Mul(Convert(Pred), operand) => select(pred, operand, 0)
2820 if (Match(multiply,
2821 m::MultiplyAnyOrder(
2822 m::Op(&operand),
2823 m::Convert(
2824 m::Op(&convert_operand)
2825 .WithShape(m::Shape().WithElementType(PRED)))))) {
2826 HloInstruction* zero_like_multiply =
2827 BroadcastZeros(computation_, multiply->shape().element_type(),
2828 multiply->shape().dimensions());
2829 return ReplaceWithNewInstruction(
2830 multiply, HloInstruction::CreateTernary(
2831 multiply->shape(), HloOpcode::kSelect, convert_operand,
2832 operand, zero_like_multiply));
2833 }
2834 }
2835
2836 {
2837 HloInstruction *a, *b, *c1, *c2;
2838 // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
2839 // constant1*constant2)
2840 if (Match(multiply,
2841 m::MultiplyAnyOrder(
2842 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2843 m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
2844 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2845 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2846 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2847 !ShapeUtil::IsScalar(multiply->shape())) {
2848 product_of_constants =
2849 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2850 multiply->shape(), product_of_constants, {}));
2851 }
2852
2853 return ReplaceWithNewInstruction(
2854 multiply,
2855 HloInstruction::CreateBinary(
2856 multiply->shape(), HloOpcode::kMultiply,
2857 computation_->AddInstruction(HloInstruction::CreateBinary(
2858 multiply->shape(), HloOpcode::kMultiply, a, b)),
2859 product_of_constants));
2860 }
2861 }
2862
2863 {
2864 HloInstruction *a, *c1, *c2;
2865 // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
2866 if (Match(multiply,
2867 m::MultiplyAnyOrder(
2868 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2869 m::Constant(&c2)))) {
2870 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2871 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2872 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2873 !ShapeUtil::IsScalar(multiply->shape())) {
2874 product_of_constants =
2875 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2876 multiply->shape(), product_of_constants, {}));
2877 }
2878
2879 return ReplaceWithNewInstruction(
2880 multiply,
2881 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
2882 a, product_of_constants));
2883 }
2884 }
2885
2886 {
2887 HloInstruction *a, *b, *constant, *op;
2888 // Mul(Mul(a, constant1), Broadcast(b)) =>
2889 // Mul(Broadcast(Mul(b, constant1), a))
2890 if (Match(multiply,
2891 m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
2892 m::Constant(&constant)),
2893 m::Op(&op))) ||
2894 Match(multiply,
2895 m::MultiplyAnyOrder(
2896 m::MultiplyAnyOrder(m::NonConstant(&a),
2897 m::Broadcast(m::Constant(&constant))),
2898 m::Op(&op)))) {
2899 // Check that the other side was a broadcast, and not of a constant.
2900 if (ShapeUtil::IsScalar(constant->shape()) &&
2901 Match(op, m::Broadcast(m::NonConstant()))) {
2902 auto dims = op->dimensions();
2903 b = op->mutable_operand(0);
2904 if (!ShapeUtil::IsScalar(b->shape())) {
2905 constant = computation_->AddInstruction(
2906 HloInstruction::CreateBroadcast(b->shape(), constant, {}));
2907 }
2908
2909 auto new_mul =
2910 computation_->AddInstruction(HloInstruction::CreateBinary(
2911 b->shape(), HloOpcode::kMultiply, b, constant));
2912
2913 return ReplaceWithNewInstruction(
2914 multiply,
2915 HloInstruction::CreateBinary(
2916 multiply->shape(), HloOpcode::kMultiply, a,
2917 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2918 multiply->shape(), new_mul, dims))));
2919 }
2920 }
2921 }
2922
2923 VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
2924 HloInstruction *a, *c1, *c2;
2925 if (Match(multiply,
2926 m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
2927 m::Constant(&c2))) ||
2928 Match(multiply,
2929 m::Multiply(
2930 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
2931 m::Broadcast(m::ConstantScalar(&c2))))) {
2932 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2933 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2934 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2935 !ShapeUtil::IsScalar(multiply->shape())) {
2936 product_of_constants =
2937 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2938 multiply->shape(), product_of_constants, {}));
2939 }
2940 return ReplaceWithNewInstruction(
2941 multiply,
2942 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
2943 product_of_constants));
2944 }
2945
2946 VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
2947 << multiply->ToString();
2948 if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
2949 auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
2950 multiply->shape(), HloOpcode::kAdd, lhs, rhs));
2951 return ReplaceWithNewInstruction(
2952 multiply,
2953 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
2954 }
2955
2956 VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
2957 << multiply->ToString();
2958 HloInstruction* b;
2959 if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
2960 IsPositive(b, options_)) {
2961 return ReplaceWithNewInstruction(
2962 multiply,
2963 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
2964 MakeScalarLike(b, 1), b));
2965 }
2966
2967 return Status::OK();
2968 }
2969
HandleNegate(HloInstruction * negate)2970 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
2971 // negate(negate(x)) => x
2972 HloInstruction* x;
2973 if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
2974 ReplaceInstructionIfSameShape(negate, x)) {
2975 return Status::OK();
2976 }
2977 return Status::OK();
2978 }
2979
HandleNot(HloInstruction * logical_not)2980 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
2981 // not(not(x)) => x
2982 HloInstruction* x;
2983 if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
2984 ReplaceInstructionIfSameShape(logical_not, x)) {
2985 return Status::OK();
2986 }
2987 return Status::OK();
2988 }
2989
HandleOr(HloInstruction * logical_or)2990 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
2991 HloInstruction *lhs, *rhs;
2992 CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
2993
2994 // Simplify logical or
2995 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
2996 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
2997 // A || True => True
2998 VLOG(10) << "trying transform [A || True => True]: "
2999 << logical_or->ToString();
3000 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
3001 return Status::OK();
3002 }
3003 // True || A => True
3004 VLOG(10) << "trying transform [True || A => True]: "
3005 << logical_or->ToString();
3006 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
3007 return Status::OK();
3008 }
3009 }
3010
3011 // A || False => A and A | 0 => A
3012 VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
3013 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
3014 return Status::OK();
3015 }
3016
3017 // False || A => A and 0 | A => A
3018 VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
3019 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
3020 return Status::OK();
3021 }
3022
3023 return Status::OK();
3024 }
3025
HandleLog(HloInstruction * log)3026 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
3027 // ln(exp(A)) => A
3028 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
3029 HloInstruction *a, *b;
3030 if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
3031 ReplaceInstructionIfSameShape(log, a)) {
3032 return Status::OK();
3033 }
3034
3035 // ln(pow(A,B)) => B*ln(abs(A))
3036 // or B*ln(A) if A is complex.
3037 if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
3038 auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
3039 ? a
3040 : computation_->AddInstruction(HloInstruction::CreateUnary(
3041 log->shape(), HloOpcode::kAbs, a));
3042 auto new_log = computation_->AddInstruction(
3043 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
3044 return ReplaceWithNewInstruction(
3045 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3046 new_log, b));
3047 }
3048
3049 if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
3050 auto new_log = computation_->AddInstruction(
3051 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3052 return ReplaceWithNewInstruction(
3053 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3054 new_log, MakeScalarLike(log, 0.5)));
3055 }
3056
3057 if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
3058 auto new_log = computation_->AddInstruction(
3059 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3060 return ReplaceWithNewInstruction(
3061 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3062 new_log, MakeScalarLike(log, -0.5)));
3063 }
3064
3065 return Status::OK();
3066 }
3067
HandleGetTupleElement(HloInstruction * get_tuple_element)3068 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
3069 HloInstruction* get_tuple_element) {
3070 auto operand = get_tuple_element->mutable_operand(0);
3071 if (operand->opcode() == HloOpcode::kTuple) {
3072 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
3073 VLOG(10) << "trying transform "
3074 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
3075 << get_tuple_element->ToString();
3076 if (ReplaceInstructionIfSameShape(
3077 get_tuple_element,
3078 operand->mutable_operand(get_tuple_element->tuple_index()))) {
3079 return Status::OK();
3080 }
3081 }
3082 return Status::OK();
3083 }
3084
3085 namespace {
3086
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)3087 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
3088 const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
3089 CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3090 return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3091 hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3092 }
3093
3094 // Returns true if the output of "instruction" is a permutation of the
3095 // elements of "operand". Precondition: "operand" is an operand of
3096 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3097 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3098 HloInstruction* operand) {
3099 DCHECK(!instruction->OperandIndices(operand).empty());
3100 switch (instruction->opcode()) {
3101 case HloOpcode::kReshape:
3102 case HloOpcode::kReverse:
3103 case HloOpcode::kTranspose:
3104 return true;
3105 case HloOpcode::kSort:
3106 return (!instruction->shape().IsTuple());
3107 default:
3108 return false;
3109 }
3110 }
3111
3112 // Returns true if the output of "instruction" is a subset of the elements of
3113 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3114 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3115 HloInstruction* operand) {
3116 const auto operand_indices = instruction->OperandIndices(operand);
3117 CHECK(!operand_indices.empty());
3118 if (operand_indices.size() != 1) {
3119 return false;
3120 }
3121 int64_t operand_index = operand_indices[0];
3122 switch (instruction->opcode()) {
3123 case HloOpcode::kSlice:
3124 CHECK_EQ(0, operand_index);
3125 return true;
3126 case HloOpcode::kDynamicSlice:
3127 return operand_index == 0;
3128 default:
3129 return false;
3130 }
3131 }
3132
3133 } // namespace
3134
HandleBroadcast(HloInstruction * broadcast)3135 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3136 HloInstruction* operand;
3137 CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3138 auto dims = broadcast->dimensions();
3139 // A degenerate broadcast of a reshape that does not change the number of
3140 // elements can be replaced by a reshape.
3141 if (std::is_sorted(dims.begin(), dims.end()) &&
3142 ShapeUtil::ElementsIn(broadcast->shape()) ==
3143 ShapeUtil::ElementsIn(operand->shape())) {
3144 VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3145 "n(broadcast(X)) == n(X)";
3146 return ReplaceWithNewInstruction(
3147 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3148 }
3149
3150 // A degenerate broadcast that has the same input and output rank can be
3151 // converted into a transpose.
3152 if (broadcast->shape().rank() == operand->shape().rank() &&
3153 ShapeUtil::ElementsIn(broadcast->shape()) ==
3154 ShapeUtil::ElementsIn(operand->shape())) {
3155 VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3156 "n(broadcast(X)) == n(X)";
3157 return ReplaceWithNewInstruction(
3158 broadcast,
3159 HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3160 }
3161
3162 // A broadcast of a reshape which merely inserts 1-sized dimensions can
3163 // elide its operand.
3164 {
3165 bool merely_inserts_or_deletes_1_sized_dimensions;
3166 std::vector<int64> inserted_indices, deleted_indices;
3167 std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
3168 inserted_indices) =
3169 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3170 if (merely_inserts_or_deletes_1_sized_dimensions &&
3171 deleted_indices.empty()) {
3172 std::reverse(inserted_indices.begin(), inserted_indices.end());
3173 for (auto inserted_index : inserted_indices) {
3174 dims.erase(dims.begin() + inserted_index);
3175 }
3176 return ReplaceWithNewInstruction(
3177 broadcast,
3178 HloInstruction::CreateBroadcast(broadcast->shape(),
3179 operand->mutable_operand(0), dims));
3180 }
3181 }
3182
3183 // A Broadcast that feeds a unary element-wise operation can sink the
3184 // broadcast after the unary element-wise operation.
3185 TF_ASSIGN_OR_RETURN(
3186 bool sink_succeeded,
3187 TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3188 changed_ |= sink_succeeded;
3189 if (sink_succeeded) {
3190 return Status::OK();
3191 }
3192
3193 // A scalar broadcast feeding an instruction which only permutes (reshape,
3194 // transpose, sort, reverse) or selects a subset of operand elements (slice,
3195 // dynamic slice) can be replaced with a broadcast directly to the output
3196 // shape of the instruction.
3197 if (ShapeUtil::IsScalar(operand->shape())) {
3198 for (HloInstruction* user : broadcast->users()) {
3199 // Skip if the broadcast user has no uses itself.
3200 if (user->user_count() == 0 && user != computation_->root_instruction()) {
3201 continue;
3202 }
3203 if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3204 OutputIsSubsetOfOperandElements(user, broadcast)) {
3205 VLOG(10) << "transform permuting/subset of a scalar broadcast into "
3206 << "a single broadcast";
3207 HloInstruction* new_broadcast = computation_->AddInstruction(
3208 HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3209 // Use HloInstruction::ReplaceAllUsesWith instead of
3210 // HloComputation::ReplaceWithNewInstruction because we are replacing an
3211 // instruction other than the visited instruction.
3212 changed_ = true;
3213 return user->ReplaceAllUsesWith(new_broadcast);
3214 }
3215 }
3216 return Status::OK();
3217 }
3218
3219 // broadcast(iota) -> iota.
3220 if (operand->opcode() == HloOpcode::kIota) {
3221 return ReplaceWithNewInstruction(
3222 broadcast,
3223 HloInstruction::CreateIota(
3224 broadcast->shape(),
3225 dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3226 }
3227
3228 // Merge two consecutive broadcasts into a single one.
3229 if (operand->opcode() == HloOpcode::kBroadcast) {
3230 std::vector<int64> new_dimensions;
3231 for (auto dim : operand->dimensions()) {
3232 new_dimensions.push_back(dims[dim]);
3233 }
3234 return ReplaceWithNewInstruction(
3235 broadcast,
3236 HloInstruction::CreateBroadcast(
3237 broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3238 }
3239 if (options_.is_layout_sensitive()) {
3240 return Status::OK();
3241 }
3242 if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3243 auto new_operand =
3244 operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3245 ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3246 std::vector<int64> new_dims;
3247 new_dims.reserve(new_operand->shape().rank());
3248 for (int64_t i = 0; i < operand->shape().rank(); ++i) {
3249 if (operand->shape().dimensions(i) != 1) {
3250 new_dims.push_back(dims[i]);
3251 }
3252 }
3253 return ReplaceWithNewInstruction(
3254 broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3255 new_operand, new_dims));
3256 }
3257 return Status::OK();
3258 }
3259
HandleCompare(HloInstruction * compare)3260 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3261 HloInstruction* lhs;
3262 HloInstruction* rhs;
3263 CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3264 {
3265 // compare(broadcast(a) + x, broadcast(b)) ==>
3266 // compare(x, broadcast(b-a)), only enabled for integral types.
3267 HloInstruction *x, *a, *b;
3268 if (Match(compare,
3269 m::Compare(
3270 m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
3271 m::Shape().IsScalar()))),
3272 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
3273 if (ShapeUtil::ElementIsSigned(x->shape()) &&
3274 ShapeUtil::ElementIsIntegral(x->shape())) {
3275 HloInstruction* sub =
3276 computation_->AddInstruction(HloInstruction::CreateBinary(
3277 b->shape(), HloOpcode::kSubtract, b, a));
3278 HloInstruction* broadcast = computation_->AddInstruction(
3279 HloInstruction::CreateBroadcast(x->shape(), sub, {}));
3280 HloInstruction* new_compare = computation_->AddInstruction(
3281 HloInstruction::CreateCompare(compare->shape(), x, broadcast,
3282 compare->comparison_direction()));
3283 return ReplaceInstruction(compare, new_compare);
3284 }
3285 }
3286 }
3287
3288 if (Cast<HloCompareInstruction>(compare)->type() ==
3289 Comparison::Type::kUnsigned) {
3290 // X u< 0 -> false
3291 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3292 IsAll(rhs, 0)) {
3293 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3294 }
3295 // X u>= 0 -> true
3296 if (compare->comparison_direction() == ComparisonDirection::kGe &&
3297 IsAll(rhs, 0)) {
3298 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3299 }
3300 // 0 u> X -> false
3301 if (compare->comparison_direction() == ComparisonDirection::kGt &&
3302 IsAll(lhs, 0)) {
3303 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3304 }
3305 // 0 u<= X -> true
3306 if (compare->comparison_direction() == ComparisonDirection::kLe &&
3307 IsAll(lhs, 0)) {
3308 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3309 }
3310 }
3311
3312 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3313 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3314 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3315 } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3316 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3317 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3318 } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3319 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3320 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3321 } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3322 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3323 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3324 }
3325 if (lhs == rhs &&
3326 primitive_util::IsIntegralType(lhs->shape().element_type())) {
3327 switch (compare->comparison_direction()) {
3328 case ComparisonDirection::kGt:
3329 case ComparisonDirection::kLt:
3330 case ComparisonDirection::kNe:
3331 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3332 case ComparisonDirection::kEq:
3333 case ComparisonDirection::kGe:
3334 case ComparisonDirection::kLe:
3335 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3336 }
3337 }
3338 return Status::OK();
3339 }
3340
HandleConvert(HloInstruction * convert)3341 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3342 PrimitiveType src_type = convert->operand(0)->shape().element_type();
3343 PrimitiveType dest_type = convert->shape().element_type();
3344 // A conversion to the same element type as the operand is a nop and can be
3345 // removed. A conversion of a constant can be simplified by making a new
3346 // constant.
3347 if (src_type == dest_type) {
3348 return ReplaceInstruction(convert, convert->mutable_operand(0));
3349 }
3350
3351 // Eliminate a convert pair if it is a no-op. The following are a few
3352 // example cases that are being handled:
3353 // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3354 // and convert(A, $TYPE1) is an upcast
3355 // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3356 // and convert(A, $TYPE1) is an upcast and is an integral conversion from
3357 // unsigned to signed (only signed to unsigned conversion is NOT allowed)
3358 // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3359 // convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3360 // $TYPE1) , floor(A), A) -> a case where the first convert has a
3361 // fan-out
3362 if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
3363 IsConvertPairNoOp(convert)) {
3364 return ReplaceInstruction(convert,
3365 convert->mutable_operand(0)->mutable_operand(0));
3366 }
3367 return Status::OK();
3368 }
3369
3370 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3371 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3372 HloInstruction *c0, *c1;
3373 if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3374 c0 == c1) {
3375 return ReplaceInstruction(complex, c0);
3376 }
3377 return Status::OK();
3378 }
3379
3380 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3381 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3382 HloInstruction* op;
3383 if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3384 return ReplaceInstruction(real, op);
3385 }
3386 return Status::OK();
3387 }
3388
3389 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3390 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3391 HloInstruction* op;
3392 if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3393 return ReplaceInstruction(imag, op);
3394 }
3395 return Status::OK();
3396 }
3397
HandleIota(HloInstruction * instruction)3398 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3399 // iota -> zero if the iota dimension never produces an element other than
3400 // zero.
3401 auto* iota = Cast<HloIotaInstruction>(instruction);
3402 if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3403 auto zero = computation_->AddInstruction(
3404 simplifier_->CreateConstantWithLayoutUpdated(
3405 LiteralUtil::Zero(iota->shape().element_type()).Clone()));
3406 return ReplaceWithNewInstruction(
3407 iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
3408 }
3409 return Status::OK();
3410 }
3411
HandlePad(HloInstruction * pad)3412 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3413 if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3414 return ReplaceWithNewInstruction(
3415 pad, HloInstruction::CreateBroadcast(pad->shape(),
3416 pad->mutable_operand(1), {}));
3417 }
3418
3419 // Interior padding on one sized dimensions have no effect. As a result it
3420 // makes other simplifications possible if there is no interior padding.
3421 if (HasInteriorPadding(pad->padding_config())) {
3422 PaddingConfig padding_config = pad->padding_config();
3423 bool cleared_interior_padding = false;
3424 for (int64_t i = 0; i < pad->shape().rank(); ++i) {
3425 if (padding_config.dimensions(i).interior_padding() > 0 &&
3426 pad->operand(0)->shape().dimensions(i) == 1) {
3427 cleared_interior_padding = true;
3428 padding_config.mutable_dimensions(i)->set_interior_padding(0);
3429 }
3430 }
3431 if (cleared_interior_padding) {
3432 return ReplaceWithNewInstruction(
3433 pad,
3434 HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3435 pad->mutable_operand(1), padding_config));
3436 }
3437 }
3438
3439 // Eliminate nop pads (padding all zero), and replace a pad with negative
3440 // padding with a pad with non-negative padding followed by a slice.
3441 bool all_zero = true;
3442 bool has_negative = false;
3443 // Used to possibly split off the unchanged padding dimensions.
3444 std::vector<int64> padding_dimensions;
3445 int64_t dimension_index = 0;
3446 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3447 if (padding_dimension.edge_padding_low() < 0 ||
3448 padding_dimension.edge_padding_high() < 0) {
3449 has_negative = true;
3450 }
3451 if (padding_dimension.edge_padding_low() != 0 ||
3452 padding_dimension.edge_padding_high() != 0) {
3453 all_zero = false;
3454 padding_dimensions.push_back(dimension_index);
3455 } else if (padding_dimension.interior_padding()) {
3456 padding_dimensions.push_back(dimension_index);
3457 }
3458 dimension_index++;
3459 }
3460
3461 if (all_zero) {
3462 if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) {
3463 return Status::OK();
3464 }
3465 }
3466
3467 // The context of this optimization can be found at b/163617402
3468 // It tries to capture the case of pad(broadcast(x)), where
3469 // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3470 // a subset of the padded dimensions in pad->config(),
3471 // and the padded dimensions in pad->config() is in turn a strict
3472 // subset of broadcast->shape().dimensions(). The combined op can be
3473 // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3474 // x with dimensions that need to be padded, and broadcast2 extends
3475 // the result of padding to full dimensions.
3476 // TODO(qyi): for future extensions: The condition for broadcast(x)
3477 // ->dimensions() to be a subset of padded dimensions in pad->config()
3478 // does not have to be strictly required, but it makes the calculation
3479 // for optimization easier, so it is required by the current implementation.
3480 // Only the second condition between the padded dimensions and the
3481 // dimensions of the final shape have to be enforced for the optimization
3482 // to make sense. If needed to remove the first constraint, the shape
3483 // calculations across the implementation need to be re-adjusted.
3484 auto pad_dims = padding_dimensions.size();
3485 if (pad_dims < dimension_index &&
3486 pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3487 pad->operand(0)->user_count() == 1 &&
3488 pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3489 // Check broadcast operand dimensions is a subset of pading_dimensions.
3490 // If not, skip the optimization.
3491 bool opt_is_valid = true;
3492 std::vector<int64> broadcast_dimensions;
3493 HloBroadcastInstruction* broadcast =
3494 static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3495 for (auto broadcast_index : broadcast->dimensions()) {
3496 bool found = false;
3497 for (int i = 0; i < pad_dims; ++i) {
3498 if (broadcast_index == padding_dimensions[i]) {
3499 broadcast_dimensions.push_back(i);
3500 found = true;
3501 break;
3502 }
3503 }
3504 if (!found) {
3505 opt_is_valid = false;
3506 break;
3507 }
3508 }
3509 if (opt_is_valid) {
3510 auto pad_shape = pad->shape();
3511 auto broadcast_shape = broadcast->shape();
3512 auto pad_shape1 = pad_shape;
3513 auto broadcast_shape1 = broadcast_shape;
3514 PaddingConfig pad_config;
3515 for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3516 int64_t j = padding_dimensions[i];
3517 while (--dimension_index > j) {
3518 broadcast_shape1.DeleteDimension(dimension_index);
3519 pad_shape1.DeleteDimension(dimension_index);
3520 }
3521 }
3522 while (--dimension_index >= 0) {
3523 broadcast_shape1.DeleteDimension(dimension_index);
3524 pad_shape1.DeleteDimension(dimension_index);
3525 }
3526 for (auto dimension_to_pad : padding_dimensions) {
3527 auto dimension = pad_config.add_dimensions();
3528 *dimension = pad->padding_config().dimensions(dimension_to_pad);
3529 }
3530 *broadcast->mutable_shape() = broadcast_shape1;
3531 *broadcast->mutable_dimensions() = broadcast_dimensions;
3532 simplifier_->UpdateLayout(broadcast->mutable_shape());
3533 auto pad2 =
3534 computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3535 *pad2->mutable_padding_config() = pad_config;
3536 simplifier_->UpdateLayout(pad2->mutable_shape());
3537 auto broadcast2 = computation_->AddInstruction(
3538 HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3539 return ReplaceInstruction(pad, broadcast2);
3540 }
3541 }
3542
3543 if (has_negative && options_.enable_negative_padding_replacement()) {
3544 // Pad has negative padding. Replace with a pad with the non-negative
3545 // padding followed by a slice which effectively performs the negative
3546 // padding.
3547 // TODO(b/34628603): Add support for negative padding in the backends, or
3548 // change kPad semantics to disallow negative padding and use slice
3549 // instead.
3550
3551 // First construct the padding config with non-negative entries and the
3552 // compute the shape of this new pad instruction.
3553 PaddingConfig nonzero_padding = pad->padding_config();
3554 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3555 PaddingConfig::PaddingConfigDimension* padding_dimension =
3556 nonzero_padding.mutable_dimensions(i);
3557 // Set negative padding to zero.
3558 if (padding_dimension->edge_padding_low() < 0) {
3559 padding_dimension->set_edge_padding_low(0);
3560 }
3561 if (padding_dimension->edge_padding_high() < 0) {
3562 padding_dimension->set_edge_padding_high(0);
3563 }
3564 }
3565
3566 TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3567 MakePadHlo(pad->mutable_operand(0),
3568 pad->mutable_operand(1), nonzero_padding));
3569 // Copy the layout from the original pad instructions. The new pad and the
3570 // slice instruction should all have the same layout.
3571 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3572 pad->shape(), nonzero_pad->mutable_shape()));
3573 simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3574
3575 // Second, construct the slice instruction to perform the negative
3576 // padding.
3577 std::vector<int64> start_indices;
3578 std::vector<int64> end_indices;
3579 std::vector<int64> strides;
3580 for (int64_t i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3581 const PaddingConfig::PaddingConfigDimension& padding_dimension =
3582 pad->padding_config().dimensions(i);
3583 int64_t start = 0;
3584 if (padding_dimension.edge_padding_low() < 0) {
3585 start = -1 * padding_dimension.edge_padding_low();
3586 }
3587 int64_t end = nonzero_pad->shape().dimensions(i);
3588 if (padding_dimension.edge_padding_high() < 0) {
3589 end += padding_dimension.edge_padding_high();
3590 }
3591 start_indices.push_back(start);
3592 end_indices.push_back(end);
3593 strides.push_back(1);
3594 }
3595
3596 TF_ASSIGN_OR_RETURN(
3597 HloInstruction * slice,
3598 MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3599 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3600 pad->shape(), slice->mutable_shape()));
3601 simplifier_->UpdateLayout(slice->mutable_shape());
3602
3603 // Verify that the slice shape matches the pad shape.
3604 auto equal = Shape::Equal();
3605 if (!options_.is_layout_sensitive()) {
3606 equal.IgnoreTilesInLayout();
3607 }
3608 TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3609
3610 return ReplaceInstruction(pad, slice);
3611 }
3612
3613 return Status::OK();
3614 }
3615
HandlePower(HloInstruction * power)3616 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3617 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3618 HloInstruction *lhs, *rhs;
3619 CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3620 if (IsAll(rhs, 0)) {
3621 return ReplaceInstruction(power, MakeScalarLike(power, 1));
3622 }
3623
3624 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3625 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
3626 return Status::OK();
3627 }
3628
3629 // pow(exp(A),B) => exp(A*B)
3630 HloInstruction *a, *b;
3631 if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3632 auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
3633 power->shape(), HloOpcode::kMultiply, a, b));
3634 return ReplaceWithNewInstruction(
3635 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3636 a_times_b));
3637 }
3638
3639 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3640 if (IsAll(rhs, 2)) {
3641 return ReplaceWithNewInstruction(
3642 power, HloInstruction::CreateBinary(power->shape(),
3643 HloOpcode::kMultiply, lhs, lhs));
3644 }
3645
3646 // Pow(A, 3) is used in GELU.
3647 VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3648 if (IsAll(rhs, 3)) {
3649 HloInstruction* tmp =
3650 computation_->AddInstruction(HloInstruction::CreateBinary(
3651 power->shape(), HloOpcode::kMultiply, lhs, lhs));
3652 return ReplaceWithNewInstruction(
3653 power, HloInstruction::CreateBinary(power->shape(),
3654 HloOpcode::kMultiply, lhs, tmp));
3655 }
3656
3657 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3658 if (IsAll(rhs, -1)) {
3659 return ReplaceWithNewInstruction(
3660 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3661 MakeScalarLike(lhs, 1), lhs));
3662 }
3663
3664 return Status::OK();
3665 }
3666
3667 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3668 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3669 HloInstruction* broadcast) {
3670 TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3671 bool changed = false;
3672 if (ShapeUtil::IsScalar(broadcast->shape())) {
3673 return false;
3674 }
3675 HloInstruction* operand = broadcast->mutable_operand(0);
3676 auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3677 return instruction->opcode() == HloOpcode::kBroadcast &&
3678 ShapeUtil::IsScalar(instruction->operand(0)->shape());
3679 };
3680 auto is_equal_broadcast = [operand,
3681 broadcast](const HloInstruction* instruction) {
3682 return instruction->opcode() == HloOpcode::kBroadcast &&
3683 ShapeUtil::Equal(operand->shape(),
3684 instruction->operand(0)->shape()) &&
3685 broadcast->dimensions() == instruction->dimensions();
3686 };
3687 auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3688 return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3689 };
3690 for (HloInstruction* user : broadcast->users()) {
3691 if (user->user_count() == 0 && user != computation_->root_instruction()) {
3692 continue;
3693 }
3694 // Do not move reshapes or broadcasts past copies since the shape the copy
3695 // will operate on will change.
3696 if (user->opcode() == HloOpcode::kCopy) {
3697 continue;
3698 }
3699 // Do not change the shape of fusion nodes in case there a multiple shapes
3700 // inside the fusion node already.
3701 if (user->opcode() == HloOpcode::kFusion) {
3702 continue;
3703 }
3704 if (!user->IsElementwise()) {
3705 continue;
3706 }
3707
3708 // Check if all the operands of the user are compatible broadcasts for
3709 // sinking. (They are either scalar broadcasts or broadcasts casting
3710 // from/to the same shape/dimensions)
3711 int64_t compatible_broadcast_count = 0;
3712 int64_t broadcast_use_count = 0;
3713 for (HloInstruction* user_operand : user->operands()) {
3714 if (is_compatible_broadcast(user_operand)) {
3715 ++compatible_broadcast_count;
3716 } else if (broadcast == user_operand) {
3717 ++broadcast_use_count;
3718 }
3719 }
3720 if (compatible_broadcast_count + broadcast_use_count !=
3721 user->operand_count()) {
3722 continue;
3723 }
3724 std::vector<HloInstruction*> new_operands;
3725 new_operands.reserve(user->operand_count());
3726
3727 Shape changed_shape;
3728 for (HloInstruction* user_operand : user->operands()) {
3729 // If this is a broadcast operand that is not our original broadcast input
3730 // to this function then we might need to change the input.
3731 if (is_compatible_broadcast(user_operand)) {
3732 // If this is a broadcast from a scalar value rewrite a broadcast from
3733 // the scalar to the new shape enforced from the other broadcast
3734 // operands.
3735 if (is_scalar_broadcast(user_operand)) {
3736 changed_shape = ShapeUtil::ChangeElementType(
3737 operand->shape(), user_operand->shape().element_type());
3738 simplifier_->UpdateLayout(&changed_shape);
3739 new_operands.push_back(
3740 computation_->AddInstruction(HloInstruction::CreateBroadcast(
3741 changed_shape, user_operand->mutable_operand(0), {})));
3742 user_operand->SetupDerivedInstruction(new_operands.back());
3743 } else {
3744 // For the non-scalar broadcasts we guarantee that the shape of the
3745 // operand of the broadcast needs to be already a compatible shape.
3746 new_operands.push_back(user_operand->mutable_operand(0));
3747 }
3748 } else {
3749 CHECK_EQ(broadcast, user_operand);
3750 new_operands.push_back(operand);
3751 }
3752 }
3753 VLOG(4) << "Sinking broadcast after user:";
3754 VLOG(4) << " old broadcast: " << broadcast->ToString();
3755 VLOG(4) << " old user: " << user->ToString();
3756 changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
3757 user->shape().element_type());
3758 simplifier_->UpdateLayout(&changed_shape);
3759 HloInstruction* new_user = computation_->AddInstruction(
3760 user->CloneWithNewOperands(changed_shape, new_operands));
3761 VLOG(4) << " new user: " << new_user->ToString();
3762 HloInstruction* new_broadcast =
3763 computation_->AddInstruction(HloInstruction::CreateBroadcast(
3764 user->shape(), new_user, broadcast->dimensions()));
3765 broadcast->SetupDerivedInstruction(new_broadcast);
3766 VLOG(4) << " new broadcast: " << new_broadcast->ToString();
3767 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
3768 changed = true;
3769 }
3770 return changed;
3771 }
3772
3773 namespace {
3774 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)3775 std::unique_ptr<HloInstruction> TryRemainderToAnd(
3776 HloInstruction* remainder, HloComputation* computation,
3777 AlgebraicSimplifier* simplifier) {
3778 HloInstruction *a, *b, *c;
3779 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3780
3781 if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
3782 !Match(b, m::ConstantEffectiveScalar(&c)) &&
3783 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
3784 return nullptr;
3785 }
3786
3787 if (ShapeUtil::ElementIsSigned(remainder->shape())) {
3788 int64_t b_value = c->literal().GetFirstElement<T>();
3789 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
3790 // Handle negative dividends by negating the result of the division.
3791 HloInstruction* zero_like_a = BroadcastZeros(
3792 computation, a->shape().element_type(), a->shape().dimensions());
3793
3794 Shape compare_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
3795 simplifier->UpdateLayout(&compare_shape);
3796 auto* dividend_is_negative =
3797 computation->AddInstruction(HloInstruction::CreateCompare(
3798 compare_shape, a, zero_like_a, ComparisonDirection::kLt));
3799
3800 auto* negated_dividend = computation->AddInstruction(
3801 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
3802
3803 auto* abs_dividend =
3804 computation->AddInstruction(HloInstruction::CreateTernary(
3805 a->shape(), HloOpcode::kSelect, dividend_is_negative,
3806 negated_dividend, a));
3807
3808 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
3809 remainder->shape(), HloOpcode::kAnd, abs_dividend,
3810 MakeScalarLike(abs_dividend, b_value - 1)));
3811
3812 auto* neqated_quotient =
3813 computation->AddInstruction(HloInstruction::CreateUnary(
3814 quotient->shape(), HloOpcode::kNegate, quotient));
3815
3816 return HloInstruction::CreateTernary(
3817 remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
3818 neqated_quotient, quotient);
3819 }
3820 } else {
3821 uint64 b_value = c->literal().GetFirstElement<T>();
3822 if (IsPowerOfTwo(b_value)) {
3823 HloInstruction* mask_amount = computation->AddInstruction(
3824 simplifier->CreateConstantWithLayoutUpdated(
3825 LiteralUtil::CreateR0<T>(b_value - 1)));
3826 if (!ShapeUtil::IsScalar(b->shape())) {
3827 mask_amount = computation->AddInstruction(
3828 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
3829 }
3830 return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
3831 a, mask_amount);
3832 }
3833 }
3834 return nullptr;
3835 }
3836 } // namespace
3837
HandleRemainder(HloInstruction * remainder)3838 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
3839 HloInstruction *a, *b;
3840 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3841
3842 // (A % B) % B == A % B.
3843 if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
3844 return ReplaceInstruction(remainder, a);
3845 }
3846
3847 // A % B => A & (B - 1) if B is a power of 2.
3848 switch (remainder->shape().element_type()) {
3849 case S8:
3850 if (std::unique_ptr<HloInstruction> shift =
3851 TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
3852 return ReplaceWithNewInstruction(remainder, std::move(shift));
3853 }
3854 break;
3855 case S16:
3856 if (std::unique_ptr<HloInstruction> shift =
3857 TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
3858 return ReplaceWithNewInstruction(remainder, std::move(shift));
3859 }
3860 break;
3861 case S32:
3862 if (std::unique_ptr<HloInstruction> shift =
3863 TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
3864 return ReplaceWithNewInstruction(remainder, std::move(shift));
3865 }
3866 break;
3867 case S64:
3868 if (std::unique_ptr<HloInstruction> shift =
3869 TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
3870 return ReplaceWithNewInstruction(remainder, std::move(shift));
3871 }
3872 break;
3873 case U8:
3874 if (std::unique_ptr<HloInstruction> shift =
3875 TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
3876 return ReplaceWithNewInstruction(remainder, std::move(shift));
3877 }
3878 break;
3879 case U16:
3880 if (std::unique_ptr<HloInstruction> shift =
3881 TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
3882 return ReplaceWithNewInstruction(remainder, std::move(shift));
3883 }
3884 break;
3885 case U32:
3886 if (std::unique_ptr<HloInstruction> shift =
3887 TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
3888 return ReplaceWithNewInstruction(remainder, std::move(shift));
3889 }
3890 break;
3891 case U64:
3892 if (std::unique_ptr<HloInstruction> shift =
3893 TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
3894 return ReplaceWithNewInstruction(remainder, std::move(shift));
3895 }
3896 break;
3897 default:
3898 break;
3899 }
3900
3901 // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
3902 //
3903 // Currently this only covers the case when N is a broadcasted constant
3904 // scalar. We could also cover the case when N is a non-broadcasted constant
3905 // with the same value repeated.
3906 HloInstruction* iota;
3907 HloInstruction* divisor;
3908 if (Match(remainder,
3909 m::Remainder(m::Iota(&iota),
3910 m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
3911 // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
3912 // conservative; the iota may overflow and count up to a smaller value than
3913 // this. But that's OK for our purposes here.)
3914 int64_t iota_upper_bound = iota->shape().dimensions(
3915 Cast<HloIotaInstruction>(iota)->iota_dimension());
3916 absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3917 std::vector<int64>(0, divisor->shape().dimensions_size()));
3918 if (divisor_val && *divisor_val >= iota_upper_bound) {
3919 return ReplaceInstruction(remainder, iota);
3920 }
3921 }
3922
3923 // (X + N) % N = X % N, so long as X + N does not overflow.
3924 //
3925 // We don't have range tracking in XLA that would let us know whether X + N
3926 // overflows, so for now we only do this simplification when X is an iota. We
3927 // could add other operations where it's easy to see a range, such as
3928 // remainder, convert, etc., though at some point we'd probably want a
3929 // range-tracking analysis.
3930 HloInstruction* bcast;
3931 HloInstruction* addend;
3932 if (Match(
3933 remainder,
3934 m::Remainder(
3935 m::AddAnyOrder(m::Iota(&iota),
3936 m::Broadcast(m::ConstantEffectiveScalar(&addend))),
3937 m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
3938 addend == divisor) {
3939 // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
3940 // that iota_upper_bound is conservative, and the true upper bound may be
3941 // smaller.
3942 int64_t iota_upper_bound = iota->shape().dimensions(
3943 Cast<HloIotaInstruction>(iota)->iota_dimension());
3944 absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3945 std::vector<int64>(0, divisor->shape().dimensions_size()));
3946 if (divisor_val) {
3947 // Check whether divisor_val + iota_upper_bound - 1 overflows.
3948 absl::optional<int64> max_val =
3949 OverflowSafeAdd(*divisor_val, iota_upper_bound);
3950 if (max_val.has_value() &&
3951 FitsInIntegralType(*max_val, iota->shape().element_type())) {
3952 return ReplaceWithNewInstruction(
3953 remainder,
3954 HloInstruction::CreateBinary(remainder->shape(),
3955 HloOpcode::kRemainder, iota, bcast));
3956 }
3957 }
3958 }
3959
3960 return Status::OK();
3961 }
3962
HandleReshape(HloInstruction * reshape)3963 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
3964 auto operand = reshape->mutable_operand(0);
3965
3966 // Reshape directly to empty constant if the shape contains zero-element
3967 // dimension.
3968 if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
3969 // If the instruction doesn't have a layout, use a default layout for
3970 // the literal result.
3971 Shape reshaped_shape = reshape->shape();
3972 if (!LayoutUtil::HasLayout(reshaped_shape)) {
3973 LayoutUtil::SetToDefaultLayout(&reshaped_shape);
3974 }
3975 auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
3976 Literal::CreateFromShape(reshaped_shape));
3977
3978 return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
3979 }
3980
3981 // Delete no-op reshapes, i.e. where shape = operand shape.
3982 if (SameShape(reshape, operand)) {
3983 VLOG(10) << "deleting no-op reshape";
3984 return ReplaceInstruction(reshape, operand);
3985 }
3986
3987 // Merge reshapes.
3988 if (HloOpcode::kReshape == operand->opcode()) {
3989 return ReplaceWithNewInstruction(
3990 reshape, HloInstruction::CreateReshape(reshape->shape(),
3991 operand->mutable_operand(0)));
3992 }
3993
3994 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3995 *operand->mutable_shape() = reshape->shape();
3996 return ReplaceInstruction(reshape, operand);
3997 }
3998
3999 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
4000 auto opt_dims = ReshapeLeavesDimensionsUnmodified(
4001 reshape, reshape->operand(0)->dimensions());
4002 if (opt_dims.has_value()) {
4003 return ReplaceWithNewInstruction(
4004 reshape,
4005 HloInstruction::CreateBroadcast(
4006 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
4007 *opt_dims));
4008 }
4009 }
4010
4011 // reshape(iota) -> iota or a mixed radix calculation like
4012 // s32[2,3,4] reshape(s32[24] iota()) to
4013 // add(
4014 // add(s32[2,3,4] iota() iota_dimension=2,
4015 // 4 * s32[2,3,4] iota() iota_dimension=1),
4016 // 12 * s32[2,3,4] iota() iota_dimension=0).
4017 if (operand->opcode() == HloOpcode::kIota) {
4018 auto* iota = Cast<HloIotaInstruction>(operand);
4019 auto common_factors =
4020 CommonFactors(reshape->operand(0)->shape().dimensions(),
4021 reshape->shape().dimensions());
4022 auto iota_dim = absl::c_find_if(
4023 common_factors, [&](const std::pair<int64, int64>& dim_pair) {
4024 return dim_pair.first == iota->iota_dimension() &&
4025 reshape->shape().dimensions(dim_pair.second) > 1;
4026 });
4027 auto next_dim = absl::c_find_if(
4028 common_factors, [&](const std::pair<int64, int64>& dim_pair) {
4029 return dim_pair.first == iota->iota_dimension() + 1;
4030 });
4031 if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
4032 int64_t multiplier = 1;
4033 HloInstruction* new_reshape = nullptr;
4034
4035 for (int64_t dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
4036 --dim) {
4037 HloInstruction* new_iota = computation_->AddInstruction(
4038 HloInstruction::CreateIota(reshape->shape(), dim));
4039 iota->SetupDerivedInstruction(new_iota);
4040 if (new_reshape) {
4041 new_reshape =
4042 computation_->AddInstruction(HloInstruction::CreateBinary(
4043 reshape->shape(), HloOpcode::kAdd, new_reshape,
4044 computation_->AddInstruction(HloInstruction::CreateBinary(
4045 reshape->shape(), HloOpcode::kMultiply, new_iota,
4046 MakeScalarLike(reshape, multiplier)))));
4047 reshape->SetupDerivedInstruction(new_reshape);
4048 } else {
4049 new_reshape = new_iota;
4050 }
4051 multiplier *= reshape->shape().dimensions(dim);
4052 }
4053 reshape->SetupDerivedInstruction(new_reshape);
4054 return ReplaceInstruction(reshape, new_reshape);
4055 }
4056 }
4057
4058 // Moves the reshape in reshape(dus(...), x, ...)) before dus so that it can
4059 // enable other optimizations, e.g., merging with broadcast, and sparse update
4060 // (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
4061 if (!options_.is_layout_sensitive()) {
4062 bool trivial_reshape;
4063 std::vector<int64> deleted_dims;
4064 std::vector<int64> inserted_dims;
4065
4066 HloInstruction* dus;
4067 HloInstruction* slice;
4068 std::tie(trivial_reshape, deleted_dims, inserted_dims) =
4069 reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
4070 // 1-sized dimensions added and removed will be one sized in both the update
4071 // slice and the dynamic-update-slice result.
4072 if (trivial_reshape &&
4073 Match(reshape->mutable_operand(0),
4074 m::Op(&dus)
4075 .WithOpcode(HloOpcode::kDynamicUpdateSlice)
4076 .WithOperand(1, m::Op(&slice))) &&
4077 !dus->has_sharding() && !dus->operand(0)->has_sharding()) {
4078 auto new_operand =
4079 computation_->AddInstruction(HloInstruction::CreateReshape(
4080 reshape->shape(), dus->mutable_operand(0)));
4081 std::vector<int64> new_slice_shape;
4082 std::vector<HloInstruction*> new_dus_operands;
4083 new_dus_operands.push_back(new_operand);
4084 new_dus_operands.push_back(nullptr);
4085 auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
4086 const Shape& old_slice_shape = dus->operand(1)->shape();
4087 for (int64 i = 0; i <= old_slice_shape.rank(); ++i) {
4088 if (absl::c_linear_search(deleted_dims, i)) {
4089 continue;
4090 }
4091 while (absl::c_linear_search(inserted_dims, new_slice_shape.size())) {
4092 new_slice_shape.push_back(1);
4093 new_dus_operands.push_back(zero);
4094 }
4095 if (i < old_slice_shape.rank()) {
4096 new_slice_shape.push_back(old_slice_shape.dimensions(i));
4097 new_dus_operands.push_back(dus->mutable_operand(2 + i));
4098 }
4099 }
4100 auto new_slice =
4101 computation_->AddInstruction(HloInstruction::CreateReshape(
4102 ShapeUtil::MakeShape(old_slice_shape.element_type(),
4103 new_slice_shape),
4104 slice));
4105 new_dus_operands[1] = new_slice;
4106 auto new_dus =
4107 dus->CloneWithNewOperands(reshape->shape(), new_dus_operands);
4108 return ReplaceWithNewInstruction(reshape, std::move(new_dus));
4109 }
4110 }
4111
4112 // Make this a bitcast if possible.
4113 if (HloInstruction* bitcast_operand =
4114 BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
4115 ReplaceWithBitcast(reshape, bitcast_operand);
4116 }
4117 return Status::OK();
4118 }
4119
HandleReverse(HloInstruction * reverse)4120 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
4121 // When all the dimensions to reverse are trivial (i.e. the bound is 1),
4122 // there is nothing to be done.
4123 auto dim_is_one = [&](int64_t i) -> bool {
4124 return reverse->shape().dimensions(i) == 1;
4125 };
4126 if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
4127 return ReplaceInstruction(reverse, reverse->mutable_operand(0));
4128 }
4129 return Status::OK();
4130 }
4131
TrySimplifyScalarSlice(HloInstruction * slice)4132 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
4133 HloInstruction* slice) {
4134 // Only try to do this for effective scalars. We could do the same for slicing
4135 // out larger pieces of padding (replacing with a broadcast of the padding
4136 // value), but this is probably not worth it.
4137 if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
4138 return false;
4139 }
4140
4141 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
4142 VLOG(10) << "Trying to simplify scalar slice of concat";
4143 // Only do this for R1, there's no chance of this being useful otherwise.
4144 if (slice->shape().rank() != 1) {
4145 VLOG(10) << "Not folding, slice is not rank 1";
4146 return false;
4147 }
4148 HloConcatenateInstruction* concat =
4149 Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4150 int64_t operand_start = 0;
4151 int64_t operand_num = 0;
4152 // Weird loop structure to avoid annoying off-by-one errors.
4153 while (true) {
4154 TF_RET_CHECK(operand_num < concat->operand_count());
4155 const HloInstruction* operand = concat->operand(operand_num);
4156 int64_t next_operand_start =
4157 operand_start + operand->shape().dimensions(0);
4158 if (next_operand_start > slice->slice_starts(0)) {
4159 break;
4160 }
4161 operand_start = next_operand_start;
4162 operand_num++;
4163 }
4164
4165 bool replaced = ReplaceInstructionIfSameShape(
4166 slice, concat->mutable_operand(operand_num));
4167 if (replaced) {
4168 VLOG(10) << "Folding scalar slice of concat into concat operand";
4169 } else {
4170 VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4171 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4172 slice, HloInstruction::CreateSlice(
4173 slice->shape(), concat->mutable_operand(operand_num),
4174 {slice->slice_starts(0) - operand_start},
4175 {slice->slice_starts(0) - operand_start + 1},
4176 slice->slice_strides())));
4177 }
4178 return true;
4179 }
4180
4181 return false;
4182 }
4183
TryToReorderSliceAndReshape(HloInstruction * slice)4184 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4185 HloInstruction* slice) {
4186 CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4187 if (!IsUnstridedSlice(slice)) {
4188 return false;
4189 }
4190 HloInstruction* reshape = slice->mutable_operand(0);
4191 if (reshape->opcode() != HloOpcode::kReshape) {
4192 return false;
4193 }
4194 HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4195 int64_t slice_rank = slice->shape().rank();
4196 std::vector<int64> sliced_dims;
4197 for (int64_t i = 0; i < slice_rank; ++i) {
4198 if (slice->slice_starts(i) != 0 ||
4199 slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4200 sliced_dims.push_back(i);
4201 }
4202 }
4203
4204 if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4205 slice->slice_starts(0) == 0) {
4206 const Shape& new_slice_shape = new_slice_operand->shape();
4207 const int64_t rank = new_slice_shape.rank();
4208 std::vector<int64> new_slice_starts(rank, 0);
4209 std::vector<int64> new_slice_stides(rank, 1);
4210 std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
4211 new_slice_shape.dimensions().end());
4212 int64_t slice_elements = ShapeUtil::ElementsIn(slice->shape());
4213 for (int64_t i = rank - 1; i >= 0; --i) {
4214 if (slice_elements >= new_slice_limits[i]) {
4215 if (slice_elements % new_slice_limits[i] != 0) {
4216 return false;
4217 }
4218 slice_elements /= new_slice_limits[i];
4219 } else {
4220 new_slice_limits[i] = slice_elements;
4221 slice_elements = 1;
4222 }
4223 }
4224 HloInstruction* new_slice =
4225 computation_->AddInstruction(HloInstruction::CreateSlice(
4226 ShapeUtil::MakeShape(new_slice_shape.element_type(),
4227 new_slice_limits),
4228 new_slice_operand, new_slice_starts, new_slice_limits,
4229 new_slice_stides));
4230 simplifier_->UpdateLayout(new_slice->mutable_shape());
4231 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4232 slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4233 return true;
4234 }
4235 return false;
4236 }
4237
4238 // Allowing a slice to move through a reverse with any necessary updates to the
4239 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4240 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4241 HloInstruction* slice) {
4242 VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4243 << slice->ToString();
4244 if (Match(slice, m::Slice(m::Reverse()))) {
4245 HloInstruction* reverse = slice->mutable_operand(0);
4246 HloInstruction* reverse_operand = reverse->mutable_operand(0);
4247 std::vector<int64> new_starts = slice->slice_starts();
4248 std::vector<int64> new_limits = slice->slice_limits();
4249 std::vector<int64> new_strides = slice->slice_strides();
4250 for (auto rdim : reverse->dimensions()) {
4251 int64_t start = slice->slice_starts(rdim);
4252 int64_t limit = slice->slice_limits(rdim);
4253 int64_t stride = slice->slice_strides(rdim);
4254 // find_nth allows us to compute the appropriate index to begin
4255 // with during reverse even in the presence of non-unit strides
4256 int64_t find_nth = (limit - start - 1) / stride;
4257 find_nth = start + find_nth * stride;
4258 limit = find_nth + 1;
4259 new_starts[rdim] =
4260 (reverse->shape().dimensions(rdim) - start) - (limit - start);
4261 new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4262 VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4263 << limit << " and new (start, limit):" << new_starts[rdim] << ","
4264 << new_limits[rdim];
4265 }
4266 // New slice formed from the reverse_operand, but strides and shape of the
4267 // slice output remains the same. New slice's starts and limits are updated
4268 // for ONLY the reversed dimensions as indicated above.
4269 HloInstruction* new_slice = computation_->AddInstruction(
4270 HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4271 new_limits, new_strides));
4272 simplifier_->UpdateLayout(new_slice->mutable_shape());
4273 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4274 slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4275 reverse->dimensions())));
4276 // We do not delete the old reverse, since there might be another
4277 // consumer of that reverse (i.e., full reverse output). DCE should take
4278 // care of any deletion that is necessary if there was no use of reverse.
4279 return true;
4280 }
4281 return false;
4282 }
4283
HandleSlice(HloInstruction * slice)4284 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4285 // Delete no-op slices, i.e. where shape = operand shape.
4286 if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
4287 return Status::OK();
4288 }
4289
4290 HloInstruction* pad;
4291 HloInstruction* pad_operand;
4292 if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4293 // Is the result of the slice the pad operand.
4294 bool slice_undoes_pad = true;
4295 // Can the slice be moved to the pad_operand without any padding being read.
4296 bool slice_inside_pad = true;
4297 // Does this slice slice out pading only.
4298 bool slice_in_padding = false;
4299 std::vector<int64> new_starts = slice->slice_starts();
4300 std::vector<int64> new_limits = slice->slice_limits();
4301 for (int64_t i = 0; i < slice->shape().rank(); ++i) {
4302 const int64_t start = slice->slice_starts(i);
4303 const int64_t stride = slice->slice_strides(i);
4304 const int64_t limit = slice->slice_limits(i);
4305 const int64_t size = pad->shape().dimensions(i);
4306
4307 const auto& dim = pad->padding_config().dimensions(i);
4308 const int64_t low = dim.edge_padding_low();
4309 const int64_t high = dim.edge_padding_high();
4310 const int64_t interior = dim.interior_padding();
4311 const int64_t edge = size - high;
4312
4313 if (limit <= low || start >= edge) {
4314 slice_in_padding = true;
4315 break;
4316 }
4317
4318 if (start != low || stride - 1 != interior) {
4319 slice_undoes_pad = false;
4320 }
4321
4322 if (start < low || limit > edge || interior != 0 || stride != 1) {
4323 slice_inside_pad = false;
4324 }
4325 new_starts[i] -= low;
4326 new_limits[i] -= low;
4327 }
4328 if (slice_in_padding) {
4329 HloInstruction* broadcast =
4330 MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4331 *(broadcast->mutable_shape()) = slice->shape();
4332 return ReplaceInstruction(slice, broadcast);
4333 }
4334 if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
4335 return Status::OK();
4336 }
4337 if (slice_inside_pad) {
4338 TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4339 MakeSliceHlo(pad_operand, new_starts, new_limits,
4340 slice->slice_strides()));
4341 *(new_slice->mutable_shape()) = slice->shape();
4342 return ReplaceInstruction(slice, new_slice);
4343 }
4344 }
4345
4346 if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4347 IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4348 HloInstruction* operand_slice = slice->mutable_operand(0);
4349 std::vector<int64> new_slice_starts = slice->slice_starts();
4350 std::vector<int64> new_slice_limits = slice->slice_limits();
4351 for (int64_t i = 0; i < new_slice_starts.size(); ++i) {
4352 new_slice_starts[i] += operand_slice->slice_starts(i);
4353 new_slice_limits[i] += operand_slice->slice_starts(i);
4354 }
4355 return ReplaceWithNewInstruction(
4356 slice, HloInstruction::CreateSlice(
4357 slice->shape(), operand_slice->mutable_operand(0),
4358 new_slice_starts, new_slice_limits, slice->slice_strides()));
4359 }
4360
4361 auto only_broadcast_dims_sliced = [&] {
4362 if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4363 return false;
4364 }
4365 for (int64_t dim : slice->operand(0)->dimensions()) {
4366 if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4367 slice->slice_limits(dim) !=
4368 slice->operand(0)->shape().dimensions(dim)) {
4369 return false;
4370 }
4371 }
4372 return true;
4373 };
4374 if (only_broadcast_dims_sliced()) {
4375 return ReplaceWithNewInstruction(
4376 slice,
4377 HloInstruction::CreateBroadcast(
4378 slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4379 slice->mutable_operand(0)->dimensions()));
4380 }
4381
4382 TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4383 if (replaced) {
4384 return Status::OK();
4385 }
4386
4387 HloInstruction* broadcast;
4388 HloInstruction* broadcast_operand;
4389 if (Match(slice,
4390 m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4391 std::vector<int64> new_slice_starts;
4392 std::vector<int64> new_slice_strides;
4393 std::vector<int64> new_slice_limits;
4394 new_slice_starts.reserve(broadcast_operand->shape().rank());
4395 new_slice_strides.reserve(broadcast_operand->shape().rank());
4396 new_slice_limits.reserve(broadcast_operand->shape().rank());
4397 for (int64_t dim : broadcast->dimensions()) {
4398 new_slice_starts.push_back(slice->slice_starts(dim));
4399 new_slice_strides.push_back(slice->slice_strides(dim));
4400 new_slice_limits.push_back(slice->slice_limits(dim));
4401 }
4402 VLOG(3) << "Sink broadcast through slice";
4403 VLOG(3) << "Original slice: " << slice->ToString();
4404 VLOG(3) << "Original broadcast: " << broadcast->ToString();
4405 auto new_slice_shape = broadcast_operand->shape();
4406 for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4407 int64_t size_i = (new_slice_limits[i] - new_slice_starts[i] +
4408 new_slice_strides[i] - 1) /
4409 new_slice_strides[i];
4410 new_slice_shape.set_dimensions(i, size_i);
4411 }
4412 simplifier_->UpdateLayout(&new_slice_shape);
4413 HloComputation* computation = broadcast_operand->parent();
4414 auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice(
4415 new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4416 new_slice_strides));
4417 auto new_broadcast = HloInstruction::CreateBroadcast(
4418 slice->shape(), new_slice, broadcast->dimensions());
4419 VLOG(3) << "New slice: " << slice->ToString();
4420 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4421 return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
4422 }
4423
4424 // Try to simplify concat -> slice to an operand of concat.
4425 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4426 IsUnstridedSlice(slice)) {
4427 auto concat = slice->operand(0);
4428 int64_t concat_dim = concat->concatenate_dimension();
4429 int64_t piece_start = 0;
4430 for (auto piece : concat->operands()) {
4431 if (!SameShape(piece, slice)) {
4432 piece_start += piece->shape().dimensions(concat_dim);
4433 continue;
4434 }
4435 if (slice->slice_starts(concat_dim) == piece_start) {
4436 return ReplaceInstruction(slice, piece);
4437 }
4438 piece_start += piece->shape().dimensions(concat_dim);
4439 }
4440 }
4441
4442 // Do not try to reorder slices and reshapes after layout assignment as it may
4443 // be invalid.
4444 if (!options_.is_layout_sensitive()) {
4445 TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4446 }
4447 if (replaced) {
4448 return Status::OK();
4449 }
4450
4451 bool reversed = false;
4452 if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4453 TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4454 }
4455 if (reversed) {
4456 return Status::OK();
4457 }
4458
4459 return Status::OK();
4460 }
4461
HandleRsqrt(HloInstruction * rsqrt)4462 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4463 VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4464 << rsqrt->ToString();
4465 HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4466 if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4467 IsAll(rsqrt_operand->operand(1), -2) &&
4468 IsPositive(rsqrt_operand, options_)) {
4469 return ReplaceWithNewInstruction(
4470 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4471 rsqrt_operand->mutable_operand(0)));
4472 }
4473
4474 VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4475 << rsqrt->ToString();
4476 if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4477 IsAll(rsqrt_operand->operand(0), 1) &&
4478 IsPositive(rsqrt_operand->operand(1), options_)) {
4479 return ReplaceWithNewInstruction(
4480 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4481 rsqrt_operand->mutable_operand(1)));
4482 }
4483
4484 return Status::OK();
4485 }
4486
HandleDynamicSlice(HloInstruction * dynamic_slice)4487 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4488 HloInstruction* dynamic_slice) {
4489 auto operand = dynamic_slice->mutable_operand(0);
4490 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4491 return ReplaceInstruction(dynamic_slice, operand);
4492 }
4493 // DynamicSlice where operand has the same size as the output is simply equal
4494 // to operand.
4495 if (SameShape(operand, dynamic_slice)) {
4496 return ReplaceInstruction(dynamic_slice, operand);
4497 }
4498
4499 HloInstruction* broadcast_operand;
4500 if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4501 std::vector<HloInstruction*> new_indices;
4502 new_indices.reserve(broadcast_operand->shape().rank());
4503 std::vector<int64> new_slice_sizes;
4504 new_slice_sizes.reserve(broadcast_operand->shape().rank());
4505
4506 for (int64_t dim : operand->dimensions()) {
4507 new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4508 new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4509 }
4510
4511 VLOG(3) << "Sink broadcast through dynamic slice";
4512 VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4513 VLOG(3) << "Original broadcast: " << operand->ToString();
4514 HloInstruction* new_dynamic_slice = broadcast_operand;
4515 if (!new_slice_sizes.empty()) {
4516 auto new_ds_shape = broadcast_operand->shape();
4517 for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4518 new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4519 }
4520 simplifier_->UpdateLayout(&new_ds_shape);
4521 HloComputation* computation = broadcast_operand->parent();
4522 new_dynamic_slice =
4523 computation->AddInstruction(HloInstruction::CreateDynamicSlice(
4524 new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4525 }
4526 auto new_broadcast = HloInstruction::CreateBroadcast(
4527 dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
4528 VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4529 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4530 return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
4531 }
4532
4533 // Convert a dynamic slice into a slice if all offsets are constant and the
4534 // operand is not constant.
4535 if (operand->opcode() != HloOpcode::kConstant &&
4536 absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4537 dynamic_slice->operands().end()),
4538 [](HloInstruction* operand) {
4539 return operand->opcode() == HloOpcode::kConstant &&
4540 ShapeUtil::ElementIsIntegral(operand->shape());
4541 })) {
4542 const int64_t rank = operand->shape().rank();
4543 std::vector<int64> slice_starts(rank);
4544 std::vector<int64> slice_limits(rank);
4545 std::vector<int64> slice_strides(rank, 1);
4546
4547 for (int64_t i = 0; i < rank; ++i) {
4548 absl::optional<int64> offset =
4549 dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4550 if (!offset || *offset < 0) {
4551 return Status::OK();
4552 }
4553 const int64_t max_offset =
4554 dynamic_slice->operand(0)->shape().dimensions(i) -
4555 dynamic_slice->shape().dimensions(i);
4556 slice_starts[i] = std::min(max_offset, *offset);
4557 slice_limits[i] =
4558 std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4559 }
4560 return ReplaceWithNewInstruction(
4561 dynamic_slice,
4562 HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4563 slice_starts, slice_limits, slice_strides));
4564 }
4565
4566 // Convert the dynamic slice of an iota to just a reference to the index
4567 // (possibly clamped and scaled). Index is always a scalar integer. Output
4568 // should be a rank 1 array of size 1 with element type matching that of the
4569 // scalar index (except the signedness).
4570 const PrimitiveType element_type = dynamic_slice->shape().element_type();
4571 if (operand->shape().rank() == 1 && dynamic_slice->shape().rank() == 1 &&
4572 dynamic_slice->shape().dimensions(0) == 1 &&
4573 (element_type == S32 || element_type == U32)) {
4574 // Match multiply(x, broadcast(scalar)) and return the scalar
4575 // constant.
4576 auto match_multiply_with_scalar =
4577 [&](HloInstruction* hlo) -> HloInstruction* {
4578 if (hlo->opcode() != HloOpcode::kMultiply) {
4579 return nullptr;
4580 }
4581 HloInstruction* broadcast = hlo->mutable_operand(1);
4582 if (broadcast->opcode() == HloOpcode::kBroadcast &&
4583 broadcast->dimensions().empty() &&
4584 ShapeUtil::IsScalar(broadcast->operand(0)->shape())) {
4585 return broadcast->mutable_operand(0);
4586 }
4587 return nullptr;
4588 };
4589
4590 HloInstruction* multiplier = match_multiply_with_scalar(operand);
4591 if (multiplier) {
4592 operand = operand->mutable_operand(0);
4593 }
4594
4595 if (operand->opcode() == HloOpcode::kIota) {
4596 // This dynamic_slice will have a single start_index operand (since its
4597 // operand is rank 1).
4598 HloInstruction* index = dynamic_slice->mutable_operand(1);
4599 const PrimitiveType index_type = index->shape().element_type();
4600
4601 auto create_constant = [&](int64_t value) {
4602 if (index_type == S32) {
4603 return MakeScalarLike<int32_t>(index, value);
4604 } else {
4605 return MakeScalarLike<uint32_t>(index, value);
4606 }
4607 };
4608
4609 if (index_type == S32 || index_type == U32) {
4610 // Clamp the index to the range of the iota.
4611 int64_t iota_size = operand->shape().dimensions(0);
4612 HloInstruction* low = create_constant(0);
4613 HloInstruction* high = create_constant(iota_size - 1);
4614 HloInstruction* clamped =
4615 computation_->AddInstruction(HloInstruction::CreateTernary(
4616 index->shape(), HloOpcode::kClamp, low, index, high));
4617
4618 // Convert the clamped index from index_type to element_type and
4619 // multiply with the multiplier.
4620 HloInstruction* result = clamped;
4621 if (index_type != element_type) {
4622 result = computation_->AddInstruction(HloInstruction::CreateConvert(
4623 ShapeUtil::MakeScalarShape(element_type), clamped));
4624 }
4625
4626 if (multiplier) {
4627 result = computation_->AddInstruction(HloInstruction::CreateBinary(
4628 result->shape(), HloOpcode::kMultiply, result, multiplier));
4629 }
4630
4631 return ReplaceWithNewInstruction(
4632 dynamic_slice,
4633 HloInstruction::CreateReshape(
4634 ShapeUtil::MakeShape(element_type, {1}), result));
4635 }
4636 }
4637 }
4638
4639 return Status::OK();
4640 }
4641
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)4642 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
4643 HloInstruction* dynamic_update_slice) {
4644 // Rewriting DynamicUpdateSlice when it matches
4645 // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
4646 // to a Pad(x, constant)
4647 // Only Broadcast considered currently, other ops need to be considered
4648 // in the future.
4649 HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
4650 HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
4651 HloInstruction* pad_value;
4652 if (Match(updated,
4653 m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
4654 auto updated_shape = updated->shape();
4655 auto update_shape = dus_update->shape();
4656 auto update_start_indx = dynamic_update_slice->operand(2);
4657 int64_t offset = 0;
4658 bool compatible = true;
4659 // Whether the start indices to dynamic update slice is a list,
4660 // output of a tuple/concatenate, we setup the update_start_indx
4661 // appropriately.
4662 if (ShapeUtil::IsScalar(update_start_indx->shape())) {
4663 update_start_indx = dynamic_update_slice;
4664 offset = 2;
4665 } else {
4666 if (update_start_indx->opcode() == HloOpcode::kTuple ||
4667 update_start_indx->opcode() == HloOpcode::kConcatenate) {
4668 offset = 0;
4669 } else {
4670 compatible = false;
4671 }
4672 }
4673 PaddingConfig padding_config;
4674 if (compatible) {
4675 for (int64_t dim = 0; dim < updated_shape.rank(); ++dim) {
4676 auto padding_config_dim = padding_config.add_dimensions();
4677 auto slice_dim_start = update_start_indx->operand(dim + offset);
4678 if (!Match(slice_dim_start, m::ConstantScalar())) {
4679 compatible = false;
4680 break;
4681 }
4682 VLOG(2) << "slice: " << slice_dim_start->ToString();
4683 absl::optional<int64> beg =
4684 slice_dim_start->literal().GetFirstInteger();
4685 if (!beg) {
4686 compatible = false;
4687 break;
4688 }
4689 VLOG(2) << "beg value: " << *beg;
4690 auto update_width = ShapeUtil::GetDimension(update_shape, dim);
4691 auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
4692 // Clamp beg so that it is non-negative.
4693 *beg = std::max<int64>(0, *beg);
4694 // Clamp beg so that it is in-bounds.
4695 *beg = std::min<int64>(bcast_width - update_width, *beg);
4696 VLOG(2) << "adjusted beg value: " << *beg;
4697 padding_config_dim->set_edge_padding_low(*beg);
4698 padding_config_dim->set_edge_padding_high(bcast_width -
4699 (*beg + update_width));
4700 // dynamic_update_slice does not specify a stride
4701 padding_config_dim->set_interior_padding(0);
4702 }
4703 }
4704
4705 if (compatible) {
4706 HloInstruction* pad =
4707 computation_->AddInstruction(HloInstruction::CreatePad(
4708 updated_shape, dus_update, pad_value, padding_config));
4709 VLOG(2) << dynamic_update_slice->ToString();
4710 VLOG(2) << " with pad:" << pad->ToString();
4711 VLOG(2) << " Computation before rewrite is: "
4712 << dynamic_update_slice->parent()->ToString();
4713 return ReplaceInstruction(dynamic_update_slice, pad);
4714 }
4715 }
4716
4717 // DynamicUpdateSlice where operand and dus_update have the same size is
4718 // equal to dus_update.
4719 if (SameShape(dynamic_update_slice, dus_update)) {
4720 return ReplaceInstruction(dynamic_update_slice, dus_update);
4721 }
4722
4723 // If any dimension of dus_update is 0, elide the DynamicUpdateSlice. This
4724 // optimization becomes invalid should we later prefer to warn about out of
4725 // bound indices.
4726 if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
4727 return ReplaceInstruction(dynamic_update_slice, updated);
4728 }
4729 return Status::OK();
4730 }
4731
HandleReduce(HloInstruction * hlo)4732 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
4733 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
4734 bool multi_output_reduce = reduce->shape().IsTuple();
4735 // For tuple reduce, we require all reduce shapes to be the same, up to the
4736 // element types, so we can just the first operand and the first result as a
4737 // representative.
4738 auto arg = reduce->inputs()[0];
4739 auto init_value = reduce->init_values()[0];
4740 const Shape& reduce_result_shape =
4741 multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
4742
4743 absl::Span<const int64> dimensions(reduce->dimensions());
4744 HloComputation* function = reduce->to_apply();
4745 if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
4746 ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
4747 if (multi_output_reduce) {
4748 std::vector<HloInstruction*> broadcast_inits;
4749 int64_t inputs = reduce->input_count();
4750 for (int64_t i = 0; i < inputs; ++i) {
4751 broadcast_inits.push_back(computation_->AddInstruction(
4752 HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
4753 reduce->init_values()[i], {})));
4754 }
4755 return ReplaceWithNewInstruction(
4756 reduce, HloInstruction::CreateTuple(broadcast_inits));
4757 } else {
4758 return ReplaceWithNewInstruction(
4759 reduce,
4760 HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
4761 }
4762 }
4763
4764 // Turn trivial variadic reductions into normal reductions.
4765 if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
4766 reduce->input_count() == 1 &&
4767 Match(function->root_instruction(), m::Tuple())) {
4768 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
4769 replacements;
4770 replacements[function->root_instruction()] = nullptr;
4771 auto new_function = computation_->parent()->AddEmbeddedComputation(
4772 function->CloneWithReplacements(
4773 std::move(replacements), /*extra_parameters=*/{},
4774 /*context=*/nullptr,
4775 /*suffix=*/"clone",
4776 /*new_root=*/function->root_instruction()->operand(0)));
4777 auto new_reduce = computation_->AddInstruction(
4778 HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
4779 reduce->dimensions(), new_function));
4780 return ReplaceWithNewInstruction(reduce,
4781 HloInstruction::CreateTuple({new_reduce}));
4782 }
4783
4784 if (options_.is_layout_sensitive()) {
4785 return Status::OK();
4786 }
4787
4788 // If the reduction results in the same number of elements, then the only
4789 // possible side effect would be a reshape. Since the init_value is an
4790 // identity of the reduction function, we can therefore replace the reduce
4791 // with a simple reshape, ignoring the reduction function completely.
4792 if (ShapeUtil::ElementsIn(reduce_result_shape) ==
4793 ShapeUtil::ElementsIn(arg->shape())) {
4794 if (multi_output_reduce) {
4795 std::vector<HloInstruction*> reshaped_args;
4796 int64_t inputs = reduce->input_count();
4797 for (int64_t i = 0; i < inputs; ++i) {
4798 reshaped_args.push_back(
4799 computation_->AddInstruction(HloInstruction::CreateReshape(
4800 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
4801 }
4802 return ReplaceWithNewInstruction(
4803 reduce, HloInstruction::CreateTuple(reshaped_args));
4804 } else {
4805 return ReplaceWithNewInstruction(
4806 reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
4807 }
4808 }
4809
4810 // TODO(b/131122694): Most of those optimizations below can be done for
4811 // multi-output reduces.
4812 if (multi_output_reduce) {
4813 return Status::OK();
4814 }
4815
4816 // A Transpose feeding a reduce can simply permute the reduction dimensions
4817 // field if the output of the reduce is a vector or scalar. Higher ranked
4818 // result may require a transpose of the output.
4819 if (arg->opcode() == HloOpcode::kTranspose &&
4820 (reduce->shape().rank() < 2 || arg->user_count() == 1 ||
4821 absl::c_all_of(arg->users(), [](HloInstruction* use) {
4822 return use->opcode() == HloOpcode::kReduce;
4823 }))) {
4824 auto transpose_dimensions = arg->dimensions();
4825 std::vector<int64> new_reduce_dimensions;
4826 for (auto dim : dimensions) {
4827 new_reduce_dimensions.push_back(transpose_dimensions[dim]);
4828 }
4829
4830 Shape new_reduce_result_shape = ShapeUtil::FilterDimensions(
4831 [&](const int64_t dim) {
4832 return !absl::c_linear_search(new_reduce_dimensions, dim);
4833 },
4834 arg->mutable_operand(0)->shape());
4835 HloInstruction* new_reduce =
4836 computation_->AddInstruction(HloInstruction::CreateReduce(
4837 new_reduce_result_shape, arg->mutable_operand(0), init_value,
4838 new_reduce_dimensions, function));
4839 reduce->SetupDerivedInstruction(new_reduce);
4840 std::vector<int64> new_transpose_dimensions;
4841 for (auto dim : transpose_dimensions) {
4842 if (absl::c_linear_search(new_reduce_dimensions, dim)) {
4843 continue;
4844 }
4845 new_transpose_dimensions.push_back(dim);
4846 }
4847
4848 // If new transpose dimensions are sorted, then there is no need to
4849 // transpose reduce result.
4850 if (absl::c_is_sorted(new_transpose_dimensions)) {
4851 return ReplaceInstruction(reduce, new_reduce);
4852 }
4853 for (auto& d : new_transpose_dimensions) {
4854 auto old_dim = d;
4855 for (auto reduced_dim : new_reduce_dimensions) {
4856 if (old_dim > reduced_dim) {
4857 --d;
4858 }
4859 }
4860 }
4861 TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose,
4862 MakeTransposeHlo(new_reduce, new_transpose_dimensions));
4863 return ReplaceInstruction(reduce, new_transpose);
4864 }
4865
4866 // If a reduce feeds a reduce with the same computation and initial value,
4867 // they can be combined into a single reduce.
4868 if (arg->opcode() == HloOpcode::kReduce &&
4869 init_value->Identical(*arg->operand(1)) &&
4870 *function == *arg->to_apply()) {
4871 // Create a new reduce with the combined reduction dimensions of both
4872 // reduces.
4873 std::vector<int64> arg_dims = arg->dimensions();
4874 absl::c_sort(arg_dims);
4875 std::vector<int64> reduce_dims = reduce->dimensions();
4876 absl::c_sort(reduce_dims);
4877 // Transform reduce_dims to the same rank as the operand of the operand.
4878 for (int64_t arg_dim : arg_dims) {
4879 for (int64& dim : reduce_dims) {
4880 if (dim >= arg_dim) {
4881 ++dim;
4882 }
4883 }
4884 }
4885 std::vector<int64> new_dimensions;
4886 new_dimensions.reserve(arg->dimensions().size() +
4887 reduce->dimensions().size());
4888 std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
4889 reduce_dims.end(), std::back_inserter(new_dimensions));
4890 return ReplaceWithNewInstruction(
4891 reduce, HloInstruction::CreateReduce(
4892 reduce_result_shape, arg->mutable_operand(0), init_value,
4893 new_dimensions, function));
4894 }
4895
4896 // A reshape that collapses multiple dimensions into a dimension being
4897 // reduced can just reduce all of those dimensions instead of doing a
4898 // collapsing reshape before a reduction.
4899 if (options_.enable_reduce_of_reshape() &&
4900 arg->opcode() == HloOpcode::kReshape) {
4901 std::vector<std::pair<int64, int64>> unmodified_dims =
4902 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
4903 arg->shape());
4904 std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
4905 std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
4906 for (auto dim : dimensions) {
4907 arg_dim_in_output[dim] = false;
4908 }
4909 for (auto dim_pair : unmodified_dims) {
4910 arg_dim_unmodified[dim_pair.second] = true;
4911 }
4912 // The goal is to verify that all dimensions that are not removed in the
4913 // reduce are unmodified by the reshape. For example:
4914 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
4915 bool can_move_reshape_into_reduce = true;
4916 for (int64_t i = 0; i < arg_dim_in_output.size(); ++i) {
4917 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
4918 can_move_reshape_into_reduce = false;
4919 }
4920 }
4921 if (can_move_reshape_into_reduce) {
4922 changed_ = true;
4923 absl::flat_hash_set<int64> dimensions_not_to_reduce;
4924 for (auto dim_pair : unmodified_dims) {
4925 if (arg_dim_in_output[dim_pair.second]) {
4926 dimensions_not_to_reduce.insert(dim_pair.first);
4927 }
4928 }
4929 std::vector<int64> new_reduce_dimensions;
4930 for (int64_t i = 0; i < arg->operand(0)->shape().rank(); ++i) {
4931 if (!dimensions_not_to_reduce.contains(i)) {
4932 new_reduce_dimensions.push_back(i);
4933 }
4934 }
4935 return ReplaceWithNewInstruction(
4936 reduce, HloInstruction::CreateReduce(
4937 reduce_result_shape, arg->mutable_operand(0), init_value,
4938 new_reduce_dimensions, function));
4939 }
4940 }
4941 // Convert Reduce(concat({a,b,...})) to
4942 // map(reduce(a),map(reduce(b),...,))
4943 //
4944 // This should make fusion easier or use less memory bandwidth in the unfused
4945 // case.
4946 if (arg->opcode() == HloOpcode::kConcatenate &&
4947 absl::c_linear_search(reduce->dimensions(),
4948 arg->concatenate_dimension())) {
4949 HloInstruction* old_reduce = nullptr;
4950 for (HloInstruction* operand : arg->operands()) {
4951 HloInstruction* new_reduce = computation_->AddInstruction(
4952 HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
4953 reduce->dimensions(), function));
4954 if (old_reduce != nullptr) {
4955 new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
4956 reduce_result_shape, {old_reduce, new_reduce}, function));
4957 }
4958 old_reduce = new_reduce;
4959 }
4960 return ReplaceInstruction(reduce, old_reduce);
4961 }
4962
4963 HloInstruction *dot, *lhs, *rhs;
4964 // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
4965 // batch dimensions of the dot. The transformation supports reducing other
4966 // dimensions as well.
4967 if (options_.enable_dot_strength_reduction() &&
4968 Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
4969 Match(reduce->to_apply()->root_instruction(),
4970 m::Add(m::Parameter(), m::Parameter())) &&
4971 absl::c_any_of(reduce->dimensions(), [&](int64_t dim) {
4972 return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
4973 })) {
4974 const auto& dnums = dot->dot_dimension_numbers();
4975 DotDimensionNumbers new_dnums = dnums;
4976 new_dnums.clear_lhs_batch_dimensions();
4977 new_dnums.clear_rhs_batch_dimensions();
4978 int64_t removed_dims = 0;
4979 for (int64_t batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
4980 ++batch_dim) {
4981 if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
4982 new_dnums.add_rhs_contracting_dimensions(
4983 dnums.rhs_batch_dimensions(batch_dim));
4984 new_dnums.add_lhs_contracting_dimensions(
4985 dnums.lhs_batch_dimensions(batch_dim));
4986 ++removed_dims;
4987 } else {
4988 new_dnums.add_rhs_batch_dimensions(
4989 dnums.rhs_batch_dimensions(batch_dim));
4990 new_dnums.add_lhs_batch_dimensions(
4991 dnums.lhs_batch_dimensions(batch_dim));
4992 }
4993 }
4994 std::vector<int64> reduce_dims;
4995 for (int64_t dim : reduce->dimensions()) {
4996 if (dim >= dnums.lhs_batch_dimensions_size()) {
4997 reduce_dims.push_back(dim - removed_dims);
4998 }
4999 }
5000 TF_ASSIGN_OR_RETURN(
5001 auto new_dot,
5002 MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
5003 /*preferred_element_type=*/dot->shape().element_type()));
5004 dot->SetupDerivedInstruction(new_dot);
5005 if (reduce_dims.empty()) {
5006 return ReplaceInstruction(hlo, new_dot);
5007 }
5008 TF_ASSIGN_OR_RETURN(
5009 auto new_reduce,
5010 MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
5011 reduce->SetupDerivedInstruction(new_reduce);
5012 return ReplaceInstruction(hlo, new_reduce);
5013 }
5014 return Status::OK();
5015 }
5016
HandleReduceWindow(HloInstruction * hlo)5017 Status AlgebraicSimplifierVisitor::HandleReduceWindow(HloInstruction* hlo) {
5018 auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
5019 const bool multi_output_reduce_window = reduce_window->shape().IsTuple();
5020 auto inputs = reduce_window->inputs();
5021 auto init_values = reduce_window->init_values();
5022 auto input_count = reduce_window->input_count();
5023 auto input_shapes = reduce_window->input_shapes();
5024 auto output_shapes = reduce_window->output_shapes();
5025 auto replace_with_span = [&](const std::vector<HloInstruction*>& elements) {
5026 CHECK(multi_output_reduce_window || elements.size() == 1);
5027 if (multi_output_reduce_window) {
5028 return ReplaceWithNewInstruction(reduce_window,
5029 HloInstruction::CreateTuple(elements));
5030 }
5031 return ReplaceInstruction(reduce_window, elements[0]);
5032 };
5033 // For tuple reduce, we require all reduce shapes to be the same, up to the
5034 // element types, so we can use just the first operand and the first result as
5035 // a representative.
5036 if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) ||
5037 ShapeUtil::IsZeroElementArray(*output_shapes[0])) {
5038 std::vector<HloInstruction*> broadcast_inits;
5039 for (int64_t i = 0; i < input_count; ++i) {
5040 broadcast_inits.push_back(
5041 computation_->AddInstruction(HloInstruction::CreateBroadcast(
5042 *output_shapes[i], init_values[i], {})));
5043 }
5044 return replace_with_span(broadcast_inits);
5045 }
5046 if (ShapeUtil::IsScalar(*input_shapes[0]) &&
5047 (!multi_output_reduce_window ||
5048 reduce_window->to_apply()->root_instruction()->opcode() ==
5049 HloOpcode::kTuple)) {
5050 std::vector<HloInstruction*> maps;
5051 for (int64_t i = 0; i < input_count; ++i) {
5052 TF_RET_CHECK(ShapeUtil::IsScalar(*input_shapes[i]));
5053 TF_RET_CHECK(ShapeUtil::IsScalar(*output_shapes[i]));
5054 HloInstruction* map_computation_root;
5055 absl::flat_hash_map<const HloInstruction*,
5056 std::unique_ptr<HloInstruction>>
5057 replacements;
5058 if (multi_output_reduce_window) {
5059 map_computation_root =
5060 reduce_window->to_apply()->root_instruction()->mutable_operand(i);
5061 replacements[reduce_window->to_apply()->root_instruction()] = nullptr;
5062 } else {
5063 map_computation_root = reduce_window->to_apply()->root_instruction();
5064 }
5065 auto map_computation = computation_->parent()->AddEmbeddedComputation(
5066 reduce_window->to_apply()->CloneWithReplacements(
5067 std::move(replacements),
5068 /*extra_parameters=*/{}, nullptr, "clone", map_computation_root));
5069 auto map = computation_->AddInstruction(HloInstruction::CreateMap(
5070 reduce_window->shape(), {init_values[i], inputs[i]},
5071 map_computation));
5072 maps.push_back(map);
5073 }
5074 return replace_with_span(maps);
5075 }
5076 // Turn trivial variadic reduce windows into normal reduce windows.
5077 auto reduce_function_root = reduce_window->to_apply()->root_instruction();
5078 if (multi_output_reduce_window && input_count == 1 &&
5079 Match(reduce_function_root, m::Tuple())) {
5080 // Make a new reducer which is identical but does not have a tuple
5081 // instruction at the bottom.
5082 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5083 replacements;
5084 replacements[reduce_function_root] = nullptr;
5085 auto new_function = computation_->parent()->AddEmbeddedComputation(
5086 reduce_window->to_apply()->CloneWithReplacements(
5087 std::move(replacements), /*extra_parameters=*/{},
5088 /*context=*/nullptr,
5089 /*suffix=*/"clone",
5090 /*new_root=*/reduce_function_root->operand(0)));
5091 auto new_reduce =
5092 computation_->AddInstruction(HloInstruction::CreateReduceWindow(
5093 *output_shapes[0], inputs[0], init_values[0],
5094 reduce_window->window(), new_function));
5095 return ReplaceWithNewInstruction(reduce_window,
5096 HloInstruction::CreateTuple({new_reduce}));
5097 }
5098 // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
5099 if (multi_output_reduce_window) {
5100 return Status::OK();
5101 }
5102 auto operand = reduce_window->mutable_operand(0);
5103 auto function = reduce_window->to_apply();
5104 const Window& window = reduce_window->window();
5105
5106 if (options_.enable_window_reduce_to_reduce_replacement()) {
5107 // A reduce window can be expressed as a reduce and a reshape if all
5108 // dimensions either have a window size of one or the entire dimension. If
5109 // there is no stride, dilation, or padding, this is as easy as checking the
5110 // size of the output shape and window dimension.
5111 //
5112 // The reshape is a bitcast since it adds one-sized dimensions. Often these
5113 // ones are immediately removed as well with another reshape. The
5114 // implementation of reduce tends to be slightly more efficient at reducing
5115 // entire dimensions compared to reduce window.
5116 auto effective_reduce_dims = [&] {
5117 if (window_util::HasStride(window) || window_util::HasDilation(window) ||
5118 window_util::HasPadding(window)) {
5119 return absl::InlinedVector<int64, 8>{};
5120 }
5121 absl::InlinedVector<int64, 8> reduce_dims;
5122 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5123 if (window.dimensions(i).size() == 1) {
5124 continue;
5125 } else if (reduce_window->shape().dimensions(i) == 1) {
5126 reduce_dims.push_back(i);
5127 } else {
5128 return absl::InlinedVector<int64, 8>{};
5129 }
5130 }
5131 return reduce_dims;
5132 }();
5133
5134 // If a reduce window can be expressed as a reduce, do so and reshape the
5135 // output.
5136 if (!effective_reduce_dims.empty()) {
5137 Shape reduce_shape = ShapeUtil::FilterDimensions(
5138 [&](int64_t dim) {
5139 return !absl::c_linear_search(effective_reduce_dims, dim);
5140 },
5141 reduce_window->shape());
5142 simplifier_->UpdateLayout(&reduce_shape);
5143 HloInstruction* reduce =
5144 computation_->AddInstruction(HloInstruction::CreateReduce(
5145 /*shape=*/reduce_shape,
5146 /*operand=*/operand,
5147 /*init_value=*/reduce_window->mutable_operand(1),
5148 /*dimensions_to_reduce=*/effective_reduce_dims,
5149 /*reduce_computation=*/function));
5150 return ReplaceWithNewInstruction(
5151 reduce_window,
5152 HloInstruction::CreateReshape(reduce_window->shape(), reduce));
5153 }
5154 }
5155
5156 // This optimization folds a pad op into reduce_window.
5157 HloInstruction* pad;
5158 const HloInstruction* convert = nullptr;
5159 if (operand->opcode() == HloOpcode::kPad) {
5160 pad = operand;
5161 } else if (operand->opcode() == HloOpcode::kConvert &&
5162 operand->operand(0)->opcode() == HloOpcode::kPad) {
5163 convert = operand;
5164 pad = operand->mutable_operand(0);
5165 } else {
5166 VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
5167 return Status::OK();
5168 }
5169
5170 VLOG(10) << "Considering folding Pad: " << pad->ToString()
5171 << "\ninto reduce-window: " << reduce_window->ToString()
5172 << (convert != nullptr
5173 ? absl::StrCat("\nvia convert: ", convert->ToString())
5174 : "");
5175
5176 // Do not fold interior padding into ReduceWindow since the backends do not
5177 // support it.
5178 const PaddingConfig& pad_config = pad->padding_config();
5179 if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
5180 VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
5181 return Status::OK();
5182 }
5183
5184 // If reduce_window already has padding, the pad value of the pad op and the
5185 // init value of reduce_window must match to allow folding the pad.
5186 const HloInstruction* pad_value = pad->operand(1);
5187 const HloInstruction* reduce_init_value = reduce_window->operand(1);
5188 if (pad_value != reduce_init_value) {
5189 auto literals_are_equivalent = [&] {
5190 auto& pad_literal = pad_value->literal();
5191 auto& reduce_init_literal = reduce_init_value->literal();
5192 if (pad_literal == reduce_init_literal) {
5193 return true;
5194 }
5195 auto converted_pad_literal =
5196 pad_literal.ConvertToShape(reduce_init_value->shape());
5197 if (!converted_pad_literal.ok()) {
5198 return false;
5199 }
5200 return converted_pad_literal.ValueOrDie() == reduce_init_literal;
5201 };
5202 // The pad value is usually a constant, so we handle that case and do not
5203 // try to get more fancy about proving equivalence in cases beyond that.
5204 if (pad_value->opcode() != HloOpcode::kConstant ||
5205 reduce_init_value->opcode() != HloOpcode::kConstant ||
5206 !literals_are_equivalent()) {
5207 VLOG(10) << "Not folding pad into reduce-window due to different pad "
5208 "values.";
5209 return Status::OK();
5210 }
5211 }
5212
5213 // If the pad puts a single non-identity value in each window that we're
5214 // reducing, then this is a broadcast.
5215 HloInstruction* pad_operand = pad->mutable_operand(0);
5216 auto is_effective_broadcast = [&] {
5217 if (window_util::HasStride(window)) {
5218 VLOG(10) << "Window has stride.";
5219 return false;
5220 }
5221 if (!window_util::HasSymmetricPadding(pad_config)) {
5222 VLOG(10) << "Window has uneven padding.";
5223 return false;
5224 }
5225 if (HasInteriorPadding(pad_config)) {
5226 VLOG(10) << "Window has interior padding.";
5227 return false;
5228 }
5229 for (int64_t i = 0; i < pad_config.dimensions_size(); ++i) {
5230 const auto& pad_dimension = pad_config.dimensions(i);
5231 if ((pad_dimension.edge_padding_low() != 0 ||
5232 pad_dimension.edge_padding_high() != 0) &&
5233 pad_operand->shape().dimensions(i) != 1) {
5234 VLOG(10) << "Found non-trivial dimension being padded: " << i;
5235 return false;
5236 }
5237 }
5238 VLOG(10) << "Found to be padding trivial dimensions only.";
5239
5240 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5241 const auto& pad_dimension = pad_config.dimensions(i);
5242 const WindowDimension& window_dimension = window.dimensions(i);
5243 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
5244 pad_dimension.edge_padding_high() != 0);
5245 if (dimension_has_padding &&
5246 window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
5247 VLOG(10) << "Found window did not cover single unpadded element in "
5248 "dimension: "
5249 << i;
5250 return false;
5251 }
5252 if (pad_operand->shape().dimensions(i) != 1 &&
5253 window_dimension.size() != 1) {
5254 VLOG(10) << "Found window covers more than one element in non-trivial "
5255 "dimension: "
5256 << i;
5257 return false;
5258 }
5259 }
5260 VLOG(10) << "Found window covers a single unpadded element.";
5261 return true;
5262 };
5263
5264 HloInstruction* new_reduce_window_operand;
5265 if (convert != nullptr) {
5266 Shape changed_shape = ShapeUtil::ChangeElementType(
5267 pad_operand->shape(), convert->shape().element_type());
5268 simplifier_->UpdateLayout(&changed_shape);
5269 new_reduce_window_operand = computation_->AddInstruction(
5270 HloInstruction::CreateConvert(changed_shape, pad_operand));
5271 } else {
5272 new_reduce_window_operand = pad_operand;
5273 }
5274
5275 if (is_effective_broadcast()) {
5276 VLOG(10) << "Replacing pad/reduce-window with broadcast.";
5277 auto fadd = [this](std::unique_ptr<HloInstruction> x) {
5278 return computation_->AddInstruction(std::move(x));
5279 };
5280 return ReplaceWithNewInstruction(
5281 reduce_window, HloInstruction::CreateBroadcastSequence(
5282 /*output_shape=*/reduce_window->shape(),
5283 /*operand=*/new_reduce_window_operand, fadd));
5284 }
5285
5286 // Carry out the folding of the pad into reduce_window.
5287 VLOG(10) << "Folding pad into reduce-window.";
5288 Window new_window = window;
5289 const int64_t rank = reduce_window->shape().rank();
5290 TF_RET_CHECK(pad_config.dimensions_size() == rank);
5291 TF_RET_CHECK(window.dimensions_size() == rank);
5292 for (int64_t i = 0; i < rank; ++i) {
5293 const auto& pad_dim = pad_config.dimensions(i);
5294 auto& window_dim = *new_window.mutable_dimensions(i);
5295 window_dim.set_padding_low(window_dim.padding_low() +
5296 pad_dim.edge_padding_low());
5297 window_dim.set_padding_high(window_dim.padding_high() +
5298 pad_dim.edge_padding_high());
5299 if (pad_dim.interior_padding() != 0) {
5300 CHECK_EQ(window_dim.base_dilation(), 1);
5301 window_dim.set_base_dilation(1 + pad_dim.interior_padding());
5302 }
5303 }
5304
5305 return ReplaceWithNewInstruction(
5306 reduce_window, HloInstruction::CreateReduceWindow(
5307 /*shape=*/reduce_window->shape(),
5308 /*operand=*/new_reduce_window_operand,
5309 /*init_value=*/reduce_window->mutable_operand(1),
5310 /*window=*/new_window,
5311 /*reduce_computation=*/function));
5312 }
5313
HandleSelect(HloInstruction * select)5314 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
5315 // select(x, y, y) -> y.
5316 if (select->operand(1) == select->operand(2)) {
5317 return ReplaceInstruction(select, select->mutable_operand(1));
5318 }
5319 // select(true, x, y) -> x.
5320 if (IsAll(select->operand(0), true)) {
5321 return ReplaceInstruction(select, select->mutable_operand(1));
5322 }
5323 // select(false, x, y) -> y.
5324 if (IsAll(select->operand(0), false)) {
5325 return ReplaceInstruction(select, select->mutable_operand(2));
5326 }
5327 return Status::OK();
5328 }
5329
HandleScatter(HloInstruction * scatter)5330 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) {
5331 if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) &&
5332 ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) {
5333 return Status::OK();
5334 }
5335 if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) &&
5336 SameShape(scatter, scatter->operand(0)) &&
5337 SameShape(scatter, scatter->operand(2))) {
5338 return ReplaceWithNewInstruction(
5339 scatter, HloInstruction::CreateMap(
5340 scatter->shape(),
5341 {scatter->mutable_operand(0), scatter->mutable_operand(2)},
5342 scatter->to_apply()));
5343 }
5344 return Status::OK();
5345 }
HandleSort(HloInstruction * sort)5346 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5347 auto operand = sort->mutable_operand(0);
5348 int64_t dimension_to_sort = sort->dimensions(0);
5349 if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5350 operand->shape().dimensions(dimension_to_sort) <= 1) {
5351 if (sort->operand_count() == 1) {
5352 return ReplaceInstruction(sort, operand);
5353 }
5354 // If it is key/value sort, the output of sort is a tuple.
5355 return ReplaceWithNewInstruction(
5356 sort, HloInstruction::CreateTuple(sort->operands()));
5357 }
5358 return Status::OK();
5359 }
5360
HandleSqrt(HloInstruction * sqrt)5361 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
5362 VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
5363 HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
5364 if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
5365 sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
5366 return ReplaceWithNewInstruction(
5367 sqrt, HloInstruction::CreateUnary(
5368 sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
5369 sqrt_operand->mutable_operand(0)));
5370 }
5371 return Status::OK();
5372 }
5373
5374 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64> perm)5375 bool OnlyPermutesDegenerateDims(const Shape& shape,
5376 absl::Span<const int64> perm) {
5377 std::vector<int64> new_permutation;
5378 int64_t degenerate_count = 0;
5379 for (int64_t i = 0; i < perm.size(); ++i) {
5380 if (shape.dimensions(i) != 1) {
5381 new_permutation.push_back(perm[i]);
5382 } else {
5383 ++degenerate_count;
5384 }
5385 }
5386 return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
5387 }
5388 } // namespace
5389
HandleTranspose(HloInstruction * transpose)5390 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
5391 auto operand = transpose->mutable_operand(0);
5392 if (std::is_sorted(transpose->dimensions().begin(),
5393 transpose->dimensions().end())) {
5394 VLOG(10) << "deleting no-op transpose";
5395 return ReplaceInstruction(transpose, operand);
5396 }
5397
5398 if (HloOpcode::kTranspose == operand->opcode()) {
5399 return ReplaceWithNewInstruction(
5400 transpose, HloInstruction::CreateTranspose(
5401 transpose->shape(), operand->mutable_operand(0),
5402 ComposePermutations(operand->dimensions(),
5403 transpose->dimensions())));
5404 }
5405
5406 // Convert transpose(dot(a,b)) to dot(b,a).
5407 if (operand->opcode() == HloOpcode::kDot && operand->user_count() == 1 &&
5408 operand->shape().rank() == 2) {
5409 TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr<bool> {
5410 const auto& dnums = operand->dot_dimension_numbers();
5411 if (dnums.lhs_batch_dimensions_size() != 0) {
5412 return false;
5413 }
5414 HloInstruction* lhs = operand->mutable_operand(0);
5415 if (lhs->shape().rank() != 1 + dnums.lhs_contracting_dimensions_size()) {
5416 return false;
5417 }
5418 HloInstruction* rhs = operand->mutable_operand(1);
5419 if (rhs->shape().rank() != 1 + dnums.rhs_contracting_dimensions_size()) {
5420 return false;
5421 }
5422 DotDimensionNumbers new_dnums;
5423 *new_dnums.mutable_lhs_contracting_dimensions() =
5424 dnums.rhs_contracting_dimensions();
5425 *new_dnums.mutable_rhs_contracting_dimensions() =
5426 dnums.lhs_contracting_dimensions();
5427 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
5428 transpose, HloInstruction::CreateDot(transpose->shape(), /*lhs=*/rhs,
5429 /*rhs=*/lhs, new_dnums,
5430 operand->precision_config())));
5431 return true;
5432 }());
5433 if (did_transform) {
5434 return Status::OK();
5435 }
5436 }
5437
5438 // Replace transpose with a reshape if more than one degenerate method is
5439 // permuted.
5440 if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
5441 return ReplaceWithNewInstruction(
5442 transpose, HloInstruction::CreateReshape(
5443 transpose->shape(), transpose->mutable_operand(0)));
5444 }
5445
5446 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
5447 *operand->mutable_shape() = transpose->shape();
5448 return ReplaceInstruction(transpose, operand);
5449 }
5450
5451 if (options_.is_layout_sensitive() &&
5452 options_.replace_transpose_with_bitcast() &&
5453 TransposeIsBitcast(transpose)) {
5454 ReplaceWithBitcast(transpose);
5455 return Status::OK();
5456 }
5457
5458 return Status::OK();
5459 }
5460
FoldConvInputPad(HloInstruction * convolution)5461 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
5462 HloInstruction* convolution) {
5463 HloInstruction *lhs, *a, *b;
5464 if (Match(convolution,
5465 m::Convolution(m::Pad(&lhs, m::Op(&a), m::ConstantScalar(0)),
5466 m::Op(&b)))) {
5467 const auto& window = convolution->window();
5468 const ConvolutionDimensionNumbers& dnums =
5469 convolution->convolution_dimension_numbers();
5470
5471 const auto& padding = lhs->padding_config();
5472
5473 // Can't pad batch or feature dims.
5474 for (int64_t dim :
5475 {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
5476 const auto& p = padding.dimensions(dim);
5477 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5478 p.interior_padding() != 0) {
5479 return false;
5480 }
5481 }
5482
5483 // Compute the window which is the result of merging the kPad and the
5484 // convolution's existing window.
5485 Window new_window = window;
5486 for (int64_t dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
5487 auto& w = *new_window.mutable_dimensions(dim);
5488 const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
5489 // Edge padding composes with itself in the straightforward way, but
5490 // composing interior padding is nontrivial, and we cowardly refuse to
5491 // think about it. If we see interior padding in either the kPad or conv,
5492 // bail if there's any sort of padding in the other.
5493 if (p.interior_padding() != 0 &&
5494 (w.padding_low() != 0 || w.padding_high() != 0 ||
5495 w.base_dilation() != 1)) {
5496 return false;
5497 }
5498 if (w.base_dilation() != 1 &&
5499 (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5500 p.interior_padding() != 0)) {
5501 return false;
5502 }
5503
5504 w.set_padding_low(w.padding_low() + p.edge_padding_low());
5505 w.set_padding_high(w.padding_high() + p.edge_padding_high());
5506 if (p.interior_padding() != 0) {
5507 CHECK_EQ(w.base_dilation(), 1);
5508 w.set_base_dilation(1 + p.interior_padding());
5509 }
5510 }
5511
5512 auto new_conv =
5513 convolution->CloneWithNewOperands(convolution->shape(), {a, b});
5514 new_conv->set_window(new_window);
5515 TF_RETURN_IF_ERROR(
5516 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5517 return true;
5518 }
5519 return false;
5520 }
5521
FoldConvFilterPad(HloInstruction * convolution)5522 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
5523 HloInstruction* convolution) {
5524 auto* lhs = convolution->mutable_operand(0);
5525 auto* rhs = convolution->mutable_operand(1);
5526 const ConvolutionDimensionNumbers& dnums =
5527 convolution->convolution_dimension_numbers();
5528
5529 if (rhs->opcode() != HloOpcode::kPad) {
5530 return false;
5531 }
5532
5533 // Convolution's padding is always zero, so bail if the kPad is adding
5534 // something other than zero.
5535 if (!IsAll(rhs->operand(1), 0)) {
5536 return false;
5537 }
5538
5539 const auto& padding = rhs->padding_config();
5540
5541 // Can't pad or dilate feature dims.
5542 for (int64_t dim : {dnums.kernel_input_feature_dimension(),
5543 dnums.kernel_output_feature_dimension()}) {
5544 const auto& p = padding.dimensions(dim);
5545 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5546 p.interior_padding() != 0) {
5547 return false;
5548 }
5549 }
5550
5551 // Compute the window which is the result of merging the kPad and the
5552 // convolution's existing window.
5553 Window new_window = convolution->window();
5554 for (int64_t dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
5555 auto& w = *new_window.mutable_dimensions(dim);
5556 const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
5557
5558 // We can only do this transformation if p adds dilation to the filter --
5559 // edge padding on the filter is not supported in conv.
5560 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
5561 return false;
5562 }
5563
5564 // Nothing to do if the kPad for this dim is entirely a nop.
5565 if (p.interior_padding() == 0) {
5566 continue;
5567 }
5568
5569 // We cowardly refuse to think about how dilation composes with itself;
5570 // bail if both the kPad and conv have dilation on this dimension.
5571 if (w.window_dilation() > 1) {
5572 return false;
5573 }
5574 CHECK_EQ(w.window_dilation(), 1);
5575 w.set_window_dilation(1 + p.interior_padding());
5576 w.set_size(rhs->operand(0)->shape().dimensions(
5577 dnums.kernel_spatial_dimensions(dim)));
5578 }
5579
5580 auto new_conv = convolution->CloneWithNewOperands(
5581 convolution->shape(), {lhs, rhs->mutable_operand(0)});
5582 new_conv->set_window(new_window);
5583 TF_RETURN_IF_ERROR(
5584 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5585 return true;
5586 }
5587
SwapConvOperands(HloInstruction * convolution)5588 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
5589 HloInstruction* convolution) {
5590 if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
5591 return false;
5592 }
5593 if (convolution->feature_group_count() > 1 ||
5594 convolution->batch_group_count() > 1) {
5595 return false;
5596 }
5597
5598 const auto& dnums = convolution->convolution_dimension_numbers();
5599 const auto& window_dims = convolution->window().dimensions();
5600 Window swapped_window;
5601
5602 HloInstruction *input = convolution->mutable_operand(0),
5603 *kernel = convolution->mutable_operand(1);
5604 int64_t kernel_product = 1;
5605 int64_t swapped_kernel_product = 1;
5606 DimensionVector reverse_dimensions;
5607 for (int64_t spatial_dim = 0;
5608 spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
5609 const int64_t kernel_size = window_dims[spatial_dim].size();
5610 const bool can_be_group_or_contraction =
5611 !window_dims[spatial_dim].window_reversal() &&
5612 window_dims[spatial_dim].padding_low() == 0 &&
5613 window_dims[spatial_dim].padding_high() == 0 &&
5614 window_dims[spatial_dim].window_dilation() == 1;
5615 const bool is_group_dim =
5616 can_be_group_or_contraction &&
5617 window_dims[spatial_dim].base_dilation() == kernel_size &&
5618 window_dims[spatial_dim].stride() == kernel_size - 1;
5619 const int64_t input_size =
5620 input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
5621 const bool is_pure_contraction_dim =
5622 kernel_size == input_size && can_be_group_or_contraction &&
5623 window_dims[spatial_dim].base_dilation() == 1 &&
5624 window_dims[spatial_dim].stride() == 1;
5625 if (is_group_dim || is_pure_contraction_dim) {
5626 *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
5627 continue;
5628 }
5629
5630 const int64_t dilated_kernel_size =
5631 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
5632 const int64_t dilated_input_size =
5633 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
5634
5635 // Don't decide to swap if the input size is one, since many convolution
5636 // implementations can easily hand that special case efficiently.
5637 kernel_product *= kernel_size;
5638 swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
5639
5640 auto new_dim = swapped_window.add_dimensions();
5641 new_dim->set_size(input_size);
5642 // If the kernel is not reversed, the activations must be manually reversed.
5643 if (!window_dims[spatial_dim].window_reversal()) {
5644 reverse_dimensions.push_back(
5645 dnums.kernel_spatial_dimensions(spatial_dim));
5646 }
5647 // The input is not originally reversed so it must be reversed to move the
5648 // kernel.
5649 new_dim->set_window_reversal(true);
5650 // Base dilation and window dilation switch places.
5651 new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
5652 new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
5653 new_dim->set_stride(window_dims[spatial_dim].stride());
5654 new_dim->set_padding_low(dilated_input_size +
5655 window_dims[spatial_dim].padding_low() -
5656 dilated_kernel_size);
5657 new_dim->set_padding_high(dilated_input_size +
5658 window_dims[spatial_dim].padding_high() -
5659 dilated_kernel_size);
5660 }
5661
5662 // Don't transform if a naive convolution implementation would not have fewer
5663 // flops.
5664 if (kernel_product <= swapped_kernel_product) {
5665 return false;
5666 }
5667 ConvolutionDimensionNumbers swapped_dnums;
5668 *swapped_dnums.mutable_output_spatial_dimensions() =
5669 dnums.output_spatial_dimensions();
5670 // Swap batch and output feature of the output.
5671 swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
5672 swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
5673
5674 // Swap input dnums with kernel dnums
5675 *swapped_dnums.mutable_input_spatial_dimensions() =
5676 dnums.kernel_spatial_dimensions();
5677 swapped_dnums.set_input_batch_dimension(
5678 dnums.kernel_output_feature_dimension());
5679 swapped_dnums.set_input_feature_dimension(
5680 dnums.kernel_input_feature_dimension());
5681
5682 // Swap kernel dnums with input dnums
5683 *swapped_dnums.mutable_kernel_spatial_dimensions() =
5684 dnums.input_spatial_dimensions();
5685 swapped_dnums.set_kernel_output_feature_dimension(
5686 dnums.input_batch_dimension());
5687 swapped_dnums.set_kernel_input_feature_dimension(
5688 dnums.input_feature_dimension());
5689
5690 PrecisionConfig precision_config;
5691 precision_config.add_operand_precision(
5692 convolution->precision_config().operand_precision(1));
5693 precision_config.add_operand_precision(
5694 convolution->precision_config().operand_precision(0));
5695 if (!reverse_dimensions.empty()) {
5696 TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
5697 }
5698 TF_ASSIGN_OR_RETURN(
5699 HloInstruction * new_convolution,
5700 MakeConvolveHlo(
5701 kernel, input, /*feature_group_count=*/1,
5702 /*batch_group_count=*/1, swapped_window, swapped_dnums,
5703 precision_config,
5704 /*preferred_element_type=*/convolution->shape().element_type()));
5705
5706 convolution->SetupDerivedInstruction(new_convolution);
5707 TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
5708
5709 return true;
5710 }
5711
SimplifyConvToDot(HloInstruction * convolution)5712 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
5713 HloInstruction* convolution) {
5714 auto* lhs = convolution->mutable_operand(0);
5715 auto* rhs = convolution->mutable_operand(1);
5716 const auto& window = convolution->window();
5717 const ConvolutionDimensionNumbers& dnums =
5718 convolution->convolution_dimension_numbers();
5719
5720 if (!options_.enable_conv_simplification()) {
5721 return false;
5722 }
5723
5724 // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
5725 // layout-insensitive mode, for fear of adding nontrivial reshapes.
5726 if (!options_.is_layout_sensitive()) {
5727 return false;
5728 }
5729
5730 const Shape& input_shape = lhs->shape();
5731 const Shape& filter_shape = rhs->shape();
5732 const Shape& convolution_shape = convolution->shape();
5733 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
5734 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
5735 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
5736
5737 // Require the spatial dimensions in the kernel to have a bound of one.
5738 for (int64_t i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
5739 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
5740 return false;
5741 }
5742 }
5743
5744 // Stride ignores part of the output, which matrix multiplication does not do,
5745 // so require no stride. Padding and base (lhs) dilation both implicitly
5746 // extend the data, which matrix multiplication also does not do, so require
5747 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
5748 // for a 1x1 window, so window dilation is no problem.
5749 if (window_util::HasStride(window) || window_util::HasPadding(window) ||
5750 window_util::HasBaseDilation(window)) {
5751 return false;
5752 }
5753
5754 // Also, the shapes must align for a rowmajor matmul:
5755 // - the input and output have the same layout.
5756 // - for input/output, the channel dimension must be the most minor. Other
5757 // spatial dims can be in any order.
5758 // - for filters, the input channel dimension must be more major than the
5759 // output channel dimension. The width+height don't matter because
5760 // they are 1.
5761 //
5762 // These constraints are harsh. If the channel dimension is the most major
5763 // and/or the layout of input/output feature dimensions are reversed, we can
5764 // still convert Conv into more efficient Matmul with operand transposition
5765 // (such as the transposition flags in cuBLAS SGEMM).
5766 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
5767 LayoutUtil::Minor(input_shape.layout(), 0) !=
5768 dnums.input_feature_dimension() ||
5769 LayoutUtil::Minor(convolution_shape.layout(), 0) !=
5770 dnums.output_feature_dimension() ||
5771 // The input feature dimension should come later in the minor-to-major
5772 // order.
5773 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5774 dnums.kernel_input_feature_dimension()) <
5775 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5776 dnums.kernel_output_feature_dimension()))) {
5777 return false;
5778 }
5779
5780 auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
5781 std::vector<int64> dims(operand->shape().dimensions_size());
5782 std::iota(dims.begin(), dims.end(), 0);
5783 return computation_->AddInstruction(
5784 HloInstruction::CreateBitcast(shape, operand));
5785 };
5786
5787 // Replace it with a dot, with bitcasts around it to get the right shape.
5788 const int64_t input_channels =
5789 input_shape.dimensions(dnums.input_feature_dimension());
5790 const int64_t output_channels =
5791 filter_shape.dimensions(dnums.kernel_output_feature_dimension());
5792
5793 // Computes the product of the non-feature dimensions.
5794 int64_t conv_width = 1;
5795 for (int i = 0; i < input_shape.dimensions_size(); ++i) {
5796 if (i != dnums.input_feature_dimension()) {
5797 conv_width *= input_shape.dimensions(i);
5798 }
5799 }
5800
5801 // We already checked feature_dimension is most minor, so data in input_shape
5802 // and row-major {conv_width,input_channels} are bitwise identical.
5803 Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5804 input_shape.element_type(), {conv_width, input_channels});
5805 simplifier_->UpdateLayout(&new_input_shape);
5806 // We already checked input_feature_dimension is more major than
5807 // output_feature_dimension, so data in filter_shape and row-major
5808 // {input_channels,output_channels} are bitwise identical.
5809 Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5810 filter_shape.element_type(), {input_channels, output_channels});
5811 simplifier_->UpdateLayout(&new_filter_shape);
5812 Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5813 convolution_shape.element_type(), {conv_width, output_channels});
5814 simplifier_->UpdateLayout(&dot_output_shape);
5815
5816 auto new_lhs = add_bitcast(new_input_shape, lhs);
5817 auto new_rhs = add_bitcast(new_filter_shape, rhs);
5818 DotDimensionNumbers dot_dimension_numbers;
5819 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
5820 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
5821 auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
5822 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
5823 convolution->precision_config()));
5824
5825 TF_RETURN_IF_ERROR(
5826 ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
5827 return true;
5828 }
5829
HandleConvolution(HloInstruction * convolution)5830 Status AlgebraicSimplifierVisitor::HandleConvolution(
5831 HloInstruction* convolution) {
5832 if (options_.enable_scalar_multiply_reduction()) {
5833 TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
5834 }
5835
5836 // Zero-sized input or filter.
5837 if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
5838 ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
5839 return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
5840 }
5841
5842 // Try to merge padding/dilation of the input with the convolution's window.
5843 TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
5844 if (folded_input_pad) {
5845 return Status::OK();
5846 }
5847
5848 // Try to merge dilation of the filter with the convolution's window.
5849 TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
5850 if (folded_filter_pad) {
5851 return Status::OK();
5852 }
5853
5854 // Try to swap convolution operands.
5855 TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
5856 if (swapped) {
5857 return Status::OK();
5858 }
5859 // Try to replace the convolution with a kDot instruction.
5860 TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
5861 if (replaced_with_dot) {
5862 return Status::OK();
5863 }
5864
5865 return Status::OK();
5866 }
5867
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)5868 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
5869 HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
5870 HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
5871 // Ensure shapes of min and max operand are equal to match current shape
5872 // inference.
5873 if (!SameShape(min_operand, max_operand)) {
5874 return false;
5875 }
5876
5877 auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
5878 max_operand, operand, min_operand);
5879 TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
5880 return true;
5881 }
5882
HandleMap(HloInstruction * map)5883 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
5884 auto* map_computation = map->to_apply();
5885 auto* map_root = map_computation->root_instruction();
5886 if (map_root->opcode() == HloOpcode::kParameter) {
5887 ReplaceInstructionIfSameShape(
5888 map, map->mutable_operand(map_root->parameter_number()));
5889 return Status::OK();
5890 }
5891 if (map_root->opcode() == HloOpcode::kConstant) {
5892 if (!ShapeUtil::IsScalar(map_root->shape())) {
5893 return Status::OK();
5894 }
5895 auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
5896 if (ShapeUtil::IsScalar(map->shape())) {
5897 return ReplaceWithNewInstruction(map, std::move(clone));
5898 }
5899 return ReplaceWithNewInstruction(
5900 map,
5901 HloInstruction::CreateBroadcast(
5902 map->shape(), computation_->AddInstruction(std::move(clone)), {}));
5903 }
5904 // Inline the map if the map computation only contains an elementwise
5905 // operation that can accept arbitrary shapes.
5906 if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
5907 return Status::OK();
5908 }
5909 std::vector<HloInstruction*> new_operands;
5910 for (auto* root_operand : map_root->operands()) {
5911 if (root_operand->opcode() != HloOpcode::kParameter) {
5912 return Status::OK();
5913 }
5914 new_operands.push_back(
5915 map->mutable_operand(root_operand->parameter_number()));
5916 }
5917 auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
5918 return ReplaceWithNewInstruction(map, std::move(clone));
5919 }
5920
Run(HloModule * module)5921 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
5922 XLA_VLOG_LINES(2,
5923 "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
5924 bool changed = false;
5925 AlgebraicSimplifierVisitor visitor(options_, this);
5926 for (auto* comp : module->MakeNonfusionComputations()) {
5927 if (visitor.Run(comp, options_, this)) {
5928 changed = true;
5929 }
5930 }
5931 XLA_VLOG_LINES(2,
5932 "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
5933 return changed;
5934 }
5935
5936 } // namespace xla
5937