1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/bits.h"
61 #include "tensorflow/core/lib/core/errors.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/types.h"
66 #include "tensorflow/stream_executor/lib/statusor.h"
67
68 namespace xla {
69
70 namespace {
71
72 namespace m = match;
73
IsAll(const HloInstruction * op,int8 value)74 bool IsAll(const HloInstruction* op, int8 value) {
75 switch (op->opcode()) {
76 case HloOpcode::kBroadcast:
77 return IsAll(op->operand(0), value);
78 case HloOpcode::kConstant:
79 return op->literal().IsAll(value);
80 default:
81 return false;
82 }
83 }
84
IsAnyOperandComplex(const HloInstruction * hlo)85 bool IsAnyOperandComplex(const HloInstruction* hlo) {
86 for (auto operand : hlo->operands()) {
87 if (ShapeUtil::ElementIsComplex(operand->shape())) {
88 return true;
89 }
90 }
91 return false;
92 }
93
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)94 bool IsPositive(const HloInstruction* hlo,
95 const AlgebraicSimplifierOptions& options) {
96 // Utility only handles real types.
97 if (IsAnyOperandComplex(hlo)) {
98 return false;
99 }
100 switch (hlo->opcode()) {
101 case HloOpcode::kGetTupleElement: {
102 const HloInstruction* gte_operand = hlo->operand(0);
103 switch (gte_operand->opcode()) {
104 case HloOpcode::kCustomCall: {
105 const auto& target = gte_operand->custom_call_target();
106 return target ==
107 options.get_cudnn_batchnorm_forward_training_metadata() &&
108 hlo->tuple_index() == 2;
109 }
110 default:
111 return false;
112 }
113 }
114 case HloOpcode::kPower:
115 case HloOpcode::kAbs:
116 case HloOpcode::kRsqrt:
117 case HloOpcode::kSqrt:
118 return IsPositive(hlo->operand(0), options);
119
120 case HloOpcode::kMultiply: {
121 return hlo->operand(0) == hlo->operand(1) &&
122 IsPositive(hlo->operand(0), options);
123 }
124 default:
125 return false;
126 }
127 }
128
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)129 bool IsNonNegative(const HloInstruction* hlo,
130 const AlgebraicSimplifierOptions& options) {
131 // Utility only handles real types.
132 if (IsAnyOperandComplex(hlo)) {
133 return false;
134 }
135 switch (hlo->opcode()) {
136 case HloOpcode::kMultiply: {
137 return hlo->operand(0) == hlo->operand(1);
138 }
139 case HloOpcode::kAbs: {
140 return true;
141 }
142 default:
143 return IsPositive(hlo, options);
144 }
145 }
146
147 // Checks whether `op` is a floating-point constant or broadcast of a constant
148 // of the form +/- 2^k for some integer k positive, negative, or zero. Such
149 // values are interesting because multiplying by a power of 2 just moves the
150 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)151 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
152 // Unwrap the broadcast if necessary.
153 const HloInstruction* c;
154 if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
155 !Match(op, m::Broadcast(m::Constant(&c).WithShape(
156 m::Shape().IsEffectiveScalar())))) {
157 return false;
158 }
159 auto val = [&]() -> absl::optional<double> {
160 switch (c->shape().element_type()) {
161 case BF16:
162 return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
163 case F16:
164 return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
165 case F32:
166 return c->literal().GetFirstElement<float>();
167 case F64:
168 return c->literal().GetFirstElement<double>();
169 default:
170 // Cowardly refuse to consider complex types.
171 return absl::nullopt;
172 }
173 }();
174 if (!val) {
175 return false;
176 }
177
178 int exp;
179 double mantissa = std::frexp(*val, &exp);
180 // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value
181 // of +/-0.5 therefore indicates that the floating point value is a power of
182 // 2.
183 return mantissa == 0.5 || mantissa == -0.5;
184 }
185
186 // Returns whether the given transpose produces a result which is bit-wise
187 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)188 bool TransposeIsBitcast(const HloInstruction* transpose) {
189 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
190 const HloInstruction* operand = transpose->operand(0);
191 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
192 transpose->dimensions());
193 }
194
195 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)196 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
197 HloInstruction* instr, HloInstruction* operand,
198 const AlgebraicSimplifierOptions& options) {
199 // Can't replace chain of copies and reshapes with bitcasts if the compiler
200 // used a memory layout which isn't compatible.
201 if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
202 return operand;
203 }
204
205 // If the operand is a copy or reshape try to see if the operand's operand
206 // would produce a bitcast with initial instruction.
207 if (HloOpcode::kReshape == operand->opcode() ||
208 HloOpcode::kCopy == operand->opcode()) {
209 return BitcastingOperandOfReshapeOrCopyChainHelper(
210 instr, operand->mutable_operand(0), options);
211 }
212 return nullptr;
213 }
214
215 // Returns an operand of a chain of reshapes and copies that is bit-wise
216 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)217 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
218 HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
219 if (!options.is_layout_sensitive()) {
220 return nullptr;
221 }
222 CHECK(HloOpcode::kReshape == instr->opcode() ||
223 HloOpcode::kCopy == instr->opcode());
224 return BitcastingOperandOfReshapeOrCopyChainHelper(
225 instr, instr->mutable_operand(0), options);
226 }
227
IsUnstridedSlice(const HloInstruction * hlo)228 bool IsUnstridedSlice(const HloInstruction* hlo) {
229 return absl::c_all_of(hlo->slice_strides(),
230 [](int64 stride) { return stride == 1; });
231 }
232
233 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)234 bool IsConvertPairNoOp(const HloInstruction* convert) {
235 // [operand_convert] [convert]
236 // (src)->convert-(intermediate)->convert-(dest)
237 const HloInstruction* operand_convert = convert->operand(0);
238 CHECK_EQ(operand_convert->opcode(), HloOpcode::kConvert);
239 const Shape& src_shape = operand_convert->operand(0)->shape();
240 const Shape& intermediate_shape = operand_convert->shape();
241 const Shape& dest_shape = convert->shape();
242
243 const PrimitiveType src_type = src_shape.element_type();
244 const PrimitiveType intermediate_type = intermediate_shape.element_type();
245 const PrimitiveType dest_type = dest_shape.element_type();
246
247 // src_type must be equal to dest_type.
248 if (src_type != dest_type) {
249 return false;
250 }
251
252 // src_type must be a larger container than intermediate_type.
253 if (ShapeUtil::ByteSizeOfPrimitiveType(intermediate_type) <=
254 ShapeUtil::ByteSizeOfPrimitiveType(src_type)) {
255 return false;
256 }
257
258 // Both src_type and intermediate_type must be either floating or integral.
259 bool is_conversion_floating =
260 ShapeUtil::ElementIsFloating(src_shape) &&
261 ShapeUtil::ElementIsFloating(intermediate_shape);
262 bool is_conversion_integral =
263 ShapeUtil::ElementIsIntegral(src_shape) &&
264 ShapeUtil::ElementIsIntegral(intermediate_shape);
265
266 return is_conversion_floating || is_conversion_integral;
267 }
268
269 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
270 // algebraic expressions to simplified forms. Note: This only supports
271 // simplifications that simply look at the operands of an instruction. For the
272 // more general case a worklist based approach would be needed.
273 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
274 public:
AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)275 explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
276 AlgebraicSimplifier* simplifier)
277 : options_(options), simplifier_(simplifier) {}
278
279 Status HandleAbs(HloInstruction* abs) override;
280
281 Status HandleAdd(HloInstruction* add) override;
282
283 Status HandleAnd(HloInstruction* logical_and) override;
284
285 Status HandleBitcast(HloInstruction* bitcast) override;
286
287 Status HandleBitcastConvert(HloInstruction* bitcast) override;
288
289 Status HandleBroadcast(HloInstruction* broadcast) override;
290
291 Status HandleCompare(HloInstruction* compare) override;
292
293 Status HandleConcatenate(HloInstruction* concatenate) override;
294
295 Status HandleConstant(HloInstruction* constant) override;
296
297 Status HandleCopy(HloInstruction* copy) override;
298
299 Status HandleConvert(HloInstruction* convert) override;
300
301 Status HandleComplex(HloInstruction* complex) override;
302
303 Status HandleReal(HloInstruction* real) override;
304
305 Status HandleImag(HloInstruction* imag) override;
306
307 Status HandleIota(HloInstruction* instruction) override;
308
309 Status HandleConvolution(HloInstruction* convolution) override;
310
311 Status HandleDivide(HloInstruction* divide) override;
312
313 Status HandleDot(HloInstruction* dot) override;
314
315 Status HandleGather(HloInstruction* gather) override;
316
317 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
318
319 Status HandleLog(HloInstruction* log) override;
320
321 Status HandleMaximum(HloInstruction* maximum) override;
322
323 Status HandleMinimum(HloInstruction* minimum) override;
324
325 Status HandleClamp(HloInstruction* clamp) override;
326
327 Status HandleMultiply(HloInstruction* multiply) override;
328
329 Status HandleNegate(HloInstruction* negate) override;
330
331 Status HandleNot(HloInstruction* logical_not) override;
332
333 Status HandleOr(HloInstruction* logical_or) override;
334
335 Status HandlePad(HloInstruction* pad) override;
336
337 Status HandlePower(HloInstruction* power) override;
338
339 Status HandleRemainder(HloInstruction* remainder) override;
340
341 Status HandleReshape(HloInstruction* reshape) override;
342
343 Status HandleReduce(HloInstruction* hlo) override;
344
345 Status HandleReduceWindow(HloInstruction* reduce_window) override;
346
347 Status HandleReverse(HloInstruction* reverse) override;
348
349 Status HandleRsqrt(HloInstruction* rsqrt) override;
350
351 Status HandleSlice(HloInstruction* slice) override;
352
353 Status HandleSqrt(HloInstruction* sqrt) override;
354
355 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
356
357 Status HandleDynamicUpdateSlice(
358 HloInstruction* dynamic_update_slice) override;
359 Status HandleScatter(HloInstruction* scatter) override;
360
361 Status HandleSelect(HloInstruction* select) override;
362
363 Status HandleSort(HloInstruction* sort) override;
364
365 Status HandleTranspose(HloInstruction* transpose) override;
366
367 Status HandleSubtract(HloInstruction* sub) override;
368
369 Status HandleMap(HloInstruction* map) override;
370
371 // Runs the visitor on a computation.
372 bool Run(HloComputation* computation,
373 const AlgebraicSimplifierOptions& options,
374 AlgebraicSimplifier* simplifier);
375
376 private:
377 // Removes degenerate dimension from dot.
378 StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
379
380 // Converts to primitive type if the input hlo is not that type, otherwise
381 // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)382 HloInstruction* AsType(HloInstruction* hlo,
383 const PrimitiveType element_type) {
384 if (hlo->shape().element_type() == element_type) {
385 return hlo;
386 }
387 Shape changed_shape =
388 ShapeUtil::ChangeElementType(hlo->shape(), element_type);
389 simplifier_->UpdateLayout(&changed_shape);
390 return computation_->AddInstruction(
391 HloInstruction::CreateConvert(changed_shape, hlo));
392 }
393
394 // Transposes a dot operand such that the batch dimensions are the most major,
395 // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)396 StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
397 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
398 absl::Span<const int64> contracting_dimensions) {
399 std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
400 batch_dimensions.end());
401 for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
402 if (!(absl::c_linear_search(batch_dimensions, i) ||
403 absl::c_linear_search(contracting_dimensions, i))) {
404 transpose_dimensions.push_back(i);
405 }
406 }
407 transpose_dimensions.insert(transpose_dimensions.end(),
408 contracting_dimensions.begin(),
409 contracting_dimensions.end());
410 if (absl::c_is_sorted(transpose_dimensions)) {
411 return dot_operand;
412 }
413 return MakeTransposeHlo(dot_operand, transpose_dimensions);
414 }
415
416 // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims,PrimitiveType type)417 HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims,
418 PrimitiveType type) {
419 HloInstruction* zero = computation_->AddInstruction(
420 simplifier_->CreateConstantWithLayoutUpdated(
421 LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
422 HloComputation* AddReduce_computation =
423 GetOrCreateScalarAddComputation(type);
424 Shape shape = ShapeUtil::FilterDimensions(
425 [&](int64 dim) { return !absl::c_linear_search(dims, dim); },
426 hlo->shape());
427 simplifier_->UpdateLayout(&shape);
428 return computation_->AddInstruction(HloInstruction::CreateReduce(
429 shape, hlo, zero, dims, AddReduce_computation));
430 }
431
432 // Move scalar multiply to the smallest side of convolution to
433 // reduce multiply computations.
434 Status ScalarMultiplyReduction(HloInstruction* dot);
435
436 // Convenience method for replacing an instruction with a bitcast. If operand
437 // is not null, then the bitcast will use the specified operand instead of the
438 // operand of the instruction.
439 void ReplaceWithBitcast(HloInstruction* instruction,
440 HloInstruction* operand = nullptr);
441
442 // Replace old instruction with new instruction if old and new instructions
443 // have the same shape. Updates uses and root instruction. Returns whether a
444 // replacement was made.
445 bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
446 HloInstruction* new_instruction);
447
448 // Returns whether the shape of the output of the given instructions are the
449 // same for the purposes of simplification. If options_.is_layout_sensitive()
450 // is true, then this tests shape equality including layout
451 // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
452 // tests shape compatibility (ShapeUtil::Compatible).
453 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
454
455 // Returns whether it was possible to transform `root` to a clamp instruction.
456 // With min a minimum instruction, max a maximum instruction, min_operand a
457 // operand of min and max_operand a operand of max.
458 // Precondition: root is either a minimum or a maximum.
459 bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
460 HloInstruction* min_operand,
461 HloInstruction* operand, HloInstruction* max,
462 HloInstruction* max_operand);
463
464 // A Broadcast that feeds an element-wise operation with a unique non-scalar
465 // operand can sink to after the operation.
466 StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
467 HloInstruction* broadcast);
468
469 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
470 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
471 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
472 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
473
474 StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
475
476 StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
477 HloInstruction* dot);
478
GetOrCreateScalarAddComputation(PrimitiveType type)479 HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
480 HloComputation*& scalar_add_computation = scalar_add_computations_[type];
481 if (scalar_add_computation) {
482 return scalar_add_computation;
483 }
484
485 HloComputation::Builder b("scalar_add_computation");
486 Shape shape = ShapeUtil::MakeShape(type, {});
487 simplifier_->UpdateLayout(&shape);
488 auto scalar_lhs = b.AddInstruction(
489 HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
490 auto scalar_rhs = b.AddInstruction(
491 HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
492 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
493 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
494 scalar_add_computation =
495 computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
496 return scalar_add_computation;
497 }
498
499 // Tries to fold a kPad in the input or filter into the convolution
500 // instruction's window.
501 StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
502 StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
503
504 // Tries to swap convolution operands if they would result in a more efficient
505 // convolution.
506 StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
507
508 // Tries to use a kDot in place of the given convolution.
509 StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
510
511 // Tries to simplify a slice where the result of the slice is a scalar.
512 StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
513
514 // Tries to convert slice(reshape(X)) into reshape(slice(X))
515 StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
516
517 // Tries to convert slice(reverse(X)) into reverse(slice(X))
518 StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice);
519
520 // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
521 // `(< a N)`. This is crucial for being able to figure out the loop trip
522 // count.
523 //
524 // Assumes that the input is conjunction.
525 StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
526
527 // Useful when we want to use the same visitor over multiple computations.
528 void ResetState(HloComputation* computation);
529
530 // Current HloComputation instance the AlgebraicSimplifierVisitor is
531 // traversing.
532 HloComputation* computation_;
533
534 // The backend-specific options selected for the algebraic simplifier.
535 const AlgebraicSimplifierOptions& options_;
536
537 // Whether algebraic simplification has occurred.
538 bool changed_ = false;
539
540 // Cached computation for adding two scalars of a given type.
541 absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
542
543 AlgebraicSimplifier* simplifier_ = nullptr;
544 };
545
546 } // namespace
547
ResetState(HloComputation * computation)548 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
549 changed_ = false;
550 ResetVisitStates();
551 computation_ = computation;
552 }
553
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)554 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
555 const AlgebraicSimplifierOptions& options,
556 AlgebraicSimplifier* simplifier) {
557 ResetState(computation);
558 TF_CHECK_OK(computation->Accept(this));
559 return changed_ || changed();
560 }
561
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const562 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
563 const HloInstruction* rhs) const {
564 if (options_.is_layout_sensitive()) {
565 return ShapeUtil::Equal(lhs->shape(), rhs->shape());
566 } else {
567 return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
568 }
569 }
570
571 namespace {
572
GetConstantValue(HloInstruction * inst)573 float GetConstantValue(HloInstruction* inst) {
574 switch (inst->shape().element_type()) {
575 case BF16:
576 return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
577 case F32:
578 return inst->literal().GetFirstElement<float>();
579 default:
580 LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type();
581 }
582 }
583
IsOpCodeMultiplyCommutative(HloOpcode opcode)584 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
585 switch (opcode) {
586 case HloOpcode::kMultiply:
587 case HloOpcode::kTranspose:
588 case HloOpcode::kReshape:
589 case HloOpcode::kSelect:
590 return true;
591 default:
592 return false;
593 }
594 }
595
MakeScalarInstruction(HloInstruction * target,float multiplier)596 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
597 float multiplier) {
598 switch (target->shape().element_type()) {
599 case BF16:
600 return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
601 LiteralUtil::CreateR0<float>(multiplier)));
602 break;
603 case F32:
604 return HloInstruction::CreateConstant(
605 LiteralUtil::CreateR0<float>(multiplier));
606 break;
607 default:
608 LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
609 }
610 }
611
612 } // namespace
613
ScalarMultiplyReduction(HloInstruction * dot)614 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
615 HloInstruction* dot) {
616 // We only process bfloat16 and float32 for now.
617 if (dot->shape().element_type() != BF16 &&
618 dot->shape().element_type() != F32) {
619 return Status::OK();
620 }
621
622 auto lhs = dot->mutable_operand(0);
623 auto rhs = dot->mutable_operand(1);
624
625 const int64 dot_size = ShapeUtil::ElementsIn(dot->shape());
626 const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape());
627 const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape());
628
629 HloInstruction* target = nullptr;
630 // (current node, user, operand_index)
631 std::vector<std::tuple<HloInstruction*, HloInstruction*, int64>> operands;
632 std::vector<HloInstruction*> users;
633
634 // Find which side of dot has the smallest size:
635 // operand 0, operand 1, or output.
636 if (dot_size <= std::min(lhs_size, rhs_size)) {
637 target = dot;
638 if (dot_size < lhs_size) {
639 operands.emplace_back(lhs, dot, 0);
640 }
641 if (dot_size < rhs_size) {
642 operands.emplace_back(rhs, dot, 1);
643 }
644 } else if (lhs_size <= rhs_size) {
645 target = lhs;
646 if (lhs_size < rhs_size) {
647 operands.emplace_back(rhs, dot, 1);
648 }
649 if (lhs_size < dot_size && dot->user_count() == 1) {
650 users.push_back(dot->users().front());
651 }
652 } else {
653 target = rhs;
654 if (rhs_size < lhs_size) {
655 operands.emplace_back(lhs, dot, 0);
656 }
657 if (rhs_size < dot_size && dot->user_count() == 1) {
658 users.push_back(dot->users().front());
659 }
660 }
661
662 std::vector<float> values;
663
664 // DFS to find scalar multiply ops from the operands.
665 while (!operands.empty()) {
666 HloInstruction* inst;
667 HloInstruction* user;
668 int64 index;
669 std::tie(inst, user, index) = operands.back();
670 operands.pop_back();
671
672 // Skip the op types that are not commutative with multiply.
673 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
674 continue;
675 }
676
677 HloInstruction* operand;
678 HloInstruction* multiplier;
679 // Pattern match a scalar multiply.
680 if (Match(inst, m::MultiplyAnyOrder(
681 m::Op(&operand),
682 m::Broadcast(m::ConstantScalar(&multiplier))))) {
683 CHECK_LT(index, user->operand_count());
684 CHECK_EQ(inst, user->operands()[index]);
685
686 // When found a scalar multiply, save its scalar value.
687 values.push_back(GetConstantValue(multiplier));
688 // And remove the scalar multiply op.
689 TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
690 inst = operand;
691 }
692
693 // Push the operands of inst.
694 int64 i = 0;
695 for (auto* operand : inst->operands()) {
696 operands.emplace_back(operand, inst, i++);
697 }
698 }
699
700 // DFS to find scalar multiply ops from the users.
701 while (!users.empty()) {
702 auto inst = users.back();
703 users.pop_back();
704
705 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
706 continue;
707 }
708
709 HloInstruction* operand;
710 HloInstruction* multiplier;
711 if (Match(inst, m::MultiplyAnyOrder(
712 m::Op(&operand),
713 m::Broadcast(m::ConstantScalar(&multiplier))))) {
714 values.push_back(GetConstantValue(multiplier));
715
716 TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
717 inst = operand;
718 }
719
720 // Process the instructions with only one user.
721 // Otherwise moving scalar multiply to the operands changes the values of
722 // other users.
723 if (inst->user_count() == 1) {
724 users.push_back(inst->users().front());
725 }
726 }
727
728 if (values.empty()) {
729 return Status::OK();
730 }
731
732 changed_ = true;
733
734 // Combine all constant multipliers.
735 float multiplier = 1.0;
736 for (const float v : values) {
737 multiplier *= v;
738 }
739
740 // Create a new const scalar multiply instruction.
741 HloInstruction* new_const_inst;
742 new_const_inst =
743 computation_->AddInstruction(MakeScalarInstruction(target, multiplier));
744
745 // Broadcast the scalar multiplier.
746 HloInstruction* new_broadcast = computation_->AddInstruction(
747 HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
748 // Create a new scalar multiply instruction.
749 HloInstruction* new_multiply =
750 computation_->AddInstruction(HloInstruction::CreateBinary(
751 target->shape(), HloOpcode::kMultiply, target, new_broadcast));
752 CHECK_EQ(new_multiply->shape(), target->shape());
753
754 // Update the dependency with the rest of the instructions.
755 if (target == lhs) {
756 return dot->ReplaceOperandWith(0, new_multiply);
757 } else if (target == rhs) {
758 return dot->ReplaceOperandWith(1, new_multiply);
759 } else {
760 CHECK_EQ(target, dot);
761 return dot->ReplaceAllUsesWith(new_multiply);
762 }
763 }
764
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)765 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
766 HloInstruction* operand) {
767 CHECK_EQ(1, instruction->operand_count());
768 if (operand == nullptr) {
769 operand = instruction->mutable_operand(0);
770 }
771 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
772 ShapeUtil::ElementsIn(operand->shape()));
773 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
774 ShapeUtil::ByteSizeOf(operand->shape()));
775
776 auto bitcast = computation_->AddInstruction(
777 HloInstruction::CreateBitcast(instruction->shape(), operand));
778 bitcast->set_metadata(instruction->metadata());
779 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
780 }
781
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)782 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
783 HloInstruction* old_instruction, HloInstruction* new_instruction) {
784 if (!SameShape(old_instruction, new_instruction)) {
785 return false;
786 }
787 TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
788 return true;
789 }
790
HandleAbs(HloInstruction * abs)791 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
792 HloInstruction* abs_operand = abs->mutable_operand(0);
793 VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
794 << " Abs operand is: " << abs_operand->ToString();
795 if (IsNonNegative(abs->operand(0), options_)) {
796 return ReplaceInstruction(abs, abs_operand);
797 }
798 return Status::OK();
799 }
800
HandleAdd(HloInstruction * add)801 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
802 HloInstruction *lhs, *rhs;
803 CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
804
805 // A + 0 => A
806 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
807 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
808 return Status::OK();
809 }
810 // 0 + A => A
811 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
812 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
813 return Status::OK();
814 }
815
816 // Canonicalization: Put constants on the right. This makes the reassociation
817 // rules below simpler.
818 VLOG(10) << "trying transform [Const + A => A + Const]";
819 if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
820 return ReplaceWithNewInstruction(
821 add,
822 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
823 }
824
825 // Reassociate to allow constant folding.
826 //
827 // Note: This is not general. For example, we won't reassociate
828 //
829 // (A + C1) + (B + C2) => A + B + (C1 + C2).
830 //
831 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
832 HloInstruction *a, *c1, *c2;
833 if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
834 m::Constant(&c2))) ||
835 Match(add, m::Add(m::Add(m::NonConstant(&a),
836 m::Broadcast(m::ConstantScalar(&c1))),
837 m::Broadcast(m::ConstantScalar(&c2))))) {
838 TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
839 MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
840 if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
841 !ShapeUtil::IsScalar(add->shape())) {
842 sum_of_constants = computation_->AddInstruction(
843 HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
844 }
845 return ReplaceWithNewInstruction(
846 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
847 sum_of_constants));
848 }
849
850 // Convert add with fullshape into add with partial shape when a
851 // portion of add is effective:
852 // zero (fullshape) rhs (partialshape)
853 // . | |
854 // . lhs . dynamic_update_slice (fullshape)
855 // . | |
856 // Add (fullshape)
857 //
858 // to:
859 // lhs
860 // |
861 // dynamic_slice (partialshape) rhs (partialshape)
862 // . | |
863 // . lhs . add (partial_shape)+----+
864 // . | |
865 // dynamic_update_slice (fullshape)
866 //
867 // This is pattern is discovered in control flow V2 gradient update.
868 if (Match(add,
869 m::Add(m::Op(&lhs),
870 m::Op(&rhs)
871 .WithOpcode(HloOpcode::kDynamicUpdateSlice)
872 .WithOperand(
873 0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
874 const Shape& partial_shape = rhs->operand(1)->shape();
875 auto sliced_lhs =
876 computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
877 partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
878 partial_shape.dimensions()));
879
880 auto add_partial = computation_->AddInstruction(
881 HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
882 sliced_lhs, rhs->mutable_operand(1)));
883
884 auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
885 lhs->shape(), lhs, add_partial,
886 absl::MakeSpan(rhs->operands()).subspan(2));
887
888 return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
889 }
890
891 // A*C + B*C => (A+B)*C
892 //
893 // - If A, B, and C are integers, do this unconditionally. Proof of
894 // correctness: https://rise4fun.com/Alive/u9X.
895 //
896 // - If A, B, and C are floating point, do this if C is a scalar constant or
897 // broadcast of scalar constant and is equal to +/- 2^k for some (possibly
898 // negative) integer k.
899 //
900 // Multiplying by a power of 2 just moves the exponent, so our answer is
901 // exact modulo rounding of intermediate results so long as
902 //
903 // - none of the three products has an exponent which underflows (so the
904 // result is 0 or denormal), and
905 // - none of the three products overflows to inf.
906 //
907 // Proof: See algebraic_simplifier_proof_distributive_property.py.
908 //
909 // We deem these differences in rounding, underflow, and overflow
910 // acceptable in the ML context.
911 HloInstruction *b, *c;
912 if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
913 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
914 (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
915 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
916 (ShapeUtil::ElementIsIntegral(add->shape()) ||
917 options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
918 return ReplaceWithNewInstruction(
919 add, HloInstruction::CreateBinary(
920 add->shape(), HloOpcode::kMultiply,
921 computation_->AddInstruction(HloInstruction::CreateBinary(
922 add->shape(), HloOpcode::kAdd, a, b)),
923 c));
924 }
925
926 if (options_.is_layout_sensitive()) {
927 return Status::OK();
928 }
929
930 HloInstruction* lhs_scatter_operand = nullptr;
931 HloInstruction* rhs_scatter_operand = nullptr;
932 HloInstruction* lhs_scatter_update = nullptr;
933 HloInstruction* rhs_scatter_update = nullptr;
934 HloInstruction* lhs_scatter_index = nullptr;
935 HloInstruction* rhs_scatter_index = nullptr;
936 bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
937 m::Op(&lhs_scatter_index),
938 m::Op(&lhs_scatter_update))
939 .WithOneUse()) &&
940 Match(lhs->to_apply()->root_instruction(),
941 m::Add(m::Parameter(), m::Parameter()));
942 bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
943 m::Op(&rhs_scatter_index),
944 m::Op(&rhs_scatter_update))
945 .WithOneUse()) &&
946 Match(rhs->to_apply()->root_instruction(),
947 m::Add(m::Parameter(), m::Parameter()));
948 if (rhs_scatter && lhs_scatter) {
949 const auto& lhs_dnums = lhs->scatter_dimension_numbers();
950 const auto& rhs_dnums = rhs->scatter_dimension_numbers();
951 absl::optional<int64> index_concat_dimension;
952 absl::optional<int64> update_concat_dimension;
953 // Don't try to combine scatters of different ranks.
954 if (lhs_scatter_index->shape().rank() !=
955 rhs_scatter_index->shape().rank()) {
956 return Status::OK();
957 }
958
959 int64 first_index_dim = lhs_scatter_index->shape().rank();
960 int64 first_update_dim = lhs_scatter_update->shape().rank();
961 // Find a dimension where it is possible to concatenate the indices and
962 // updates. This is the first and only non-equal dimension or the first
963 // equally sized dimension.
964 for (int64 d = lhs_scatter_index->shape().rank() - 1,
965 update_dim = lhs_scatter_update->shape().rank() - 1;
966 d >= 0; --d) {
967 if (d == lhs_dnums.index_vector_dim()) {
968 continue;
969 }
970 while (
971 absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
972 --update_dim;
973 }
974 if (lhs_scatter_index->shape().dimensions(d) ==
975 rhs_scatter_index->shape().dimensions(d)) {
976 first_index_dim = d;
977 first_update_dim = update_dim--;
978 continue;
979 }
980 // More than one dimension of unequal size was found, bail out.
981 if (index_concat_dimension) {
982 return Status::OK();
983 }
984 index_concat_dimension = d;
985 update_concat_dimension = update_dim--;
986 }
987 if (!index_concat_dimension) {
988 index_concat_dimension = first_index_dim;
989 update_concat_dimension = first_update_dim;
990 }
991
992 // A scalar scatter will require additional reshapes of the index and
993 // update.
994 if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
995 return Status::OK();
996 }
997 const bool update_concat_is_cheap =
998 ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
999 ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
1000 ShapeUtil::ElementsIn(lhs->shape());
1001 if (!update_concat_is_cheap) {
1002 return Status::OK();
1003 }
1004 const bool same_dimension_numbers =
1005 lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
1006 absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
1007 rhs_dnums.scatter_dims_to_operand_dims()) &&
1008 absl::c_equal(lhs_dnums.inserted_window_dims(),
1009 rhs_dnums.inserted_window_dims()) &&
1010 absl::c_equal(lhs_dnums.update_window_dims(),
1011 rhs_dnums.update_window_dims());
1012 const bool index_concat_is_safe =
1013 !lhs->unique_indices() && !rhs->unique_indices() &&
1014 !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
1015 !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
1016
1017 Shape lhs_update_window = ShapeUtil::FilterDimensions(
1018 [&](int64 dim) {
1019 return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
1020 },
1021 lhs_scatter_update->shape());
1022 Shape rhs_update_window = ShapeUtil::FilterDimensions(
1023 [&](int64 dim) {
1024 return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
1025 },
1026 rhs_scatter_update->shape());
1027 // Concatenate the indices and updates
1028 if (index_concat_is_safe && same_dimension_numbers &&
1029 index_concat_dimension &&
1030 lhs_scatter_index->shape().element_type() ==
1031 rhs_scatter_index->shape().element_type() &&
1032 ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
1033 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1034 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1035 rhs_scatter_operand));
1036 TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
1037 MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
1038 *index_concat_dimension));
1039 TF_ASSIGN_OR_RETURN(
1040 HloInstruction * new_update,
1041 MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
1042 *update_concat_dimension));
1043 return ReplaceWithNewInstruction(
1044 add, HloInstruction::CreateScatter(
1045 add->shape(), new_operand, new_index, new_update,
1046 lhs->to_apply(), lhs_dnums, false, false));
1047 }
1048 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
1049 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
1050 rhs_scatter_operand));
1051 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1052 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
1053 return ReplaceInstruction(add, lhs);
1054 } else if (rhs_scatter) {
1055 TF_ASSIGN_OR_RETURN(
1056 HloInstruction * new_operand,
1057 MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
1058 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
1059 return ReplaceInstruction(add, rhs);
1060 } else if (lhs_scatter) {
1061 TF_ASSIGN_OR_RETURN(
1062 HloInstruction * new_operand,
1063 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
1064 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
1065 return ReplaceInstruction(add, lhs);
1066 }
1067 return Status::OK();
1068 }
1069
TrySimplifyTautologicalCompare(HloInstruction * conjunction)1070 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
1071 HloInstruction* conjunction) {
1072 HloInstruction *lhs, *rhs;
1073 if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
1074 return false;
1075 }
1076 struct LessThanCompareInfo { // (LT var constant)
1077 HloInstruction* var;
1078 int64 constant;
1079 };
1080
1081 auto get_compare_info =
1082 [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
1083 HloInstruction *lhs, *rhs;
1084 auto scalar_shape_matcher =
1085 m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
1086 if (Match(cmp, m::Compare(m::Op(&lhs),
1087 m::Constant(&rhs).WithShape(scalar_shape_matcher))
1088 .WithComparisonDirection(ComparisonDirection::kLt))) {
1089 return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
1090 } else if (Match(
1091 cmp,
1092 m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
1093 m::Op(&rhs))
1094 .WithComparisonDirection(ComparisonDirection::kGt))) {
1095 return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
1096 }
1097 return absl::nullopt;
1098 };
1099
1100 absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
1101 absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
1102 if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
1103 int64 new_bound = std::min(lhs_info->constant, rhs_info->constant);
1104 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1105 conjunction,
1106 HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
1107 MakeScalarLike(lhs_info->var, new_bound),
1108 ComparisonDirection::kLt)));
1109 return true;
1110 }
1111 return false;
1112 }
1113
HandleAnd(HloInstruction * logical_and)1114 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1115 HloInstruction *lhs, *rhs;
1116 CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1117 // Simplify logical and
1118 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1119 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1120 // A && True => A
1121 VLOG(10) << "trying transform [A && True => A]: "
1122 << logical_and->ToString();
1123 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1124 return Status::OK();
1125 }
1126 // True && A => A
1127 VLOG(10) << "trying transform [True && A => A]: "
1128 << logical_and->ToString();
1129 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1130 return Status::OK();
1131 }
1132 }
1133
1134 // A && False => False or A & 0 => 0
1135 VLOG(10) << "trying transform [A && False => False]: "
1136 << logical_and->ToString();
1137 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
1138 return Status::OK();
1139 }
1140
1141 // False && A => False or A & 0 => 0
1142 VLOG(10) << "trying transform [False && A => False]: "
1143 << logical_and->ToString();
1144 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
1145 return Status::OK();
1146 }
1147
1148 // Simplify tautological conjunctions.
1149 TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1150 TrySimplifyTautologicalCompare(logical_and));
1151 if (found_tautological_compare) {
1152 return Status::OK();
1153 }
1154
1155 return Status::OK();
1156 }
1157
HandleBitcast(HloInstruction * bitcast)1158 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1159 // If a bitcast feeds a bitcast, make it a single bitcast.
1160 HloInstruction* op;
1161 if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1162 return ReplaceWithNewInstruction(
1163 bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1164 }
1165 // All bitcasts can be eliminated (assuming layout constraints are
1166 // satisfied).
1167 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1168 return Status::OK();
1169 }
1170
HandleBitcastConvert(HloInstruction * bitcast)1171 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1172 HloInstruction* bitcast) {
1173 // Eliminate bitcast converts between same shape.
1174 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
1175 return Status::OK();
1176 }
1177
HandleCopy(HloInstruction * copy)1178 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1179 // If a copy feeds a copy, make it a single copy.
1180 HloInstruction* op;
1181 if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1182 return ReplaceWithNewInstruction(
1183 copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1184 }
1185 // All copies can be eliminated (assuming layout constraints are satisfied).
1186 if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
1187 return Status::OK();
1188 }
1189
1190 if (HloInstruction* bitcast_operand =
1191 BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1192 ReplaceWithBitcast(copy, bitcast_operand);
1193 return Status::OK();
1194 }
1195
1196 // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1197 if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1198 copy->operand(0)->user_count() == 1 &&
1199 ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1200 return ReplaceWithNewInstruction(
1201 copy,
1202 copy->operand(0)->CloneWithNewOperands(
1203 copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1204 }
1205 return Status::OK();
1206 }
1207
HandleConcatenate(HloInstruction * concatenate)1208 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1209 HloInstruction* concatenate) {
1210 absl::Span<HloInstruction* const> operands(concatenate->operands());
1211 if (operands.size() == 1) {
1212 // Unary concatenates are useless.
1213 ReplaceInstructionIfSameShape(concatenate, operands[0]);
1214 return Status::OK();
1215 }
1216 // Filter out and remove empty operands.
1217 std::vector<HloInstruction*> nonempty_operands;
1218 for (HloInstruction* operand : operands) {
1219 if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1220 nonempty_operands.push_back(operand);
1221 }
1222 }
1223 if (nonempty_operands.size() < operands.size()) {
1224 HloInstruction* replacement;
1225 if (nonempty_operands.empty()) {
1226 replacement = operands[0];
1227 } else if (nonempty_operands.size() == 1) {
1228 replacement = nonempty_operands[0];
1229 } else {
1230 replacement =
1231 computation_->AddInstruction(concatenate->CloneWithNewOperands(
1232 concatenate->shape(), nonempty_operands));
1233 }
1234 VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1235 << replacement->ToString();
1236 ReplaceInstructionIfSameShape(concatenate, replacement);
1237 return Status::OK();
1238 }
1239
1240 if (options_.is_layout_sensitive()) {
1241 return Status::OK();
1242 }
1243
1244 // Check if we can merge "adjacent" slice operands which take slices from the
1245 // same other op. For simplicity we only merge unstrided slices.
1246 int64 concatenate_dimension = concatenate->concatenate_dimension();
1247 for (int64 i = 0; i < operands.size(); ++i) {
1248 if (operands[i]->opcode() != HloOpcode::kSlice ||
1249 !IsUnstridedSlice(operands[i])) {
1250 continue;
1251 }
1252 int64 slice_end = operands[i]->slice_limits(concatenate_dimension);
1253 HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1254 int64 j = i + 1;
1255 while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice &&
1256 IsUnstridedSlice(operands[j]) &&
1257 operands[j]->operand(0) == slice_operand &&
1258 operands[j]->slice_starts(concatenate_dimension) == slice_end) {
1259 // Check that all the slice_start values are the same in all other
1260 // dimensions. This implies that the slice_limit values are also the same,
1261 // because operands of concatenate need to have the same shape, and we
1262 // already checked that the slices are unstrided.
1263 bool same_other_starts = true;
1264 for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) {
1265 if (k == concatenate_dimension) {
1266 continue;
1267 }
1268 if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1269 same_other_starts = false;
1270 break;
1271 }
1272 }
1273 if (!same_other_starts) {
1274 break;
1275 }
1276 slice_end = operands[j]->slice_limits(concatenate_dimension);
1277 ++j;
1278 }
1279 if (j - i > 1) {
1280 Shape new_slice_shape = operands[i]->shape();
1281 new_slice_shape.set_dimensions(
1282 concatenate_dimension,
1283 slice_end - operands[i]->slice_starts(concatenate_dimension));
1284 simplifier_->UpdateLayout(&new_slice_shape);
1285 auto new_limit_indices = operands[i]->slice_limits();
1286 new_limit_indices[concatenate_dimension] = slice_end;
1287 auto new_slice_op =
1288 computation_->AddInstruction(HloInstruction::CreateSlice(
1289 new_slice_shape, slice_operand,
1290 /*start_indices=*/operands[i]->slice_starts(),
1291 /*limit_indices=*/new_limit_indices,
1292 /*strides=*/operands[i]->slice_strides()));
1293 std::vector<HloInstruction*> new_operands;
1294 for (int64 k = 0; k < i; ++k) {
1295 new_operands.push_back(operands[k]);
1296 }
1297 new_operands.push_back(new_slice_op);
1298 for (int64 k = j; k < operands.size(); ++k) {
1299 new_operands.push_back(operands[k]);
1300 }
1301 auto replacement =
1302 computation_->AddInstruction(concatenate->CloneWithNewOperands(
1303 concatenate->shape(), new_operands));
1304
1305 // Recurse to handle multiple disjoint sequence of inputs. The
1306 // logic above merge only 1 sequential series of
1307 // inputs. Otherwise, it can lead to the FixPass optimization
1308 // hitting its threshold.
1309 if (ReplaceInstructionIfSameShape(concatenate, replacement)) {
1310 return HandleConcatenate(replacement);
1311 }
1312
1313 return Status::OK();
1314 }
1315 }
1316
1317 if (operands.size() == 2) {
1318 // A binary concat with a broadcasted scalar as an operand can be converted
1319 // into a pad which is simpler to fold into other operations.
1320 bool is_effective_low_pad = Match(
1321 operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1322 bool is_effective_high_pad = Match(
1323 operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1324 if (!is_effective_low_pad && !is_effective_high_pad) {
1325 return Status::OK();
1326 }
1327 PaddingConfig padding_config;
1328 for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1329 auto padding_config_dim = padding_config.add_dimensions();
1330 padding_config_dim->set_edge_padding_high(0);
1331 padding_config_dim->set_edge_padding_low(0);
1332 padding_config_dim->set_interior_padding(0);
1333 if (dim == concatenate_dimension) {
1334 if (is_effective_low_pad) {
1335 padding_config_dim->set_edge_padding_low(
1336 operands[0]->shape().dimensions(dim));
1337 } else {
1338 padding_config_dim->set_edge_padding_high(
1339 operands[1]->shape().dimensions(dim));
1340 }
1341 }
1342 }
1343 int64 operand_to_pad = is_effective_low_pad ? 1 : 0;
1344 int64 pad_value_operand = is_effective_low_pad ? 0 : 1;
1345 HloInstruction* pad =
1346 computation_->AddInstruction(HloInstruction::CreatePad(
1347 concatenate->shape(), operands[operand_to_pad],
1348 operands[pad_value_operand]->mutable_operand(0), padding_config));
1349 return ReplaceInstruction(concatenate, pad);
1350 }
1351
1352 if (absl::c_count(operands, operands[0]) == operands.size() &&
1353 operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1354 Shape new_shape = operands[0]->shape();
1355 absl::InlinedVector<int64, 8> broadcast_dims;
1356 for (int64 i = 0; i < new_shape.rank(); ++i) {
1357 if (i == concatenate_dimension) {
1358 continue;
1359 }
1360 broadcast_dims.push_back(i);
1361 }
1362 new_shape.DeleteDimension(concatenate_dimension);
1363 return ReplaceInstruction(
1364 concatenate,
1365 MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1366 broadcast_dims, concatenate->shape()));
1367 }
1368 return Status::OK();
1369 }
1370
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1371 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1372 const LiteralSlice& literal,
1373 AlgebraicSimplifier* simplifier) {
1374 if (literal.shape().IsTuple()) {
1375 std::vector<HloInstruction*> elems;
1376 elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1377 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1378 elems.push_back(BuildTupleConstant(
1379 computation, LiteralSlice(literal, {i}), simplifier));
1380 }
1381 return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1382 } else {
1383 return computation->AddInstruction(
1384 simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1385 }
1386 }
1387
HandleConstant(HloInstruction * constant)1388 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1389 // Tuple constants aren't directly supported by any backend. Expand them into
1390 // explicit Tuple instructions.
1391 if (constant->shape().IsTuple()) {
1392 return ReplaceInstruction(
1393 constant,
1394 BuildTupleConstant(computation_, constant->literal(), simplifier_));
1395 }
1396
1397 if (constant->shape().element_type() == TOKEN) {
1398 return Status::OK();
1399 }
1400
1401 // If a literal is all the same element replace it with a scalar broadcast.
1402 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1403 constant->literal().IsAllFirst()) {
1404 Literal unique_scalar(
1405 LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1406 HloInstruction* scalar = computation_->AddInstruction(
1407 simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1408 return ReplaceWithNewInstruction(
1409 constant,
1410 HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1411 }
1412
1413 // If a literal is an increasing sequence from zero, replace it with an iota.
1414 if (constant->shape().rank() == 1 &&
1415 ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1416 constant->literal().IsR1Iota()) {
1417 return ReplaceWithNewInstruction(
1418 constant, HloInstruction::CreateIota(constant->shape(), 0));
1419 }
1420 return Status::OK();
1421 }
1422
HandleSubtract(HloInstruction * sub)1423 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1424 HloInstruction *lhs, *rhs;
1425 CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1426 // A - 0 => A
1427 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1428 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
1429 return Status::OK();
1430 }
1431
1432 // Canonicalize subtraction of a constant to addition.
1433 VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1434 if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1435 Match(sub, m::Subtract(m::NonConstant(&lhs),
1436 m::Broadcast(m::Constant(&rhs))))) {
1437 HloInstruction* negative_const = computation_->AddInstruction(
1438 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1439 if (const HloInstruction* broadcast =
1440 DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1441 negative_const =
1442 computation_->AddInstruction(HloInstruction::CreateBroadcast(
1443 broadcast->shape(), negative_const, broadcast->dimensions()));
1444 }
1445 return ReplaceWithNewInstruction(
1446 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1447 negative_const));
1448 }
1449
1450 return Status::OK();
1451 }
1452 namespace {
1453 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1454 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1455 return result->Populate<T>([&](absl::Span<const int64> indices) {
1456 return T{1.0} / constant.literal().Get<T>(indices);
1457 });
1458 }
1459
1460 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1461 std::unique_ptr<HloInstruction> TryDivideToShift(
1462 HloInstruction* divide, HloComputation* computation,
1463 AlgebraicSimplifier* simplifier) {
1464 HloInstruction *a, *b, *c;
1465 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1466
1467 if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1468 !Match(b, m::ConstantEffectiveScalar(&c)) &&
1469 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1470 return nullptr;
1471 }
1472
1473 if (ShapeUtil::ElementIsSigned(divide->shape())) {
1474 int64 b_value = c->literal().GetFirstElement<T>();
1475 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
1476 // Handle negative dividends by negating the result of the division.
1477 HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1478
1479 Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1480 simplifier->UpdateLayout(&changed_shape);
1481 auto* dividend_is_negative =
1482 computation->AddInstruction(HloInstruction::CreateCompare(
1483 changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1484
1485 auto* negated_dividend = computation->AddInstruction(
1486 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1487
1488 auto* abs_dividend =
1489 computation->AddInstruction(HloInstruction::CreateTernary(
1490 a->shape(), HloOpcode::kSelect, dividend_is_negative,
1491 negated_dividend, a));
1492
1493 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
1494 divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1495 MakeScalarLike(abs_dividend, tensorflow::Log2Floor64(b_value))));
1496
1497 auto* neqated_quotient =
1498 computation->AddInstruction(HloInstruction::CreateUnary(
1499 quotient->shape(), HloOpcode::kNegate, quotient));
1500
1501 return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1502 dividend_is_negative,
1503 neqated_quotient, quotient);
1504 }
1505 } else {
1506 uint64 b_value = c->literal().GetFirstElement<T>();
1507 if (IsPowerOfTwo(b_value)) {
1508 return HloInstruction::CreateBinary(
1509 divide->shape(), HloOpcode::kShiftRightLogical, a,
1510 MakeScalarLike(a, tensorflow::Log2Floor64(b_value)));
1511 }
1512 }
1513
1514 return nullptr;
1515 }
1516 } // namespace
1517
HandleDivide(HloInstruction * divide)1518 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1519 HloInstruction *a, *b, *c, *d;
1520 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1521 // A/1 => A
1522 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1523 if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
1524 return Status::OK();
1525 }
1526
1527 // A / B => A >> log2(B) if B is a power of 2.
1528 switch (divide->shape().element_type()) {
1529 case S8:
1530 if (std::unique_ptr<HloInstruction> shift =
1531 TryDivideToShift<int8>(divide, computation_, simplifier_)) {
1532 return ReplaceWithNewInstruction(divide, std::move(shift));
1533 }
1534 break;
1535 case S16:
1536 if (std::unique_ptr<HloInstruction> shift =
1537 TryDivideToShift<int16>(divide, computation_, simplifier_)) {
1538 return ReplaceWithNewInstruction(divide, std::move(shift));
1539 }
1540 break;
1541 case S32:
1542 if (std::unique_ptr<HloInstruction> shift =
1543 TryDivideToShift<int32>(divide, computation_, simplifier_)) {
1544 return ReplaceWithNewInstruction(divide, std::move(shift));
1545 }
1546 break;
1547 case S64:
1548 if (std::unique_ptr<HloInstruction> shift =
1549 TryDivideToShift<int64>(divide, computation_, simplifier_)) {
1550 return ReplaceWithNewInstruction(divide, std::move(shift));
1551 }
1552 break;
1553 case U8:
1554 if (std::unique_ptr<HloInstruction> shift =
1555 TryDivideToShift<uint8>(divide, computation_, simplifier_)) {
1556 return ReplaceWithNewInstruction(divide, std::move(shift));
1557 }
1558 break;
1559 case U16:
1560 if (std::unique_ptr<HloInstruction> shift =
1561 TryDivideToShift<uint16>(divide, computation_, simplifier_)) {
1562 return ReplaceWithNewInstruction(divide, std::move(shift));
1563 }
1564 break;
1565 case U32:
1566 if (std::unique_ptr<HloInstruction> shift =
1567 TryDivideToShift<uint32>(divide, computation_, simplifier_)) {
1568 return ReplaceWithNewInstruction(divide, std::move(shift));
1569 }
1570 break;
1571 case U64:
1572 if (std::unique_ptr<HloInstruction> shift =
1573 TryDivideToShift<uint64>(divide, computation_, simplifier_)) {
1574 return ReplaceWithNewInstruction(divide, std::move(shift));
1575 }
1576 break;
1577 default:
1578 break;
1579 }
1580
1581 Shape* shape;
1582 // exp(A)/exp(B) => exp(A-B)
1583 if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1584 .WithShape(m::Shape(&shape)))) {
1585 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1586 HloInstruction* subtract = computation_->AddInstruction(
1587 HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1588 return ReplaceWithNewInstruction(
1589 divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1590 }
1591
1592 // A/exp(B) => A*exp(-B)
1593 if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1594 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1595 HloInstruction* negate = computation_->AddInstruction(
1596 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1597 HloInstruction* new_exp = computation_->AddInstruction(
1598 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1599 return ReplaceWithNewInstruction(
1600 divide, HloInstruction::CreateBinary(divide->shape(),
1601 HloOpcode::kMultiply, a, new_exp));
1602 }
1603
1604 // A/pow(B,C) => A*pow(B,-C)
1605 if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1606 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1607 // The output shape of the created negate operator should be the same as the
1608 // input.
1609 const Shape& negate_shape = c->shape();
1610 HloInstruction* negate = computation_->AddInstruction(
1611 HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1612 // And the power operator should retain the output shape of the old one.
1613 const Shape& new_power_shape = b->shape();
1614 HloInstruction* new_power =
1615 computation_->AddInstruction(HloInstruction::CreateBinary(
1616 new_power_shape, HloOpcode::kPower, b, negate));
1617 return ReplaceWithNewInstruction(
1618 divide, HloInstruction::CreateBinary(
1619 divide->shape(), HloOpcode::kMultiply, a, new_power));
1620 }
1621
1622 // A/sqrt(B) => A*rsqrt(X).
1623 if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1624 auto* rsqrt = computation_->AddInstruction(
1625 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1626 return ReplaceWithNewInstruction(
1627 divide, HloInstruction::CreateBinary(rsqrt->shape(),
1628 HloOpcode::kMultiply, a, rsqrt));
1629 }
1630
1631 // A/rsqrt(B) => A*sqrt(B).
1632 if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1633 auto* sqrt = computation_->AddInstruction(
1634 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1635 return ReplaceWithNewInstruction(
1636 divide, HloInstruction::CreateBinary(sqrt->shape(),
1637 HloOpcode::kMultiply, a, sqrt));
1638 }
1639
1640 // Simplifying integral division would produce unexpected results.
1641 if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1642 return Status::OK();
1643 }
1644
1645 // A / Const => A * (1 / Const)
1646 //
1647 // (Backends can do this transformation, but generally only if the constant is
1648 // a scalar.)
1649 if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1650 (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1651 Shape result_shape = c->literal().shape();
1652 Literal new_literal(result_shape);
1653 switch (result_shape.element_type()) {
1654 case F16:
1655 TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1656 break;
1657 case F32:
1658 TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1659 break;
1660 case BF16:
1661 TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1662 break;
1663 case F64:
1664 TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1665 break;
1666 case C64:
1667 TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1668 break;
1669 case C128:
1670 TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1671 break;
1672 default:
1673 return Status::OK();
1674 }
1675 auto inverse = computation_->AddInstruction(
1676 simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1677 if (b != c) {
1678 inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1679 b->shape(), inverse, b->dimensions()));
1680 }
1681 TF_ASSIGN_OR_RETURN(auto new_divide,
1682 MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1683 return ReplaceInstruction(divide, new_divide);
1684 }
1685
1686 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
1687 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1688 m::Divide(m::Op(&c), m::Op(&d))))) {
1689 TF_ASSIGN_OR_RETURN(auto a_times_d,
1690 MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1691 TF_ASSIGN_OR_RETURN(auto b_times_c,
1692 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1693 TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1694 a_times_d, b_times_c));
1695
1696 return ReplaceInstruction(divide, new_divide);
1697 }
1698
1699 // (A / B) / C => A / (B * C)
1700 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1701 TF_ASSIGN_OR_RETURN(auto b_times_c,
1702 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1703 TF_ASSIGN_OR_RETURN(auto new_divide,
1704 MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1705 return ReplaceInstruction(divide, new_divide);
1706 }
1707
1708 // A / (B / C) => (A*C) / B
1709 if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1710 TF_ASSIGN_OR_RETURN(auto a_times_c,
1711 MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1712 TF_ASSIGN_OR_RETURN(auto new_divide,
1713 MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1714 return ReplaceInstruction(divide, new_divide);
1715 }
1716
1717 // If X is a convert from pred, then
1718 // X / broadcast(Y) => broadcast(1/Y) * X
1719 if (Match(divide,
1720 m::Divide(
1721 m::Convert(&a,
1722 m::Op().WithShape(m::Shape().WithElementType(PRED))),
1723 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1724 TF_ASSIGN_OR_RETURN(
1725 auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1726 auto recip_bcast = computation_->AddInstruction(
1727 HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1728 TF_ASSIGN_OR_RETURN(auto mul,
1729 MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1730 return ReplaceInstruction(divide, mul);
1731 }
1732
1733 return Status::OK();
1734 }
1735
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1736 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1737 HloInstruction* dot) {
1738 const Shape& lhs_shape = dot->operand(0)->shape();
1739 int64 num_degenerate_lhs_dims = 0;
1740 std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1);
1741 for (int64 i = 0; i < lhs_shape.rank(); ++i) {
1742 if (lhs_shape.dimensions(i) == 1) {
1743 ++num_degenerate_lhs_dims;
1744 } else {
1745 lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1746 }
1747 }
1748
1749 const Shape& rhs_shape = dot->operand(1)->shape();
1750 int64 num_degenerate_rhs_dims = 0;
1751 std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1);
1752 for (int64 i = 0; i < rhs_shape.rank(); ++i) {
1753 if (rhs_shape.dimensions(i) == 1) {
1754 ++num_degenerate_rhs_dims;
1755 } else {
1756 rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1757 }
1758 }
1759 if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1760 return false;
1761 }
1762 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1763 DotDimensionNumbers new_dnums;
1764 for (int64 dim : dnums.lhs_batch_dimensions()) {
1765 int64 new_dim = lhs_dimension_map[dim];
1766 if (new_dim != -1) {
1767 new_dnums.add_lhs_batch_dimensions(new_dim);
1768 }
1769 }
1770 for (int64 dim : dnums.lhs_contracting_dimensions()) {
1771 int64 new_dim = lhs_dimension_map[dim];
1772 if (new_dim != -1) {
1773 new_dnums.add_lhs_contracting_dimensions(new_dim);
1774 }
1775 }
1776
1777 for (int64 dim : dnums.rhs_batch_dimensions()) {
1778 int64 new_dim = rhs_dimension_map[dim];
1779 if (new_dim != -1) {
1780 new_dnums.add_rhs_batch_dimensions(new_dim);
1781 }
1782 }
1783 for (int64 dim : dnums.rhs_contracting_dimensions()) {
1784 int64 new_dim = rhs_dimension_map[dim];
1785 if (new_dim != -1) {
1786 new_dnums.add_rhs_contracting_dimensions(new_dim);
1787 }
1788 }
1789
1790 HloInstruction* new_lhs =
1791 num_degenerate_lhs_dims > 0
1792 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1793 ShapeUtil::DropDegenerateDimensions(lhs_shape),
1794 dot->mutable_operand(0)))
1795 : dot->mutable_operand(0);
1796 HloInstruction* new_rhs =
1797 num_degenerate_rhs_dims > 0
1798 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
1799 ShapeUtil::DropDegenerateDimensions(rhs_shape),
1800 dot->mutable_operand(1)))
1801 : dot->mutable_operand(1);
1802 TF_ASSIGN_OR_RETURN(
1803 auto new_dot,
1804 MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
1805 /*preferred_element_type=*/dot->shape().element_type()));
1806 if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
1807 TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
1808 } else {
1809 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
1810 dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
1811 }
1812 return true;
1813 }
1814
OptimizeDotOfConcat(HloInstruction * dot)1815 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1816 HloInstruction* dot) {
1817 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1818 if (dnums.lhs_contracting_dimensions_size() != 1 ||
1819 dnums.lhs_batch_dimensions_size() != 0 ||
1820 dot->shape().dimensions_size() != 2) { // dot output 2D
1821 return nullptr;
1822 }
1823
1824 const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1825 const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1826 HloInstruction *lhs, *rhs;
1827 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1828
1829 TF_ASSIGN_OR_RETURN(
1830 HloInstruction * optimized_lhs_concat,
1831 OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1832 rhs_contracting_dim, /*swapped=*/false));
1833 if (optimized_lhs_concat) {
1834 return optimized_lhs_concat;
1835 }
1836
1837 return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1838 lhs_contracting_dim, /*swapped=*/true);
1839 }
1840
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64 lhs_contracting_dim,HloInstruction * rhs,int64 rhs_contracting_dim,bool swapped)1841 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1842 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
1843 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
1844 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1845 lhs->concatenate_dimension() == lhs_contracting_dim &&
1846 rhs->opcode() == HloOpcode::kConstant;
1847 if (!can_optimize) {
1848 return nullptr;
1849 }
1850
1851 // We're replacing this:
1852 //
1853 // +-----+-----+-----+ +-------------------+
1854 // | | | | | |
1855 // | | | | | R_0 |
1856 // | | | | | |
1857 // | | | | +-------------------+
1858 // | | | | | |
1859 // | L_0 | L_1 | L_2 | * | R_1 |
1860 // | | | | | |
1861 // | | | | +-------------------+
1862 // | | | | | |
1863 // | | | | | R_2 |
1864 // | | | | | |
1865 // +-----+-----+-----+ +-------------------+
1866 //
1867 // with this:
1868 //
1869 // [Sum over i]
1870 //
1871 // +-----+ +-------------------+
1872 // | | | |
1873 // | | * | R_i |
1874 // | | | |
1875 // | | +-------------------+
1876 // | |
1877 // | L_i |
1878 // | |
1879 // | |
1880 // | |
1881 // | |
1882 // | |
1883 // +-----+
1884 //
1885 // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1886 // for free) and the RHS is a constant tensor (and thus can be split at
1887 // compile time). In the future, we may also want to do this when both the
1888 // LHS and the RHS are concatenate operations that line up along the dimension
1889 // being contracted over.
1890 //
1891 // We should be able to generalize this transform to work on a non-constant
1892 // RHS when/if we have in-place slices or support input-fusing slices into
1893 // Dots.
1894
1895 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1896 // the diagram above).
1897 DotDimensionNumbers new_dot_dnums;
1898 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1899 : lhs_contracting_dim);
1900 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1901 : rhs_contracting_dim);
1902
1903 // Here we use the MKN notation, where the contracted dimension has K
1904 // elements and the two non-contracted dimensions have M and N elements.
1905 HloInstruction* add_result = nullptr;
1906 int64 rhs_contracting_dim_offset = 0;
1907 int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1908 for (HloInstruction* concat_op : lhs->operands()) {
1909 int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1910 Shape rhs_slice_shape(rhs->shape());
1911 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1912 simplifier_->UpdateLayout(&rhs_slice_shape);
1913
1914 std::array<int64, 2> start_indices;
1915 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1916 start_indices[1 - rhs_contracting_dim] = 0;
1917
1918 std::array<int64, 2> limit_indices;
1919 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1920 limit_indices[1 - rhs_contracting_dim] = n;
1921
1922 HloInstruction* rhs_slice =
1923 computation_->AddInstruction(HloInstruction::CreateSlice(
1924 rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1925 /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1926
1927 // TODO(b/69062148): We can get rid of `swapped` once all backends support
1928 // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1929 // LHS with dimension 0 of the RHS). But for now we keep the same
1930 // contraction dimensions as the incoming dot operation to ensure the new
1931 // dot operations can be lowered.
1932 HloInstruction *new_dot_lhs, *new_dot_rhs;
1933 if (swapped) {
1934 new_dot_lhs = rhs_slice;
1935 new_dot_rhs = concat_op;
1936 } else {
1937 new_dot_lhs = concat_op;
1938 new_dot_rhs = rhs_slice;
1939 }
1940
1941 auto* new_dot = computation_->AddInstruction(
1942 HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
1943 new_dot_dnums, dot.precision_config()));
1944
1945 if (add_result) {
1946 add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
1947 dot.shape(), HloOpcode::kAdd, add_result, new_dot));
1948 } else {
1949 add_result = new_dot;
1950 }
1951
1952 rhs_contracting_dim_offset += sub_k;
1953 }
1954
1955 return add_result;
1956 }
1957
OptimizeDotOfGather(HloInstruction * dot)1958 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
1959 HloInstruction* dot) {
1960 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1961 if (dnums.lhs_contracting_dimensions_size() != 1 ||
1962 dnums.rhs_contracting_dimensions_size() != 1 ||
1963 dnums.lhs_batch_dimensions_size() != 0 ||
1964 dnums.rhs_batch_dimensions_size() != 0 ||
1965 dot->shape().dimensions_size() != 2) { // dot output 2D
1966 VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
1967 return nullptr;
1968 }
1969
1970 // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
1971 // Currently a Gather is a DynamicSlice.
1972 auto is_dynamic_slice_constant_combination =
1973 [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
1974 // First operand is a DynamicSlice(Constant).
1975 if (a->opcode() != HloOpcode::kDynamicSlice) {
1976 return false;
1977 }
1978 auto* dynamic_slice_op = a->operand(0);
1979 if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
1980 return false;
1981 }
1982 // Second operand is a Constant.
1983 if (b->opcode() != HloOpcode::kConstant) {
1984 return false;
1985 }
1986 // The DynamicSlice output is a vector.
1987 const Shape& dynamic_slice_shape = a->shape();
1988 if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
1989 return false;
1990 }
1991 // Constant size is the same before and after slice in the contracting
1992 // dimension, otherwise we either must precompute for all possible slice
1993 // indices or dot is invalid.
1994 const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
1995 if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
1996 dynamic_slice_shape.dimensions(a_contracting_dimension)) {
1997 return false;
1998 }
1999 return true;
2000 };
2001
2002 HloInstruction* lhs = dot->mutable_operand(0);
2003 HloInstruction* rhs = dot->mutable_operand(1);
2004 int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2005 int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2006
2007 if (!is_dynamic_slice_constant_combination(
2008 lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2009 !is_dynamic_slice_constant_combination(
2010 rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2011 VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2012 "dot(ctB, DS(ctA)), where the two constants have equal "
2013 "contracting dimensions.";
2014 return nullptr;
2015 }
2016
2017 // LHS is DynamicSlice:
2018 // input: dot(DS(ctA), ctB))
2019 // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2020 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2021 // output: DS(dot(ctA, ctB))
2022 // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2023
2024 // RHS is DynamicSlice:
2025 // input: dot(ctA, DS(ctB))
2026 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2027 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2028 // output: DS(dot(ctA, ctB))
2029 // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2030
2031 bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2032 HloDynamicSliceInstruction* dynamic_slice =
2033 lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2034 : Cast<HloDynamicSliceInstruction>(rhs);
2035
2036 // ctA:
2037 HloInstruction* left_operand =
2038 lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2039 // ctB:
2040 HloInstruction* right_operand =
2041 lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2042 // Build ctA x ctB.
2043 const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2044 const int n =
2045 right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2046 auto memoized_shape =
2047 ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2048 simplifier_->UpdateLayout(&memoized_shape);
2049 auto* memoized_inst = computation_->AddInstruction(
2050 HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2051 dnums, dot->precision_config()));
2052 // Get pair {start, 0} or {0, start}.
2053 // Position of start:
2054 int index_of_non_zero_start = lhs_is_dynamic_slice
2055 ? 1 - lhs_contracting_dimension
2056 : 1 - rhs_contracting_dimension;
2057 // Position of zero:
2058 int index_of_zero_start = 1 - index_of_non_zero_start;
2059
2060 // Slice out start and 0 components and reorder if necessary.
2061 auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2062 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2063 simplifier_->UpdateLayout(&s_shape);
2064 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2065 simplifier_->UpdateLayout(&d_shape);
2066 HloInstruction* non_zero_start =
2067 dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2068 HloInstruction* zero_start =
2069 dynamic_slice->mutable_operand(1 + index_of_zero_start);
2070 std::vector<HloInstruction*> new_start_indices;
2071 if (lhs_is_dynamic_slice) {
2072 new_start_indices = {non_zero_start, zero_start};
2073 } else {
2074 new_start_indices = {zero_start, non_zero_start};
2075 }
2076
2077 // Build DynamicSlice(ctA x ctB).
2078 const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2079 const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2080 auto* memoized_lookup =
2081 computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
2082 dot->shape(), memoized_inst, new_start_indices,
2083 {new_slice_m, new_slice_n}));
2084
2085 return memoized_lookup;
2086 }
2087
2088 // This function tries to transform
2089 // dot(reshape(transpose(A)), Const) to
2090 // dot(reshape(A), reshape(transpose(reshape(Const)))),
2091 // so that the reshape and transpose on the Const side can be constant folded.
2092 //
2093 // The basic idea is that since the accumulation in the dot operation is
2094 // associative, so as long as we permute the elements of the contracting
2095 // dimensions on both sides of the dot in the same way, the result of the
2096 // dot is not affected.
2097 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2098 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2099 HloInstruction* dot) {
2100 // This transformation assumes layout is not assigned yet.
2101 if (options_.is_layout_sensitive()) {
2102 return nullptr;
2103 }
2104
2105 // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2106 // remainder of this function easier.
2107 auto dnums = dot->dot_dimension_numbers();
2108 auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2109 auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2110 auto* lhs = dot->mutable_operand(0);
2111 auto* rhs = dot->mutable_operand(1);
2112 if (dot->operand(0)->IsConstant()) {
2113 std::swap(lhs, rhs);
2114 std::swap(lhs_contracting_dims, rhs_contracting_dims);
2115 }
2116
2117 // Require single contracting dim to make the implementation easier to
2118 // track contracting dims.
2119 if (dnums.lhs_contracting_dimensions_size() != 1) {
2120 return nullptr;
2121 }
2122
2123 // Pattern match Dot(reshape(transpose(input), constant))
2124 HloInstruction* reshape;
2125 HloInstruction* transpose;
2126 HloInstruction* input;
2127 HloInstruction* constant;
2128 if (!Match(lhs,
2129 m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2130 !Match(rhs, m::Constant(&constant))) {
2131 return nullptr;
2132 }
2133
2134 // Check that reshape squishes some dims into one dim and that this one
2135 // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2136 // be N - 1, where N is the rank of the reshape output. This means that the
2137 // reshape squishes some dims into one dim. lhs contracting dim should not
2138 // be in unmodified_dims. This means that the squishing target dim is the
2139 // lhs contracting dim.
2140 auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2141 reshape->operand(0)->shape(), reshape->shape());
2142 CHECK_EQ(lhs_contracting_dims.size(), 1);
2143 if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2144 absl::c_any_of(unmodified_dims, [&](const std::pair<int64, int64>& p) {
2145 return p.second == lhs_contracting_dims[0];
2146 })) {
2147 return nullptr;
2148 }
2149
2150 // Virtually pull the reshape into the dot so the dot operates on the
2151 // transpose, with "unsquished" lhs contracting dims. The new contracting
2152 // dims are all of the dims that are modified by the reshape -- that is, every
2153 // dimension that's not in `unmodified_dims[i].first`.
2154 //
2155 // (We don't need to actually create a new dot instruction. We can just keep
2156 // track of lhs and lhs_contracting_dims.)
2157 absl::flat_hash_set<int64> unmodified_transpose_dims;
2158 for (const auto& pair : unmodified_dims) {
2159 unmodified_transpose_dims.insert(pair.first);
2160 }
2161 lhs_contracting_dims.Clear();
2162 for (int64 i = 0; i < transpose->shape().dimensions_size(); ++i) {
2163 if (!unmodified_transpose_dims.contains(i)) {
2164 lhs_contracting_dims.Add(i);
2165 }
2166 }
2167 // We require the "unsquished" lhs contracting dims to be consecutive.
2168 auto is_iota = [](absl::Span<const int64> dims) {
2169 return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
2170 return (b != a + 1);
2171 }) == dims.end();
2172 };
2173 if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
2174 return nullptr;
2175 }
2176 lhs = lhs->mutable_operand(0);
2177
2178 // Check that the transpose only permutes the contracting dims.
2179 const auto& transpose_dims = transpose->dimensions();
2180 for (int64 i = 0; i < transpose_dims.size(); ++i) {
2181 if (transpose_dims[i] != i &&
2182 !absl::c_linear_search(lhs_contracting_dims, i)) {
2183 return nullptr;
2184 }
2185 }
2186 // Virtually pull the transpose into the dot. Now the dot is equivalent to
2187 // a new dot with "permuted" lhs contracting dims.
2188 std::vector<int64> permutation;
2189 for (auto dim : lhs_contracting_dims) {
2190 permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2191 }
2192 CHECK(IsPermutation(permutation));
2193 auto new_lhs_contracting_dims =
2194 ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
2195 lhs_contracting_dims.Clear();
2196 for (auto dim : new_lhs_contracting_dims) {
2197 lhs_contracting_dims.Add(dim);
2198 }
2199 lhs = lhs->mutable_operand(0);
2200
2201 // All checks are passed at this point.
2202 //
2203 // Transform lhs. Remove the transpose and reshape by sorting the lhs
2204 // contracting dims and squishing them into a single one. We don't actually
2205 // squish the lhs_contracting_dims here because we still need the unsquished
2206 // contracting dims to invert reshape and transpose.
2207 absl::c_sort(lhs_contracting_dims);
2208 lhs = computation_->AddInstruction(
2209 HloInstruction::CreateReshape(reshape->shape(), lhs));
2210
2211 // Transform rhs. Say the input HLO is:
2212 //
2213 // t0 = f32[2, 2, 3] parameter(0)
2214 // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2215 // t2 = f32[2, 6] reshape(t1)
2216 // t3 = f32[6, 2] constant(...)
2217 // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2218 // rhs_contracting_dims={0}
2219 //
2220 // At this point in the function, we have decided that the second and third
2221 // dims of t0 can be switched to remove the transpose, and we have
2222 // "virtually decomposed" the input HLO to:
2223 //
2224 // t0 = f32[2, 2, 3] parameter(0)
2225 // t2' = f32[2, 6] reshape(t0)
2226 // t3' = f32[6, 2] ops-to-be-filled ...
2227 // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2228 // rhs_contracting_dims={0}
2229 //
2230 // The rest of this function is to fill in the ops of t3'. To do this, we
2231 // unsquish the contracting dimensions in t3 and then apply the inverse of
2232 // the transpose from t1.
2233
2234 // Invert reshape.
2235 CHECK_EQ(rhs_contracting_dims.size(), 1);
2236 std::vector<int64> rhs_unsquished_shape_dims =
2237 SpanToVector(constant->shape().dimensions());
2238 auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2239 rhs_contracting_dims[0]);
2240 for (auto dim : lhs_contracting_dims) {
2241 it = rhs_unsquished_shape_dims.insert(it,
2242 transpose->shape().dimensions(dim));
2243 ++it;
2244 }
2245 HloInstruction* rhs_reshape =
2246 computation_->AddInstruction(HloInstruction::CreateReshape(
2247 ShapeUtil::MakeShape(constant->shape().element_type(),
2248 rhs_unsquished_shape_dims),
2249 constant));
2250 rhs = rhs_reshape;
2251
2252 // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2253 rhs_contracting_dims.Resize(lhs_contracting_dims.size(),
2254 rhs_contracting_dims[0]);
2255 absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2256
2257 // Invert transpose. First compute the shape.
2258 std::vector<int64> rhs_transpose_shape_dims =
2259 SpanToVector(rhs_reshape->shape().dimensions());
2260 it = rhs_transpose_shape_dims.erase(
2261 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2262 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2263 rhs_contracting_dims.size());
2264 for (auto dim : lhs_contracting_dims) {
2265 it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2266 ++it;
2267 }
2268 // Then compute the transpose dims.
2269 std::vector<int64> rhs_transpose_dims(rhs_reshape->shape().rank());
2270 absl::c_iota(rhs_transpose_dims, 0);
2271 it = rhs_transpose_dims.erase(
2272 rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2273 rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2274 rhs_contracting_dims.size());
2275 auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2276 for (auto dim : lhs_contracting_dims) {
2277 it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2278 lhs_contracting_dims[0] +
2279 rhs_contracting_dims[0]);
2280 ++it;
2281 }
2282 HloInstruction* rhs_transpose =
2283 computation_->AddInstruction(HloInstruction::CreateTranspose(
2284 ShapeUtil::MakeShape(constant->shape().element_type(),
2285 rhs_transpose_shape_dims),
2286 rhs_reshape, rhs_transpose_dims));
2287 rhs = rhs_transpose;
2288
2289 // Squish the multiple rhs contracting dims into a single one.
2290 rhs = computation_->AddInstruction(
2291 HloInstruction::CreateReshape(constant->shape(), rhs));
2292
2293 // If we virtually swapped lhs and rhs, we need to swap it back before
2294 // creating new dot.
2295 if (dot->operand(0)->IsConstant()) {
2296 std::swap(lhs, rhs);
2297 }
2298
2299 HloInstruction* new_dot =
2300 computation_->AddInstruction(HloInstruction::CreateDot(
2301 dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2302 return new_dot;
2303 }
2304
HandleDot(HloInstruction * dot)2305 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2306 CHECK(computation_ == dot->parent());
2307 HloInstruction *lhs, *rhs;
2308 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2309 if (options_.is_layout_sensitive()) {
2310 return Status::OK();
2311 }
2312 // Replace a zero element dot with a broadcast of the constant 0.
2313 if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2314 ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2315 ShapeUtil::IsZeroElementArray(rhs->shape())) {
2316 auto zero = computation_->AddInstruction(
2317 simplifier_->CreateConstantWithLayoutUpdated(
2318 LiteralUtil::Zero(dot->shape().element_type())));
2319 return ReplaceWithNewInstruction(
2320 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2321 }
2322
2323 // If there are no contracting dimensions, a dot can be rewritten as
2324 // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2325 if (options_.enable_dot_to_multiply_rewrite() &&
2326 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
2327 TF_ASSIGN_OR_RETURN(
2328 HloInstruction * new_lhs,
2329 NormalizeDotOperandToBatchMajorAndContractingMinor(
2330 lhs,
2331 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2332 AsInt64Slice(
2333 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2334 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2335 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2336 }
2337 TF_ASSIGN_OR_RETURN(
2338 HloInstruction * new_rhs,
2339 NormalizeDotOperandToBatchMajorAndContractingMinor(
2340 rhs,
2341 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2342 AsInt64Slice(
2343 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2344 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2345 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2346 }
2347 if (dot->shape().rank() != lhs->shape().rank()) {
2348 std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
2349 absl::c_iota(lhs_broadcast_dims, 0);
2350 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2351 dot->shape(), new_lhs, lhs_broadcast_dims));
2352 }
2353 if (dot->shape().rank() != rhs->shape().rank()) {
2354 std::vector<int64> rhs_broadcast_dims(
2355 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2356 absl::c_iota(rhs_broadcast_dims, 0);
2357 for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2358 rhs_broadcast_dims.push_back(i);
2359 }
2360 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2361 dot->shape(), new_rhs, rhs_broadcast_dims));
2362 }
2363 return ReplaceWithNewInstruction(
2364 dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2365 new_lhs, new_rhs));
2366 }
2367
2368 // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2369 // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2370 if (options_.enable_dot_strength_reduction() &&
2371 ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2372 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
2373 lhs->shape().rank()) ||
2374 (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
2375 dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
2376 rhs->shape().rank()))) {
2377 TF_ASSIGN_OR_RETURN(
2378 HloInstruction * new_lhs,
2379 NormalizeDotOperandToBatchMajorAndContractingMinor(
2380 lhs,
2381 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
2382 AsInt64Slice(
2383 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
2384 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2385 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2386 }
2387
2388 TF_ASSIGN_OR_RETURN(
2389 HloInstruction * new_rhs,
2390 NormalizeDotOperandToBatchMajorAndContractingMinor(
2391 rhs,
2392 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
2393 AsInt64Slice(
2394 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
2395 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2396 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2397 }
2398
2399 int64 lhs_outer_dims =
2400 lhs->shape().rank() -
2401 (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2402 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2403 int64 rhs_outer_dims =
2404 rhs->shape().rank() -
2405 (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2406 dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
2407 CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2408 if (rhs_outer_dims > 0) {
2409 std::vector<int64> lhs_broadcast_dims(
2410 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2411 absl::c_iota(lhs_broadcast_dims, 0);
2412 lhs_broadcast_dims.resize(lhs->shape().rank());
2413 std::iota(lhs_broadcast_dims.begin() +
2414 dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
2415 lhs_broadcast_dims.end(),
2416 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
2417 rhs_outer_dims);
2418 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2419 new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2420 } else if (lhs_outer_dims > 0) {
2421 std::vector<int64> rhs_broadcast_dims(
2422 dot->dot_dimension_numbers().rhs_batch_dimensions_size());
2423 absl::c_iota(rhs_broadcast_dims, 0);
2424 rhs_broadcast_dims.resize(rhs->shape().rank());
2425 std::iota(rhs_broadcast_dims.begin() +
2426 dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
2427 rhs_broadcast_dims.end(),
2428 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
2429 lhs_outer_dims);
2430 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
2431 new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2432 }
2433
2434 TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2435 MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2436 std::vector<int64> reduce_dims(
2437 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
2438 PrimitiveType dot_type =
2439 ShapeUtil::ElementIsFloating(dot->shape())
2440 ? (dot->shape().element_type() == F64 ? F64 : F32)
2441 : dot->shape().element_type();
2442 new_dot = AsType(new_dot, dot_type);
2443 const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2444 absl::c_iota(
2445 reduce_dims,
2446 outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
2447 new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2448 new_dot = AsType(new_dot, dot->shape().element_type());
2449 return ReplaceInstruction(dot, new_dot);
2450 }
2451
2452 // Simplify dot(reshape(transpose(A)), Const) to:
2453 // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2454 // and transpose on the Const side can be constant folded.
2455 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2456 OptimizeDotOfReorderContractingDims(dot));
2457 if (dot_of_reorder_optimized) {
2458 VLOG(10) << " Replaced dot " << dot->ToString()
2459 << " with new dot operation: "
2460 << dot_of_reorder_optimized->ToString();
2461 return ReplaceInstruction(dot, dot_of_reorder_optimized);
2462 }
2463
2464 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2465 OptimizeDotOfConcat(dot));
2466 if (dot_of_concat_optimized) {
2467 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2468 "constant)...)";
2469 return ReplaceInstruction(dot, dot_of_concat_optimized);
2470 }
2471
2472 // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2473 // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2474 // batched version of dot.
2475 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2476 OptimizeDotOfGather(dot));
2477 if (dot_of_gather_optimized) {
2478 VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2479 "gather(i, dot*(constA, constB))";
2480 return ReplaceInstruction(dot, dot_of_gather_optimized);
2481 }
2482
2483 TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2484 RemoveDegenerateDimensionFromDot(dot));
2485 if (removed_degenerate_dimensions) {
2486 return Status::OK();
2487 }
2488
2489 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
2490 if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
2491 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
2492 dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
2493 dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
2494 lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
2495 DotDimensionNumbers dot_dimension_numbers;
2496 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
2497 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
2498 auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
2499 ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
2500 rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
2501 dot->precision_config()));
2502 return ReplaceWithNewInstruction(
2503 dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
2504 }
2505
2506 return Status::OK();
2507 }
2508
HandleGather(HloInstruction * gather)2509 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2510 const Shape& operand_shape = gather->operand(0)->shape();
2511 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2512 return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2513 }
2514
2515 // Gathering from a scalar operand is simply a broadcast of that scalar
2516 if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2517 HloInstruction* new_operand = gather->mutable_operand(0);
2518 if (operand_shape.rank()) {
2519 TF_ASSIGN_OR_RETURN(new_operand,
2520 MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2521 operand_shape.element_type()),
2522 new_operand));
2523 }
2524 HloInstruction* new_gather =
2525 MakeBroadcastHlo(new_operand, {}, gather->shape());
2526 return ReplaceInstruction(gather, new_gather);
2527 }
2528 // If the operand of a gather is very small, it is easier to fuse a
2529 // sequence of selects.
2530 const Shape& index_shape = gather->operand(1)->shape();
2531 if (operand_shape.rank() == 1 &&
2532 operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2533 gather->gather_dimension_numbers().index_vector_dim() ==
2534 index_shape.rank() &&
2535 gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2536 const int64 operand_elements = operand_shape.dimensions(0);
2537 auto get_value = [&](int64 i) {
2538 auto slice = computation_->AddInstruction(HloInstruction::CreateSlice(
2539 ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2540 gather->mutable_operand(0), {i}, {i + 1}, {1}));
2541 auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape(
2542 ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2543 return computation_->AddInstruction(
2544 HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2545 };
2546 auto result = get_value(0);
2547 auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2548 auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2549 index_shape.element_type());
2550 for (int64 i = 0; i < operand_elements; ++i) {
2551 auto index_mask =
2552 computation_->AddInstruction(HloInstruction::CreateCompare(
2553 pred_shape, gather->mutable_operand(1),
2554 MakeScalarLike(gather->mutable_operand(1), i),
2555 ComparisonDirection::kGe));
2556 result = computation_->AddInstruction(
2557 HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2558 index_mask, get_value(i), result));
2559 }
2560 return ReplaceInstruction(gather, result);
2561 }
2562 return Status::OK();
2563 }
2564
2565 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast)2566 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2567 HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2568 HloInstruction* clamp_upper_bound_bcast) {
2569 HloInstruction* clamp_lower_bound;
2570 CHECK(Match(clamp_lower_bound_bcast,
2571 m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2572 << clamp_lower_bound_bcast->ToString();
2573
2574 HloInstruction* clamp_upper_bound;
2575 CHECK(Match(clamp_upper_bound_bcast,
2576 m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2577 << clamp_upper_bound_bcast->ToString();
2578
2579 const Literal& lower_bound =
2580 Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2581 const Literal& upper_bound =
2582 Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2583
2584 std::unique_ptr<HloInstruction> lower_bound_instr =
2585 HloInstruction::CreateConstant(lower_bound.Clone());
2586 std::unique_ptr<HloInstruction> upper_bound_instr =
2587 HloInstruction::CreateConstant(upper_bound.Clone());
2588
2589 std::unique_ptr<HloInstruction> cloned_instruction =
2590 HloInstruction::CreateCompare(
2591 ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED),
2592 lower_bound_instr.get(), upper_bound_instr.get(),
2593 ComparisonDirection::kLt);
2594
2595 HloEvaluator evaluator;
2596 TF_ASSIGN_OR_RETURN(auto result,
2597 evaluator.Evaluate(cloned_instruction.get()));
2598 if (result.IsAll(true)) {
2599 return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2600 clamp_lower_bound_bcast, to_clamp,
2601 clamp_upper_bound_bcast);
2602 }
2603 return std::unique_ptr<HloInstruction>();
2604 }
2605 } // namespace
2606
HandleMaximum(HloInstruction * maximum)2607 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2608 HloInstruction *lhs, *rhs;
2609 CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2610
2611 HloInstruction* clamp_upper_bound_bcast;
2612 HloInstruction* clamp_lower_bound_bcast;
2613 HloInstruction* to_clamp;
2614 if (Match(maximum, m::MaximumAnyOrder(
2615 m::Broadcast(&clamp_lower_bound_bcast,
2616 m::ConstantEffectiveScalar()),
2617 m::MinimumAnyOrder(
2618 m::Op(&to_clamp),
2619 m::Broadcast(&clamp_upper_bound_bcast,
2620 m::ConstantEffectiveScalar()))))) {
2621 TF_ASSIGN_OR_RETURN(auto clamp,
2622 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2623 clamp_upper_bound_bcast));
2624 if (clamp) {
2625 return ReplaceWithNewInstruction(maximum, std::move(clamp));
2626 }
2627 }
2628
2629 HloInstruction* clamp_lower_bound;
2630 HloInstruction* clamp_upper_bound;
2631 HloInstruction* max_operand;
2632 HloInstruction* clamp;
2633 if (Match(maximum,
2634 m::MaximumAnyOrder(
2635 m::Op(&max_operand),
2636 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2637 m::Op(&clamp_upper_bound))))) {
2638 if (max_operand == clamp_lower_bound &&
2639 ReplaceInstructionIfSameShape(maximum, clamp)) {
2640 return Status::OK();
2641 }
2642 }
2643
2644 return Status::OK();
2645 }
2646
HandleMinimum(HloInstruction * minimum)2647 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2648 HloInstruction *lhs, *rhs;
2649 CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2650
2651 HloInstruction* clamp_upper_bound_bcast;
2652 HloInstruction* clamp_lower_bound_bcast;
2653 HloInstruction* to_clamp;
2654 if (Match(minimum, m::MinimumAnyOrder(
2655 m::Broadcast(&clamp_upper_bound_bcast,
2656 m::ConstantEffectiveScalar()),
2657 m::MaximumAnyOrder(
2658 m::Op(&to_clamp),
2659 m::Broadcast(&clamp_lower_bound_bcast,
2660 m::ConstantEffectiveScalar()))))) {
2661 TF_ASSIGN_OR_RETURN(auto clamp,
2662 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2663 clamp_upper_bound_bcast));
2664 if (clamp) {
2665 return ReplaceWithNewInstruction(minimum, std::move(clamp));
2666 }
2667 }
2668
2669 return Status::OK();
2670 }
2671
HandleClamp(HloInstruction * clamp)2672 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2673 HloInstruction* clamp_lower_bound;
2674 HloInstruction* clamp_upper_bound;
2675 HloInstruction* to_clamp;
2676 CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2677 m::Op(&clamp_upper_bound))));
2678
2679 // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2680 if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2681 m::Op().Is(clamp_upper_bound))) &&
2682 ReplaceInstructionIfSameShape(clamp, to_clamp)) {
2683 return Status::OK();
2684 }
2685
2686 return Status::OK();
2687 }
2688
HandleMultiply(HloInstruction * multiply)2689 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
2690 HloInstruction *lhs, *rhs;
2691 CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
2692 // LHS*1 => LHS
2693 VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
2694 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
2695 return Status::OK();
2696 }
2697 // 1*RHS => RHS
2698 VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
2699 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
2700 return Status::OK();
2701 }
2702
2703 // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
2704 if (IsAll(lhs, 0) &&
2705 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2706 ReplaceInstructionIfSameShape(multiply, lhs)) {
2707 return Status::OK();
2708 }
2709 // LHS*0 => 0
2710 if (IsAll(rhs, 0) &&
2711 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
2712 ReplaceInstructionIfSameShape(multiply, rhs)) {
2713 return Status::OK();
2714 }
2715
2716 {
2717 HloInstruction* abs_operand;
2718 if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
2719 !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
2720 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
2721 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
2722 changed_ = true;
2723 return Status::OK();
2724 }
2725 }
2726
2727 {
2728 HloInstruction *convert_operand, *operand;
2729 // Mul(Convert(Pred), operand) => select(pred, operand, 0)
2730 if (Match(multiply,
2731 m::MultiplyAnyOrder(
2732 m::Op(&operand),
2733 m::Convert(
2734 m::Op(&convert_operand)
2735 .WithShape(m::Shape().WithElementType(PRED)))))) {
2736 HloInstruction* zero_like_multiply =
2737 BroadcastZeros(computation_, multiply->shape().element_type(),
2738 multiply->shape().dimensions());
2739 return ReplaceWithNewInstruction(
2740 multiply, HloInstruction::CreateTernary(
2741 multiply->shape(), HloOpcode::kSelect, convert_operand,
2742 operand, zero_like_multiply));
2743 }
2744 }
2745
2746 {
2747 HloInstruction *a, *b, *c1, *c2;
2748 // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
2749 // constant1*constant2)
2750 if (Match(multiply,
2751 m::MultiplyAnyOrder(
2752 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2753 m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
2754 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2755 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2756 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2757 !ShapeUtil::IsScalar(multiply->shape())) {
2758 product_of_constants =
2759 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2760 multiply->shape(), product_of_constants, {}));
2761 }
2762
2763 return ReplaceWithNewInstruction(
2764 multiply,
2765 HloInstruction::CreateBinary(
2766 multiply->shape(), HloOpcode::kMultiply,
2767 computation_->AddInstruction(HloInstruction::CreateBinary(
2768 multiply->shape(), HloOpcode::kMultiply, a, b)),
2769 product_of_constants));
2770 }
2771 }
2772
2773 {
2774 HloInstruction *a, *c1, *c2;
2775 // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
2776 if (Match(multiply,
2777 m::MultiplyAnyOrder(
2778 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
2779 m::Constant(&c2)))) {
2780 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2781 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2782 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2783 !ShapeUtil::IsScalar(multiply->shape())) {
2784 product_of_constants =
2785 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2786 multiply->shape(), product_of_constants, {}));
2787 }
2788
2789 return ReplaceWithNewInstruction(
2790 multiply,
2791 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
2792 a, product_of_constants));
2793 }
2794 }
2795
2796 {
2797 HloInstruction *a, *b, *constant, *op;
2798 // Mul(Mul(a, constant1), Broadcast(b)) =>
2799 // Mul(Broadcast(Mul(b, constant1), a))
2800 if (Match(multiply,
2801 m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
2802 m::Constant(&constant)),
2803 m::Op(&op))) ||
2804 Match(multiply,
2805 m::MultiplyAnyOrder(
2806 m::MultiplyAnyOrder(m::NonConstant(&a),
2807 m::Broadcast(m::Constant(&constant))),
2808 m::Op(&op)))) {
2809 // Check that the other side was a broadcast, and not of a constant.
2810 if (ShapeUtil::IsScalar(constant->shape()) &&
2811 Match(op, m::Broadcast(m::NonConstant()))) {
2812 auto dims = op->dimensions();
2813 b = op->mutable_operand(0);
2814 if (!ShapeUtil::IsScalar(b->shape())) {
2815 constant = computation_->AddInstruction(
2816 HloInstruction::CreateBroadcast(b->shape(), constant, {}));
2817 }
2818
2819 auto new_mul =
2820 computation_->AddInstruction(HloInstruction::CreateBinary(
2821 b->shape(), HloOpcode::kMultiply, b, constant));
2822
2823 return ReplaceWithNewInstruction(
2824 multiply,
2825 HloInstruction::CreateBinary(
2826 multiply->shape(), HloOpcode::kMultiply, a,
2827 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2828 multiply->shape(), new_mul, dims))));
2829 }
2830 }
2831 }
2832
2833 VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
2834 HloInstruction *a, *c1, *c2;
2835 if (Match(multiply,
2836 m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
2837 m::Constant(&c2))) ||
2838 Match(multiply,
2839 m::Multiply(
2840 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
2841 m::Broadcast(m::ConstantScalar(&c2))))) {
2842 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
2843 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
2844 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
2845 !ShapeUtil::IsScalar(multiply->shape())) {
2846 product_of_constants =
2847 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2848 multiply->shape(), product_of_constants, {}));
2849 }
2850 return ReplaceWithNewInstruction(
2851 multiply,
2852 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
2853 product_of_constants));
2854 }
2855
2856 VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
2857 << multiply->ToString();
2858 if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
2859 auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
2860 multiply->shape(), HloOpcode::kAdd, lhs, rhs));
2861 return ReplaceWithNewInstruction(
2862 multiply,
2863 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
2864 }
2865
2866 VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
2867 << multiply->ToString();
2868 HloInstruction* b;
2869 if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
2870 IsPositive(b, options_)) {
2871 return ReplaceWithNewInstruction(
2872 multiply,
2873 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
2874 MakeScalarLike(b, 1), b));
2875 }
2876
2877 return Status::OK();
2878 }
2879
HandleNegate(HloInstruction * negate)2880 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
2881 // negate(negate(x)) => x
2882 HloInstruction* x;
2883 if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
2884 ReplaceInstructionIfSameShape(negate, x)) {
2885 return Status::OK();
2886 }
2887 return Status::OK();
2888 }
2889
HandleNot(HloInstruction * logical_not)2890 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
2891 // not(not(x)) => x
2892 HloInstruction* x;
2893 if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
2894 ReplaceInstructionIfSameShape(logical_not, x)) {
2895 return Status::OK();
2896 }
2897 return Status::OK();
2898 }
2899
HandleOr(HloInstruction * logical_or)2900 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
2901 HloInstruction *lhs, *rhs;
2902 CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
2903
2904 // Simplify logical or
2905 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
2906 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
2907 // A || True => True
2908 VLOG(10) << "trying transform [A || True => True]: "
2909 << logical_or->ToString();
2910 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
2911 return Status::OK();
2912 }
2913 // True || A => True
2914 VLOG(10) << "trying transform [True || A => True]: "
2915 << logical_or->ToString();
2916 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
2917 return Status::OK();
2918 }
2919 }
2920
2921 // A || False => A and A | 0 => A
2922 VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
2923 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
2924 return Status::OK();
2925 }
2926
2927 // False || A => A and 0 | A => A
2928 VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
2929 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
2930 return Status::OK();
2931 }
2932
2933 return Status::OK();
2934 }
2935
HandleLog(HloInstruction * log)2936 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
2937 // ln(exp(A)) => A
2938 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
2939 HloInstruction *a, *b;
2940 if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
2941 ReplaceInstructionIfSameShape(log, a)) {
2942 return Status::OK();
2943 }
2944
2945 // ln(pow(A,B)) => B*ln(abs(A))
2946 // or B*ln(A) if A is complex.
2947 if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
2948 auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
2949 ? a
2950 : computation_->AddInstruction(HloInstruction::CreateUnary(
2951 log->shape(), HloOpcode::kAbs, a));
2952 auto new_log = computation_->AddInstruction(
2953 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
2954 return ReplaceWithNewInstruction(
2955 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2956 new_log, b));
2957 }
2958
2959 if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
2960 auto new_log = computation_->AddInstruction(
2961 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
2962 return ReplaceWithNewInstruction(
2963 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2964 new_log, MakeScalarLike(log, 0.5)));
2965 }
2966
2967 if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
2968 auto new_log = computation_->AddInstruction(
2969 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
2970 return ReplaceWithNewInstruction(
2971 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
2972 new_log, MakeScalarLike(log, -0.5)));
2973 }
2974
2975 return Status::OK();
2976 }
2977
HandleGetTupleElement(HloInstruction * get_tuple_element)2978 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
2979 HloInstruction* get_tuple_element) {
2980 auto operand = get_tuple_element->mutable_operand(0);
2981 if (operand->opcode() == HloOpcode::kTuple) {
2982 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
2983 VLOG(10) << "trying transform "
2984 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
2985 << get_tuple_element->ToString();
2986 if (ReplaceInstructionIfSameShape(
2987 get_tuple_element,
2988 operand->mutable_operand(get_tuple_element->tuple_index()))) {
2989 return Status::OK();
2990 }
2991 }
2992 return Status::OK();
2993 }
2994
2995 namespace {
2996
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)2997 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
2998 const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
2999 CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3000 return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3001 hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3002 }
3003
3004 // Returns true if the output of "instruction" is a permutation of the
3005 // elements of "operand". Precondition: "operand" is an operand of
3006 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3007 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3008 HloInstruction* operand) {
3009 DCHECK(!instruction->OperandIndices(operand).empty());
3010 switch (instruction->opcode()) {
3011 case HloOpcode::kReshape:
3012 case HloOpcode::kReverse:
3013 case HloOpcode::kTranspose:
3014 return true;
3015 case HloOpcode::kSort:
3016 return (!instruction->shape().IsTuple());
3017 default:
3018 return false;
3019 }
3020 }
3021
3022 // Returns true if the output of "instruction" is a subset of the elements of
3023 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3024 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3025 HloInstruction* operand) {
3026 const auto operand_indices = instruction->OperandIndices(operand);
3027 CHECK(!operand_indices.empty());
3028 if (operand_indices.size() != 1) {
3029 return false;
3030 }
3031 int64 operand_index = operand_indices[0];
3032 switch (instruction->opcode()) {
3033 case HloOpcode::kSlice:
3034 CHECK_EQ(0, operand_index);
3035 return true;
3036 case HloOpcode::kDynamicSlice:
3037 return operand_index == 0;
3038 default:
3039 return false;
3040 }
3041 }
3042
3043 } // namespace
3044
HandleBroadcast(HloInstruction * broadcast)3045 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3046 HloInstruction* operand;
3047 CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3048 auto dims = broadcast->dimensions();
3049 // A degenerate broadcast of a reshape that does not change the number of
3050 // elements can be replaced by a reshape.
3051 if (std::is_sorted(dims.begin(), dims.end()) &&
3052 ShapeUtil::ElementsIn(broadcast->shape()) ==
3053 ShapeUtil::ElementsIn(operand->shape())) {
3054 VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3055 "n(broadcast(X)) == n(X)";
3056 return ReplaceWithNewInstruction(
3057 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3058 }
3059
3060 // A degenerate broadcast that has the same input and output rank can be
3061 // converted into a transpose.
3062 if (broadcast->shape().rank() == operand->shape().rank() &&
3063 ShapeUtil::ElementsIn(broadcast->shape()) ==
3064 ShapeUtil::ElementsIn(operand->shape())) {
3065 VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3066 "n(broadcast(X)) == n(X)";
3067 return ReplaceWithNewInstruction(
3068 broadcast,
3069 HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3070 }
3071
3072 // A broadcast of a reshape which merely inserts 1-sized dimensions can
3073 // elide its operand.
3074 {
3075 bool merely_inserts_or_deletes_1_sized_dimensions;
3076 std::vector<int64> inserted_indices, deleted_indices;
3077 std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
3078 inserted_indices) =
3079 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3080 if (merely_inserts_or_deletes_1_sized_dimensions &&
3081 deleted_indices.empty()) {
3082 std::reverse(inserted_indices.begin(), inserted_indices.end());
3083 for (auto inserted_index : inserted_indices) {
3084 dims.erase(dims.begin() + inserted_index);
3085 }
3086 return ReplaceWithNewInstruction(
3087 broadcast,
3088 HloInstruction::CreateBroadcast(broadcast->shape(),
3089 operand->mutable_operand(0), dims));
3090 }
3091 }
3092
3093 // A Broadcast that feeds a unary element-wise operation can sink the
3094 // broadcast after the unary element-wise operation.
3095 TF_ASSIGN_OR_RETURN(
3096 bool sink_succeeded,
3097 TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3098 changed_ |= sink_succeeded;
3099 if (sink_succeeded) {
3100 return Status::OK();
3101 }
3102
3103 // A scalar broadcast feeding an instruction which only permutes (reshape,
3104 // transpose, sort, reverse) or selects a subset of operand elements (slice,
3105 // dynamic slice) can be replaced with a broadcast directly to the output
3106 // shape of the instruction.
3107 if (ShapeUtil::IsScalar(operand->shape())) {
3108 for (HloInstruction* user : broadcast->users()) {
3109 // Skip if the broadcast user has no uses itself.
3110 if (user->user_count() == 0 && user != computation_->root_instruction()) {
3111 continue;
3112 }
3113 if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3114 OutputIsSubsetOfOperandElements(user, broadcast)) {
3115 VLOG(10) << "transform permuting/subset of a scalar broadcast into "
3116 << "a single broadcast";
3117 HloInstruction* new_broadcast = computation_->AddInstruction(
3118 HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3119 // Use HloInstruction::ReplaceAllUsesWith instead of
3120 // HloComputation::ReplaceWithNewInstruction because we are replacing an
3121 // instruction other than the visited instruction.
3122 changed_ = true;
3123 return user->ReplaceAllUsesWith(new_broadcast);
3124 }
3125 }
3126 return Status::OK();
3127 }
3128
3129 // broadcast(iota) -> iota.
3130 if (operand->opcode() == HloOpcode::kIota) {
3131 return ReplaceWithNewInstruction(
3132 broadcast,
3133 HloInstruction::CreateIota(
3134 broadcast->shape(),
3135 dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3136 }
3137
3138 // Merge two consecutive broadcasts into a single one.
3139 if (operand->opcode() == HloOpcode::kBroadcast) {
3140 std::vector<int64> new_dimensions;
3141 for (auto dim : operand->dimensions()) {
3142 new_dimensions.push_back(dims[dim]);
3143 }
3144 return ReplaceWithNewInstruction(
3145 broadcast,
3146 HloInstruction::CreateBroadcast(
3147 broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3148 }
3149 if (options_.is_layout_sensitive()) {
3150 return Status::OK();
3151 }
3152 if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3153 auto new_operand =
3154 operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3155 ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3156 std::vector<int64> new_dims;
3157 new_dims.reserve(new_operand->shape().rank());
3158 for (int64 i = 0; i < operand->shape().rank(); ++i) {
3159 if (operand->shape().dimensions(i) != 1) {
3160 new_dims.push_back(dims[i]);
3161 }
3162 }
3163 return ReplaceWithNewInstruction(
3164 broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3165 new_operand, new_dims));
3166 }
3167 return Status::OK();
3168 }
3169
HandleCompare(HloInstruction * compare)3170 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3171 HloInstruction* lhs;
3172 HloInstruction* rhs;
3173 CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3174 {
3175 // compare(broadcast(a) + x, broadcast(b)) ==>
3176 // compare(x, broadcast(b-a)), only enabled for integral types.
3177 HloInstruction *x, *a, *b;
3178 if (Match(compare,
3179 m::Compare(
3180 m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape(
3181 m::Shape().IsScalar()))),
3182 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
3183 if (ShapeUtil::ElementIsSigned(x->shape()) &&
3184 ShapeUtil::ElementIsIntegral(x->shape())) {
3185 HloInstruction* sub =
3186 computation_->AddInstruction(HloInstruction::CreateBinary(
3187 b->shape(), HloOpcode::kSubtract, b, a));
3188 HloInstruction* broadcast = computation_->AddInstruction(
3189 HloInstruction::CreateBroadcast(x->shape(), sub, {}));
3190 HloInstruction* new_compare = computation_->AddInstruction(
3191 HloInstruction::CreateCompare(compare->shape(), x, broadcast,
3192 compare->comparison_direction()));
3193 return ReplaceInstruction(compare, new_compare);
3194 }
3195 }
3196 }
3197
3198 if (Cast<HloCompareInstruction>(compare)->type() ==
3199 Comparison::Type::kUnsigned) {
3200 // X u< 0 -> false
3201 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3202 IsAll(rhs, 0)) {
3203 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3204 }
3205 // X u>= 0 -> true
3206 if (compare->comparison_direction() == ComparisonDirection::kGe &&
3207 IsAll(rhs, 0)) {
3208 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3209 }
3210 // 0 u> X -> false
3211 if (compare->comparison_direction() == ComparisonDirection::kGt &&
3212 IsAll(lhs, 0)) {
3213 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3214 }
3215 // 0 u<= X -> true
3216 if (compare->comparison_direction() == ComparisonDirection::kLe &&
3217 IsAll(lhs, 0)) {
3218 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3219 }
3220 }
3221
3222 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3223 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3224 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3225 } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3226 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3227 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3228 } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3229 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3230 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3231 } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3232 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3233 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3234 }
3235 if (lhs == rhs &&
3236 primitive_util::IsIntegralType(lhs->shape().element_type())) {
3237 switch (compare->comparison_direction()) {
3238 case ComparisonDirection::kGt:
3239 case ComparisonDirection::kLt:
3240 case ComparisonDirection::kNe:
3241 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3242 case ComparisonDirection::kEq:
3243 case ComparisonDirection::kGe:
3244 case ComparisonDirection::kLe:
3245 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3246 }
3247 }
3248 return Status::OK();
3249 }
3250
HandleConvert(HloInstruction * convert)3251 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3252 PrimitiveType src_type = convert->operand(0)->shape().element_type();
3253 PrimitiveType dest_type = convert->shape().element_type();
3254 // A conversion to the same element type as the operand is a nop and can be
3255 // removed. A conversion of a constant can be simplified by making a new
3256 // constant.
3257 if (src_type == dest_type) {
3258 return ReplaceInstruction(convert, convert->mutable_operand(0));
3259 }
3260
3261 // Eliminate a convert pair if it is a no-op. The following are a few
3262 // example cases that are being handled:
3263 // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3264 // and convert(A, $TYPE1) is an upcast
3265 // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3266 // and convert(A, $TYPE1) is an upcast and is an integral conversion from
3267 // unsigned to signed (only signed to unsigned conversion is NOT allowed)
3268 // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3269 // convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3270 // $TYPE1) , floor(A), A) -> a case where the first convert has a
3271 // fan-out
3272 if (convert->operand(0)->opcode() == HloOpcode::kConvert &&
3273 IsConvertPairNoOp(convert)) {
3274 return ReplaceInstruction(convert,
3275 convert->mutable_operand(0)->mutable_operand(0));
3276 }
3277 return Status::OK();
3278 }
3279
3280 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3281 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3282 HloInstruction *c0, *c1;
3283 if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3284 c0 == c1) {
3285 return ReplaceInstruction(complex, c0);
3286 }
3287 return Status::OK();
3288 }
3289
3290 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3291 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3292 HloInstruction* op;
3293 if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3294 return ReplaceInstruction(real, op);
3295 }
3296 return Status::OK();
3297 }
3298
3299 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3300 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3301 HloInstruction* op;
3302 if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3303 return ReplaceInstruction(imag, op);
3304 }
3305 return Status::OK();
3306 }
3307
HandleIota(HloInstruction * instruction)3308 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3309 // iota -> zero if the iota dimension never produces an element other than
3310 // zero.
3311 auto* iota = Cast<HloIotaInstruction>(instruction);
3312 if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3313 auto zero = computation_->AddInstruction(
3314 simplifier_->CreateConstantWithLayoutUpdated(
3315 LiteralUtil::Zero(iota->shape().element_type()).Clone()));
3316 return ReplaceWithNewInstruction(
3317 iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
3318 }
3319 return Status::OK();
3320 }
3321
HandlePad(HloInstruction * pad)3322 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3323 if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3324 return ReplaceWithNewInstruction(
3325 pad, HloInstruction::CreateBroadcast(pad->shape(),
3326 pad->mutable_operand(1), {}));
3327 }
3328
3329 // Interior padding on one sized dimensions have no effect. As a result it
3330 // makes other simplifications possible if there is no interior padding.
3331 if (HasInteriorPadding(pad->padding_config())) {
3332 PaddingConfig padding_config = pad->padding_config();
3333 bool cleared_interior_padding = false;
3334 for (int64 i = 0; i < pad->shape().rank(); ++i) {
3335 if (padding_config.dimensions(i).interior_padding() > 0 &&
3336 pad->operand(0)->shape().dimensions(i) == 1) {
3337 cleared_interior_padding = true;
3338 padding_config.mutable_dimensions(i)->set_interior_padding(0);
3339 }
3340 }
3341 if (cleared_interior_padding) {
3342 return ReplaceWithNewInstruction(
3343 pad,
3344 HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3345 pad->mutable_operand(1), padding_config));
3346 }
3347 }
3348
3349 // Eliminate nop pads (padding all zero), and replace a pad with negative
3350 // padding with a pad with non-negative padding followed by a slice.
3351 bool all_zero = true;
3352 bool has_negative = false;
3353 // Used to possibly split off the unchanged padding dimensions.
3354 std::vector<int64> padding_dimensions;
3355 int64 dimension_index = 0;
3356 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3357 if (padding_dimension.edge_padding_low() < 0 ||
3358 padding_dimension.edge_padding_high() < 0) {
3359 has_negative = true;
3360 }
3361 if (padding_dimension.edge_padding_low() != 0 ||
3362 padding_dimension.edge_padding_high() != 0) {
3363 all_zero = false;
3364 padding_dimensions.push_back(dimension_index);
3365 } else if (padding_dimension.interior_padding()) {
3366 padding_dimensions.push_back(dimension_index);
3367 }
3368 dimension_index++;
3369 }
3370
3371 if (all_zero) {
3372 if (ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0))) {
3373 return Status::OK();
3374 }
3375 }
3376
3377 // The context of this optimization can be found at b/163617402
3378 // It tries to capture the case of pad(broadcast(x)), where
3379 // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3380 // a subset of the padded dimensions in pad->config(),
3381 // and the padded dimensions in pad->config() is in turn a strict
3382 // subset of broadcast->shape().dimensions(). The combined op can be
3383 // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3384 // x with dimensions that need to be padded, and broadcast2 extends
3385 // the result of padding to full dimensions.
3386 // TODO(qyi): for future extensions: The condition for broadcast(x)
3387 // ->dimensions() to be a subset of padded dimensions in pad->config()
3388 // does not have to be strictly required, but it makes the calculation
3389 // for optimization easier, so it is required by the current implementation.
3390 // Only the second condition between the padded dimensions and the
3391 // dimensions of the final shape have to be enforced for the optimization
3392 // to make sense. If needed to remove the first constraint, the shape
3393 // calculations across the implementation need to be re-adjusted.
3394 auto pad_dims = padding_dimensions.size();
3395 if (pad_dims < dimension_index &&
3396 pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3397 pad->operand(0)->user_count() == 1 &&
3398 pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3399 // Check broadcast operand dimensions is a subset of pading_dimensions.
3400 // If not, skip the optimization.
3401 bool opt_is_valid = true;
3402 std::vector<int64> broadcast_dimensions;
3403 HloBroadcastInstruction* broadcast =
3404 static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3405 for (auto broadcast_index : broadcast->dimensions()) {
3406 bool found = false;
3407 for (int i = 0; i < pad_dims; ++i) {
3408 if (broadcast_index == padding_dimensions[i]) {
3409 broadcast_dimensions.push_back(i);
3410 found = true;
3411 break;
3412 }
3413 }
3414 if (!found) {
3415 opt_is_valid = false;
3416 break;
3417 }
3418 }
3419 if (opt_is_valid) {
3420 auto pad_shape = pad->shape();
3421 auto broadcast_shape = broadcast->shape();
3422 auto pad_shape1 = pad_shape;
3423 auto broadcast_shape1 = broadcast_shape;
3424 PaddingConfig pad_config;
3425 for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3426 int64 j = padding_dimensions[i];
3427 while (--dimension_index > j) {
3428 broadcast_shape1.DeleteDimension(dimension_index);
3429 pad_shape1.DeleteDimension(dimension_index);
3430 }
3431 }
3432 while (--dimension_index >= 0) {
3433 broadcast_shape1.DeleteDimension(dimension_index);
3434 pad_shape1.DeleteDimension(dimension_index);
3435 }
3436 for (auto dimension_to_pad : padding_dimensions) {
3437 auto dimension = pad_config.add_dimensions();
3438 *dimension = pad->padding_config().dimensions(dimension_to_pad);
3439 }
3440 *broadcast->mutable_shape() = broadcast_shape1;
3441 *broadcast->mutable_dimensions() = broadcast_dimensions;
3442 simplifier_->UpdateLayout(broadcast->mutable_shape());
3443 auto pad2 =
3444 computation_->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3445 *pad2->mutable_padding_config() = pad_config;
3446 simplifier_->UpdateLayout(pad2->mutable_shape());
3447 auto broadcast2 = computation_->AddInstruction(
3448 HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3449 return ReplaceInstruction(pad, broadcast2);
3450 }
3451 }
3452
3453 if (has_negative && options_.enable_negative_padding_replacement()) {
3454 // Pad has negative padding. Replace with a pad with the non-negative
3455 // padding followed by a slice which effectively performs the negative
3456 // padding.
3457 // TODO(b/34628603): Add support for negative padding in the backends, or
3458 // change kPad semantics to disallow negative padding and use slice
3459 // instead.
3460
3461 // First construct the padding config with non-negative entries and the
3462 // compute the shape of this new pad instruction.
3463 PaddingConfig nonzero_padding = pad->padding_config();
3464 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3465 PaddingConfig::PaddingConfigDimension* padding_dimension =
3466 nonzero_padding.mutable_dimensions(i);
3467 // Set negative padding to zero.
3468 if (padding_dimension->edge_padding_low() < 0) {
3469 padding_dimension->set_edge_padding_low(0);
3470 }
3471 if (padding_dimension->edge_padding_high() < 0) {
3472 padding_dimension->set_edge_padding_high(0);
3473 }
3474 }
3475
3476 TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3477 MakePadHlo(pad->mutable_operand(0),
3478 pad->mutable_operand(1), nonzero_padding));
3479 // Copy the layout from the original pad instructions. The new pad and the
3480 // slice instruction should all have the same layout.
3481 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3482 pad->shape(), nonzero_pad->mutable_shape()));
3483 simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3484
3485 // Second, construct the slice instruction to perform the negative
3486 // padding.
3487 std::vector<int64> start_indices;
3488 std::vector<int64> end_indices;
3489 std::vector<int64> strides;
3490 for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3491 const PaddingConfig::PaddingConfigDimension& padding_dimension =
3492 pad->padding_config().dimensions(i);
3493 int64 start = 0;
3494 if (padding_dimension.edge_padding_low() < 0) {
3495 start = -1 * padding_dimension.edge_padding_low();
3496 }
3497 int64 end = nonzero_pad->shape().dimensions(i);
3498 if (padding_dimension.edge_padding_high() < 0) {
3499 end += padding_dimension.edge_padding_high();
3500 }
3501 start_indices.push_back(start);
3502 end_indices.push_back(end);
3503 strides.push_back(1);
3504 }
3505
3506 TF_ASSIGN_OR_RETURN(
3507 HloInstruction * slice,
3508 MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3509 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3510 pad->shape(), slice->mutable_shape()));
3511 simplifier_->UpdateLayout(slice->mutable_shape());
3512
3513 // Verify that the slice shape matches the pad shape.
3514 auto equal = Shape::Equal();
3515 if (!options_.is_layout_sensitive()) {
3516 equal.IgnoreTilesInLayout();
3517 }
3518 TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3519
3520 return ReplaceInstruction(pad, slice);
3521 }
3522
3523 return Status::OK();
3524 }
3525
HandlePower(HloInstruction * power)3526 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3527 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3528 HloInstruction *lhs, *rhs;
3529 CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3530 if (IsAll(rhs, 0)) {
3531 return ReplaceInstruction(power, MakeScalarLike(power, 1));
3532 }
3533
3534 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3535 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
3536 return Status::OK();
3537 }
3538
3539 // pow(exp(A),B) => exp(A*B)
3540 HloInstruction *a, *b;
3541 if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3542 auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
3543 power->shape(), HloOpcode::kMultiply, a, b));
3544 return ReplaceWithNewInstruction(
3545 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3546 a_times_b));
3547 }
3548
3549 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3550 if (IsAll(rhs, 2)) {
3551 return ReplaceWithNewInstruction(
3552 power, HloInstruction::CreateBinary(power->shape(),
3553 HloOpcode::kMultiply, lhs, lhs));
3554 }
3555
3556 // Pow(A, 3) is used in GELU.
3557 VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3558 if (IsAll(rhs, 3)) {
3559 HloInstruction* tmp =
3560 computation_->AddInstruction(HloInstruction::CreateBinary(
3561 power->shape(), HloOpcode::kMultiply, lhs, lhs));
3562 return ReplaceWithNewInstruction(
3563 power, HloInstruction::CreateBinary(power->shape(),
3564 HloOpcode::kMultiply, lhs, tmp));
3565 }
3566
3567 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3568 if (IsAll(rhs, -1)) {
3569 return ReplaceWithNewInstruction(
3570 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3571 MakeScalarLike(lhs, 1), lhs));
3572 }
3573
3574 return Status::OK();
3575 }
3576
3577 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3578 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3579 HloInstruction* broadcast) {
3580 TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3581 bool changed = false;
3582 if (ShapeUtil::IsScalar(broadcast->shape())) {
3583 return false;
3584 }
3585 HloInstruction* operand = broadcast->mutable_operand(0);
3586 auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3587 return instruction->opcode() == HloOpcode::kBroadcast &&
3588 ShapeUtil::IsScalar(instruction->operand(0)->shape());
3589 };
3590 auto is_equal_broadcast = [operand,
3591 broadcast](const HloInstruction* instruction) {
3592 return instruction->opcode() == HloOpcode::kBroadcast &&
3593 ShapeUtil::Equal(operand->shape(),
3594 instruction->operand(0)->shape()) &&
3595 broadcast->dimensions() == instruction->dimensions();
3596 };
3597 auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3598 return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3599 };
3600 for (HloInstruction* user : broadcast->users()) {
3601 if (user->user_count() == 0 && user != computation_->root_instruction()) {
3602 continue;
3603 }
3604 // Do not move reshapes or broadcasts past copies since the shape the copy
3605 // will operate on will change.
3606 if (user->opcode() == HloOpcode::kCopy) {
3607 continue;
3608 }
3609 // Do not change the shape of fusion nodes in case there a multiple shapes
3610 // inside the fusion node already.
3611 if (user->opcode() == HloOpcode::kFusion) {
3612 continue;
3613 }
3614 if (!user->IsElementwise()) {
3615 continue;
3616 }
3617
3618 // Check if all the operands of the user are compatible broadcasts for
3619 // sinking. (They are either scalar broadcasts or broadcasts casting
3620 // from/to the same shape/dimensions)
3621 int64 compatible_broadcast_count = 0;
3622 int64 broadcast_use_count = 0;
3623 for (HloInstruction* user_operand : user->operands()) {
3624 if (is_compatible_broadcast(user_operand)) {
3625 ++compatible_broadcast_count;
3626 } else if (broadcast == user_operand) {
3627 ++broadcast_use_count;
3628 }
3629 }
3630 if (compatible_broadcast_count + broadcast_use_count !=
3631 user->operand_count()) {
3632 continue;
3633 }
3634 std::vector<HloInstruction*> new_operands;
3635 new_operands.reserve(user->operand_count());
3636
3637 Shape changed_shape;
3638 for (HloInstruction* user_operand : user->operands()) {
3639 // If this is a broadcast operand that is not our original broadcast input
3640 // to this function then we might need to change the input.
3641 if (is_compatible_broadcast(user_operand)) {
3642 // If this is a broadcast from a scalar value rewrite a broadcast from
3643 // the scalar to the new shape enforced from the other broadcast
3644 // operands.
3645 if (is_scalar_broadcast(user_operand)) {
3646 changed_shape = ShapeUtil::ChangeElementType(
3647 operand->shape(), user_operand->shape().element_type());
3648 simplifier_->UpdateLayout(&changed_shape);
3649 new_operands.push_back(
3650 computation_->AddInstruction(HloInstruction::CreateBroadcast(
3651 changed_shape, user_operand->mutable_operand(0), {})));
3652 user_operand->SetupDerivedInstruction(new_operands.back());
3653 } else {
3654 // For the non-scalar broadcasts we guarantee that the shape of the
3655 // operand of the broadcast needs to be already a compatible shape.
3656 new_operands.push_back(user_operand->mutable_operand(0));
3657 }
3658 } else {
3659 CHECK_EQ(broadcast, user_operand);
3660 new_operands.push_back(operand);
3661 }
3662 }
3663 VLOG(4) << "Sinking broadcast after user:";
3664 VLOG(4) << " old broadcast: " << broadcast->ToString();
3665 VLOG(4) << " old user: " << user->ToString();
3666 changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
3667 user->shape().element_type());
3668 simplifier_->UpdateLayout(&changed_shape);
3669 HloInstruction* new_user = computation_->AddInstruction(
3670 user->CloneWithNewOperands(changed_shape, new_operands));
3671 VLOG(4) << " new user: " << new_user->ToString();
3672 HloInstruction* new_broadcast =
3673 computation_->AddInstruction(HloInstruction::CreateBroadcast(
3674 user->shape(), new_user, broadcast->dimensions()));
3675 broadcast->SetupDerivedInstruction(new_broadcast);
3676 VLOG(4) << " new broadcast: " << new_broadcast->ToString();
3677 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
3678 changed = true;
3679 }
3680 return changed;
3681 }
3682
3683 namespace {
3684 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)3685 std::unique_ptr<HloInstruction> TryRemainderToAnd(
3686 HloInstruction* remainder, HloComputation* computation,
3687 AlgebraicSimplifier* simplifier) {
3688 HloInstruction *a, *b, *c;
3689 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3690
3691 if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
3692 !Match(b, m::ConstantEffectiveScalar(&c)) &&
3693 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
3694 return nullptr;
3695 }
3696
3697 if (ShapeUtil::ElementIsSigned(remainder->shape())) {
3698 int64 b_value = c->literal().GetFirstElement<T>();
3699 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
3700 // Handle negative dividends by negating the result of the division.
3701 HloInstruction* zero_like_a = BroadcastZeros(
3702 computation, a->shape().element_type(), a->shape().dimensions());
3703
3704 auto* dividend_is_negative =
3705 computation->AddInstruction(HloInstruction::CreateCompare(
3706 ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
3707 ComparisonDirection::kLt));
3708
3709 auto* negated_dividend = computation->AddInstruction(
3710 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
3711
3712 auto* abs_dividend =
3713 computation->AddInstruction(HloInstruction::CreateTernary(
3714 a->shape(), HloOpcode::kSelect, dividend_is_negative,
3715 negated_dividend, a));
3716
3717 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
3718 remainder->shape(), HloOpcode::kAnd, abs_dividend,
3719 MakeScalarLike(abs_dividend, b_value - 1)));
3720
3721 auto* neqated_quotient =
3722 computation->AddInstruction(HloInstruction::CreateUnary(
3723 quotient->shape(), HloOpcode::kNegate, quotient));
3724
3725 return HloInstruction::CreateTernary(
3726 remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
3727 neqated_quotient, quotient);
3728 }
3729 } else {
3730 uint64 b_value = c->literal().GetFirstElement<T>();
3731 if (IsPowerOfTwo(b_value)) {
3732 HloInstruction* mask_amount = computation->AddInstruction(
3733 simplifier->CreateConstantWithLayoutUpdated(
3734 LiteralUtil::CreateR0<T>(b_value - 1)));
3735 if (!ShapeUtil::IsScalar(b->shape())) {
3736 mask_amount = computation->AddInstruction(
3737 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
3738 }
3739 return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
3740 a, mask_amount);
3741 }
3742 }
3743 return nullptr;
3744 }
3745 } // namespace
3746
HandleRemainder(HloInstruction * remainder)3747 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
3748 HloInstruction *a, *b;
3749 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
3750
3751 // (A % B) % B == A % B.
3752 if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
3753 return ReplaceInstruction(remainder, a);
3754 }
3755
3756 // A % B => A & (B - 1) if B is a power of 2.
3757 switch (remainder->shape().element_type()) {
3758 case S8:
3759 if (std::unique_ptr<HloInstruction> shift =
3760 TryRemainderToAnd<int8>(remainder, computation_, simplifier_)) {
3761 return ReplaceWithNewInstruction(remainder, std::move(shift));
3762 }
3763 break;
3764 case S16:
3765 if (std::unique_ptr<HloInstruction> shift =
3766 TryRemainderToAnd<int16>(remainder, computation_, simplifier_)) {
3767 return ReplaceWithNewInstruction(remainder, std::move(shift));
3768 }
3769 break;
3770 case S32:
3771 if (std::unique_ptr<HloInstruction> shift =
3772 TryRemainderToAnd<int32>(remainder, computation_, simplifier_)) {
3773 return ReplaceWithNewInstruction(remainder, std::move(shift));
3774 }
3775 break;
3776 case S64:
3777 if (std::unique_ptr<HloInstruction> shift =
3778 TryRemainderToAnd<int64>(remainder, computation_, simplifier_)) {
3779 return ReplaceWithNewInstruction(remainder, std::move(shift));
3780 }
3781 break;
3782 case U8:
3783 if (std::unique_ptr<HloInstruction> shift =
3784 TryRemainderToAnd<uint8>(remainder, computation_, simplifier_)) {
3785 return ReplaceWithNewInstruction(remainder, std::move(shift));
3786 }
3787 break;
3788 case U16:
3789 if (std::unique_ptr<HloInstruction> shift =
3790 TryRemainderToAnd<uint16>(remainder, computation_, simplifier_)) {
3791 return ReplaceWithNewInstruction(remainder, std::move(shift));
3792 }
3793 break;
3794 case U32:
3795 if (std::unique_ptr<HloInstruction> shift =
3796 TryRemainderToAnd<uint32>(remainder, computation_, simplifier_)) {
3797 return ReplaceWithNewInstruction(remainder, std::move(shift));
3798 }
3799 break;
3800 case U64:
3801 if (std::unique_ptr<HloInstruction> shift =
3802 TryRemainderToAnd<uint64>(remainder, computation_, simplifier_)) {
3803 return ReplaceWithNewInstruction(remainder, std::move(shift));
3804 }
3805 break;
3806 default:
3807 break;
3808 }
3809
3810 // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
3811 //
3812 // Currently this only covers the case when N is a broadcasted constant
3813 // scalar. We could also cover the case when N is a non-broadcasted constant
3814 // with the same value repeated.
3815 HloInstruction* iota;
3816 HloInstruction* divisor;
3817 if (Match(remainder,
3818 m::Remainder(m::Iota(&iota),
3819 m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
3820 // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
3821 // conservative; the iota may overflow and count up to a smaller value than
3822 // this. But that's OK for our purposes here.)
3823 int64 iota_upper_bound = iota->shape().dimensions(
3824 Cast<HloIotaInstruction>(iota)->iota_dimension());
3825 absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3826 std::vector<int64>(0, divisor->shape().dimensions_size()));
3827 if (divisor_val && *divisor_val >= iota_upper_bound) {
3828 return ReplaceInstruction(remainder, iota);
3829 }
3830 }
3831
3832 // (X + N) % N = X % N, so long as X + N does not overflow.
3833 //
3834 // We don't have range tracking in XLA that would let us know whether X + N
3835 // overflows, so for now we only do this simplification when X is an iota. We
3836 // could add other operations where it's easy to see a range, such as
3837 // remainder, convert, etc., though at some point we'd probably want a
3838 // range-tracking analysis.
3839 HloInstruction* bcast;
3840 HloInstruction* addend;
3841 if (Match(
3842 remainder,
3843 m::Remainder(
3844 m::AddAnyOrder(m::Iota(&iota),
3845 m::Broadcast(m::ConstantEffectiveScalar(&addend))),
3846 m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
3847 addend == divisor) {
3848 // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
3849 // that iota_upper_bound is conservative, and the true upper bound may be
3850 // smaller.
3851 int64 iota_upper_bound = iota->shape().dimensions(
3852 Cast<HloIotaInstruction>(iota)->iota_dimension());
3853 absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
3854 std::vector<int64>(0, divisor->shape().dimensions_size()));
3855 if (divisor_val) {
3856 // Check whether divisor_val + iota_upper_bound - 1 overflows.
3857 absl::optional<int64> max_val =
3858 OverflowSafeAdd(*divisor_val, iota_upper_bound);
3859 if (max_val.has_value() &&
3860 FitsInIntegralType(*max_val, iota->shape().element_type())) {
3861 return ReplaceWithNewInstruction(
3862 remainder,
3863 HloInstruction::CreateBinary(remainder->shape(),
3864 HloOpcode::kRemainder, iota, bcast));
3865 }
3866 }
3867 }
3868
3869 return Status::OK();
3870 }
3871
HandleReshape(HloInstruction * reshape)3872 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
3873 auto operand = reshape->mutable_operand(0);
3874
3875 // Reshape directly to empty constant if the shape contains zero-element
3876 // dimension.
3877 if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
3878 // If the instruction doesn't have a layout, use a default layout for
3879 // the literal result.
3880 Shape reshaped_shape = reshape->shape();
3881 if (!LayoutUtil::HasLayout(reshaped_shape)) {
3882 LayoutUtil::SetToDefaultLayout(&reshaped_shape);
3883 }
3884 auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
3885 Literal::CreateFromShape(reshaped_shape));
3886
3887 return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
3888 }
3889
3890 // Delete no-op reshapes, i.e. where shape = operand shape.
3891 if (SameShape(reshape, operand)) {
3892 VLOG(10) << "deleting no-op reshape";
3893 return ReplaceInstruction(reshape, operand);
3894 }
3895
3896 // Merge reshapes.
3897 if (HloOpcode::kReshape == operand->opcode()) {
3898 return ReplaceWithNewInstruction(
3899 reshape, HloInstruction::CreateReshape(reshape->shape(),
3900 operand->mutable_operand(0)));
3901 }
3902
3903 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3904 *operand->mutable_shape() = reshape->shape();
3905 return ReplaceInstruction(reshape, operand);
3906 }
3907
3908 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
3909 auto opt_dims = ReshapeLeavesDimensionsUnmodified(
3910 reshape, reshape->operand(0)->dimensions());
3911 if (opt_dims.has_value()) {
3912 return ReplaceWithNewInstruction(
3913 reshape,
3914 HloInstruction::CreateBroadcast(
3915 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
3916 *opt_dims));
3917 }
3918 }
3919
3920 // reshape(iota) -> iota or a mixed radix calculation like
3921 // s32[2,3,4] reshape(s32[24] iota()) to
3922 // add(
3923 // add(s32[2,3,4] iota() iota_dimension=2,
3924 // 4 * s32[2,3,4] iota() iota_dimension=1),
3925 // 12 * s32[2,3,4] iota() iota_dimension=0).
3926 if (operand->opcode() == HloOpcode::kIota) {
3927 auto* iota = Cast<HloIotaInstruction>(operand);
3928 auto common_factors =
3929 CommonFactors(reshape->operand(0)->shape().dimensions(),
3930 reshape->shape().dimensions());
3931 auto iota_dim = absl::c_find_if(
3932 common_factors, [&](const std::pair<int64, int64>& dim_pair) {
3933 return dim_pair.first == iota->iota_dimension() &&
3934 reshape->shape().dimensions(dim_pair.second) > 1;
3935 });
3936 auto next_dim = absl::c_find_if(
3937 common_factors, [&](const std::pair<int64, int64>& dim_pair) {
3938 return dim_pair.first == iota->iota_dimension() + 1;
3939 });
3940 if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
3941 int64 multiplier = 1;
3942 HloInstruction* new_reshape = nullptr;
3943
3944 for (int64 dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
3945 --dim) {
3946 HloInstruction* new_iota = computation_->AddInstruction(
3947 HloInstruction::CreateIota(reshape->shape(), dim));
3948 iota->SetupDerivedInstruction(new_iota);
3949 if (new_reshape) {
3950 new_reshape =
3951 computation_->AddInstruction(HloInstruction::CreateBinary(
3952 reshape->shape(), HloOpcode::kAdd, new_reshape,
3953 computation_->AddInstruction(HloInstruction::CreateBinary(
3954 reshape->shape(), HloOpcode::kMultiply, new_iota,
3955 MakeScalarLike(reshape, multiplier)))));
3956 reshape->SetupDerivedInstruction(new_reshape);
3957 } else {
3958 new_reshape = new_iota;
3959 }
3960 multiplier *= reshape->shape().dimensions(dim);
3961 }
3962 reshape->SetupDerivedInstruction(new_reshape);
3963 return ReplaceInstruction(reshape, new_reshape);
3964 }
3965 }
3966
3967 // Make this a bitcast if possible.
3968 if (HloInstruction* bitcast_operand =
3969 BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
3970 ReplaceWithBitcast(reshape, bitcast_operand);
3971 }
3972 return Status::OK();
3973 }
3974
HandleReverse(HloInstruction * reverse)3975 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
3976 // When all the dimensions to reverse are trivial (i.e. the bound is 1),
3977 // there is nothing to be done.
3978 auto dim_is_one = [&](int64 i) -> bool {
3979 return reverse->shape().dimensions(i) == 1;
3980 };
3981 if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
3982 return ReplaceInstruction(reverse, reverse->mutable_operand(0));
3983 }
3984 return Status::OK();
3985 }
3986
TrySimplifyScalarSlice(HloInstruction * slice)3987 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
3988 HloInstruction* slice) {
3989 // Only try to do this for effective scalars. We could do the same for slicing
3990 // out larger pieces of padding (replacing with a broadcast of the padding
3991 // value), but this is probably not worth it.
3992 if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
3993 return false;
3994 }
3995
3996 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
3997 VLOG(10) << "Trying to simplify scalar slice of concat";
3998 // Only do this for R1, there's no chance of this being useful otherwise.
3999 if (slice->shape().rank() != 1) {
4000 VLOG(10) << "Not folding, slice is not rank 1";
4001 return false;
4002 }
4003 HloConcatenateInstruction* concat =
4004 Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4005 int64 operand_start = 0;
4006 int64 operand_num = 0;
4007 // Weird loop structure to avoid annoying off-by-one errors.
4008 while (true) {
4009 TF_RET_CHECK(operand_num < concat->operand_count());
4010 const HloInstruction* operand = concat->operand(operand_num);
4011 int64 next_operand_start = operand_start + operand->shape().dimensions(0);
4012 if (next_operand_start > slice->slice_starts(0)) {
4013 break;
4014 }
4015 operand_start = next_operand_start;
4016 operand_num++;
4017 }
4018
4019 bool replaced = ReplaceInstructionIfSameShape(
4020 slice, concat->mutable_operand(operand_num));
4021 if (replaced) {
4022 VLOG(10) << "Folding scalar slice of concat into concat operand";
4023 } else {
4024 VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4025 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4026 slice, HloInstruction::CreateSlice(
4027 slice->shape(), concat->mutable_operand(operand_num),
4028 {slice->slice_starts(0) - operand_start},
4029 {slice->slice_starts(0) - operand_start + 1},
4030 slice->slice_strides())));
4031 }
4032 return true;
4033 }
4034
4035 return false;
4036 }
4037
TryToReorderSliceAndReshape(HloInstruction * slice)4038 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4039 HloInstruction* slice) {
4040 CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4041 if (!IsUnstridedSlice(slice)) {
4042 return false;
4043 }
4044 HloInstruction* reshape = slice->mutable_operand(0);
4045 if (reshape->opcode() != HloOpcode::kReshape) {
4046 return false;
4047 }
4048 HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4049 int64 slice_rank = slice->shape().rank();
4050 std::vector<int64> sliced_dims;
4051 for (int64 i = 0; i < slice_rank; ++i) {
4052 if (slice->slice_starts(i) != 0 ||
4053 slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4054 sliced_dims.push_back(i);
4055 }
4056 }
4057
4058 if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4059 slice->slice_starts(0) == 0) {
4060 const Shape& new_slice_shape = new_slice_operand->shape();
4061 const int64 rank = new_slice_shape.rank();
4062 std::vector<int64> new_slice_starts(rank, 0);
4063 std::vector<int64> new_slice_stides(rank, 1);
4064 std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
4065 new_slice_shape.dimensions().end());
4066 int64 slice_elements = ShapeUtil::ElementsIn(slice->shape());
4067 for (int64 i = rank - 1; i >= 0; --i) {
4068 if (slice_elements >= new_slice_limits[i]) {
4069 if (slice_elements % new_slice_limits[i] != 0) {
4070 return false;
4071 }
4072 slice_elements /= new_slice_limits[i];
4073 } else {
4074 new_slice_limits[i] = slice_elements;
4075 slice_elements = 1;
4076 }
4077 }
4078 HloInstruction* new_slice =
4079 computation_->AddInstruction(HloInstruction::CreateSlice(
4080 ShapeUtil::MakeShape(new_slice_shape.element_type(),
4081 new_slice_limits),
4082 new_slice_operand, new_slice_starts, new_slice_limits,
4083 new_slice_stides));
4084 simplifier_->UpdateLayout(new_slice->mutable_shape());
4085 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4086 slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4087 return true;
4088 }
4089 return false;
4090 }
4091
4092 // Allowing a slice to move through a reverse with any necessary updates to the
4093 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4094 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4095 HloInstruction* slice) {
4096 VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4097 << slice->ToString();
4098 if (Match(slice, m::Slice(m::Reverse()))) {
4099 HloInstruction* reverse = slice->mutable_operand(0);
4100 HloInstruction* reverse_operand = reverse->mutable_operand(0);
4101 std::vector<int64> new_starts = slice->slice_starts();
4102 std::vector<int64> new_limits = slice->slice_limits();
4103 std::vector<int64> new_strides = slice->slice_strides();
4104 for (auto rdim : reverse->dimensions()) {
4105 int64 start = slice->slice_starts(rdim);
4106 int64 limit = slice->slice_limits(rdim);
4107 int64 stride = slice->slice_strides(rdim);
4108 // find_nth allows us to compute the appropriate index to begin
4109 // with during reverse even in the presence of non-unit strides
4110 int64 find_nth = (limit - start - 1) / stride;
4111 find_nth = start + find_nth * stride;
4112 limit = find_nth + 1;
4113 new_starts[rdim] =
4114 (reverse->shape().dimensions(rdim) - start) - (limit - start);
4115 new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4116 VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4117 << limit << " and new (start, limit):" << new_starts[rdim] << ","
4118 << new_limits[rdim];
4119 }
4120 // New slice formed from the reverse_operand, but strides and shape of the
4121 // slice output remains the same. New slice's starts and limits are updated
4122 // for ONLY the reversed dimensions as indicated above.
4123 HloInstruction* new_slice = computation_->AddInstruction(
4124 HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4125 new_limits, new_strides));
4126 simplifier_->UpdateLayout(new_slice->mutable_shape());
4127 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4128 slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4129 reverse->dimensions())));
4130 // We do not delete the old reverse, since there might be another
4131 // consumer of that reverse (i.e., full reverse output). DCE should take
4132 // care of any deletion that is necessary if there was no use of reverse.
4133 return true;
4134 }
4135 return false;
4136 }
4137
HandleSlice(HloInstruction * slice)4138 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4139 // Delete no-op slices, i.e. where shape = operand shape.
4140 if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
4141 return Status::OK();
4142 }
4143
4144 HloInstruction* pad;
4145 HloInstruction* pad_operand;
4146 if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4147 // Is the result of the slice the pad operand.
4148 bool slice_undoes_pad = true;
4149 // Can the slice be moved to the pad_operand without any padding being read.
4150 bool slice_inside_pad = true;
4151 // Does this slice slice out pading only.
4152 bool slice_in_padding = false;
4153 std::vector<int64> new_starts = slice->slice_starts();
4154 std::vector<int64> new_limits = slice->slice_limits();
4155 for (int64 i = 0; i < slice->shape().rank(); ++i) {
4156 const int64 start = slice->slice_starts(i);
4157 const int64 stride = slice->slice_strides(i);
4158 const int64 limit = slice->slice_limits(i);
4159 const int64 size = pad->shape().dimensions(i);
4160
4161 const auto& dim = pad->padding_config().dimensions(i);
4162 const int64 low = dim.edge_padding_low();
4163 const int64 high = dim.edge_padding_high();
4164 const int64 interior = dim.interior_padding();
4165 const int64 edge = size - high;
4166
4167 if (limit <= low || start >= edge) {
4168 slice_in_padding = true;
4169 break;
4170 }
4171
4172 if (start != low || stride - 1 != interior) {
4173 slice_undoes_pad = false;
4174 }
4175
4176 if (start < low || limit > edge || interior != 0 || stride != 1) {
4177 slice_inside_pad = false;
4178 }
4179 new_starts[i] -= low;
4180 new_limits[i] -= low;
4181 }
4182 if (slice_in_padding) {
4183 HloInstruction* broadcast =
4184 MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4185 *(broadcast->mutable_shape()) = slice->shape();
4186 return ReplaceInstruction(slice, broadcast);
4187 }
4188 if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
4189 return Status::OK();
4190 }
4191 if (slice_inside_pad) {
4192 TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4193 MakeSliceHlo(pad_operand, new_starts, new_limits,
4194 slice->slice_strides()));
4195 *(new_slice->mutable_shape()) = slice->shape();
4196 return ReplaceInstruction(slice, new_slice);
4197 }
4198 }
4199
4200 if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4201 IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4202 HloInstruction* operand_slice = slice->mutable_operand(0);
4203 std::vector<int64> new_slice_starts = slice->slice_starts();
4204 std::vector<int64> new_slice_limits = slice->slice_limits();
4205 for (int64 i = 0; i < new_slice_starts.size(); ++i) {
4206 new_slice_starts[i] += operand_slice->slice_starts(i);
4207 new_slice_limits[i] += operand_slice->slice_starts(i);
4208 }
4209 return ReplaceWithNewInstruction(
4210 slice, HloInstruction::CreateSlice(
4211 slice->shape(), operand_slice->mutable_operand(0),
4212 new_slice_starts, new_slice_limits, slice->slice_strides()));
4213 }
4214
4215 auto only_broadcast_dims_sliced = [&] {
4216 if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4217 return false;
4218 }
4219 for (int64 dim : slice->operand(0)->dimensions()) {
4220 if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4221 slice->slice_limits(dim) !=
4222 slice->operand(0)->shape().dimensions(dim)) {
4223 return false;
4224 }
4225 }
4226 return true;
4227 };
4228 if (only_broadcast_dims_sliced()) {
4229 return ReplaceWithNewInstruction(
4230 slice,
4231 HloInstruction::CreateBroadcast(
4232 slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4233 slice->mutable_operand(0)->dimensions()));
4234 }
4235
4236 TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4237 if (replaced) {
4238 return Status::OK();
4239 }
4240
4241 HloInstruction* broadcast;
4242 HloInstruction* broadcast_operand;
4243 if (Match(slice,
4244 m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4245 std::vector<int64> new_slice_starts;
4246 std::vector<int64> new_slice_strides;
4247 std::vector<int64> new_slice_limits;
4248 new_slice_starts.reserve(broadcast_operand->shape().rank());
4249 new_slice_strides.reserve(broadcast_operand->shape().rank());
4250 new_slice_limits.reserve(broadcast_operand->shape().rank());
4251 for (int64 dim : broadcast->dimensions()) {
4252 new_slice_starts.push_back(slice->slice_starts(dim));
4253 new_slice_strides.push_back(slice->slice_strides(dim));
4254 new_slice_limits.push_back(slice->slice_limits(dim));
4255 }
4256 VLOG(3) << "Sink broadcast through slice";
4257 VLOG(3) << "Original slice: " << slice->ToString();
4258 VLOG(3) << "Original broadcast: " << broadcast->ToString();
4259 auto new_slice_shape = broadcast_operand->shape();
4260 for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) {
4261 int64 size_i = (new_slice_limits[i] - new_slice_starts[i] +
4262 new_slice_strides[i] - 1) /
4263 new_slice_strides[i];
4264 new_slice_shape.set_dimensions(i, size_i);
4265 }
4266 simplifier_->UpdateLayout(&new_slice_shape);
4267 HloComputation* computation = broadcast_operand->parent();
4268 auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice(
4269 new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4270 new_slice_strides));
4271 auto new_broadcast = HloInstruction::CreateBroadcast(
4272 slice->shape(), new_slice, broadcast->dimensions());
4273 VLOG(3) << "New slice: " << slice->ToString();
4274 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4275 return ReplaceWithNewInstruction(slice, std::move(new_broadcast));
4276 }
4277
4278 // Try to simplify concat -> slice to an operand of concat.
4279 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4280 IsUnstridedSlice(slice)) {
4281 auto concat = slice->operand(0);
4282 int64 concat_dim = concat->concatenate_dimension();
4283 int64 piece_start = 0;
4284 for (auto piece : concat->operands()) {
4285 if (!SameShape(piece, slice)) {
4286 piece_start += piece->shape().dimensions(concat_dim);
4287 continue;
4288 }
4289 if (slice->slice_starts(concat_dim) == piece_start) {
4290 return ReplaceInstruction(slice, piece);
4291 }
4292 piece_start += piece->shape().dimensions(concat_dim);
4293 }
4294 }
4295
4296 // Do not try to reorder slices and reshapes after layout assignment as it may
4297 // be invalid.
4298 if (!options_.is_layout_sensitive()) {
4299 TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4300 }
4301 if (replaced) {
4302 return Status::OK();
4303 }
4304
4305 bool reversed = false;
4306 if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4307 TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4308 }
4309 if (reversed) {
4310 return Status::OK();
4311 }
4312
4313 return Status::OK();
4314 }
4315
HandleRsqrt(HloInstruction * rsqrt)4316 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4317 VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4318 << rsqrt->ToString();
4319 HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4320 if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4321 IsAll(rsqrt_operand->operand(1), -2) &&
4322 IsPositive(rsqrt_operand, options_)) {
4323 return ReplaceWithNewInstruction(
4324 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4325 rsqrt_operand->mutable_operand(0)));
4326 }
4327
4328 VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4329 << rsqrt->ToString();
4330 if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4331 IsAll(rsqrt_operand->operand(0), 1) &&
4332 IsPositive(rsqrt_operand->operand(1), options_)) {
4333 return ReplaceWithNewInstruction(
4334 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4335 rsqrt_operand->mutable_operand(1)));
4336 }
4337
4338 return Status::OK();
4339 }
4340
HandleDynamicSlice(HloInstruction * dynamic_slice)4341 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4342 HloInstruction* dynamic_slice) {
4343 auto operand = dynamic_slice->mutable_operand(0);
4344 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4345 return ReplaceInstruction(dynamic_slice, operand);
4346 }
4347 // DynamicSlice where operand has the same size as the output is simply equal
4348 // to operand.
4349 if (SameShape(operand, dynamic_slice)) {
4350 return ReplaceInstruction(dynamic_slice, operand);
4351 }
4352
4353 HloInstruction* broadcast_operand;
4354 if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4355 std::vector<HloInstruction*> new_indices;
4356 new_indices.reserve(broadcast_operand->shape().rank());
4357 std::vector<int64> new_slice_sizes;
4358 new_slice_sizes.reserve(broadcast_operand->shape().rank());
4359
4360 for (int64 dim : operand->dimensions()) {
4361 new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4362 new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4363 }
4364
4365 VLOG(3) << "Sink broadcast through dynamic slice";
4366 VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4367 VLOG(3) << "Original broadcast: " << operand->ToString();
4368 HloInstruction* new_dynamic_slice = broadcast_operand;
4369 if (!new_slice_sizes.empty()) {
4370 auto new_ds_shape = broadcast_operand->shape();
4371 for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) {
4372 new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4373 }
4374 simplifier_->UpdateLayout(&new_ds_shape);
4375 HloComputation* computation = broadcast_operand->parent();
4376 new_dynamic_slice =
4377 computation->AddInstruction(HloInstruction::CreateDynamicSlice(
4378 new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4379 }
4380 auto new_broadcast = HloInstruction::CreateBroadcast(
4381 dynamic_slice->shape(), new_dynamic_slice, operand->dimensions());
4382 VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4383 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4384 return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast));
4385 }
4386
4387 // Convert a dynamic slice into a slice if all offsets are constant and the
4388 // operand is not constant.
4389 if (operand->opcode() != HloOpcode::kConstant &&
4390 absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4391 dynamic_slice->operands().end()),
4392 [](HloInstruction* operand) {
4393 return operand->opcode() == HloOpcode::kConstant &&
4394 ShapeUtil::ElementIsIntegral(operand->shape());
4395 })) {
4396 const int64 rank = operand->shape().rank();
4397 std::vector<int64> slice_starts(rank);
4398 std::vector<int64> slice_limits(rank);
4399 std::vector<int64> slice_strides(rank, 1);
4400
4401 for (int64 i = 0; i < rank; ++i) {
4402 absl::optional<int64> offset =
4403 dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4404 if (!offset || *offset < 0) {
4405 return Status::OK();
4406 }
4407 const int64 max_offset =
4408 dynamic_slice->operand(0)->shape().dimensions(i) -
4409 dynamic_slice->shape().dimensions(i);
4410 slice_starts[i] = std::min(max_offset, *offset);
4411 slice_limits[i] =
4412 std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4413 }
4414 return ReplaceWithNewInstruction(
4415 dynamic_slice,
4416 HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4417 slice_starts, slice_limits, slice_strides));
4418 }
4419 return Status::OK();
4420 }
4421
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)4422 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
4423 HloInstruction* dynamic_update_slice) {
4424 // Rewriting DynamicUpdateSlice when it matches
4425 // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
4426 // to a Pad(x, constant)
4427 // Only Broadcast considered currently, other ops need to be considered
4428 // in the future.
4429 HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
4430 HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
4431 HloInstruction* pad_value;
4432 if (Match(updated,
4433 m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
4434 auto updated_shape = updated->shape();
4435 auto update_shape = dus_update->shape();
4436 auto update_start_indx = dynamic_update_slice->operand(2);
4437 int64 offset = 0;
4438 bool compatible = true;
4439 // Whether the start indices to dynamic update slice is a list,
4440 // output of a tuple/concatenate, we setup the update_start_indx
4441 // appropriately.
4442 if (ShapeUtil::IsScalar(update_start_indx->shape())) {
4443 update_start_indx = dynamic_update_slice;
4444 offset = 2;
4445 } else {
4446 if (update_start_indx->opcode() == HloOpcode::kTuple ||
4447 update_start_indx->opcode() == HloOpcode::kConcatenate) {
4448 offset = 0;
4449 } else {
4450 compatible = false;
4451 }
4452 }
4453 PaddingConfig padding_config;
4454 if (compatible) {
4455 for (int64 dim = 0; dim < updated_shape.rank(); ++dim) {
4456 auto padding_config_dim = padding_config.add_dimensions();
4457 auto slice_dim_start = update_start_indx->operand(dim + offset);
4458 if (!Match(slice_dim_start, m::ConstantScalar())) {
4459 compatible = false;
4460 break;
4461 }
4462 VLOG(2) << "slice: " << slice_dim_start->ToString();
4463 absl::optional<int64> beg =
4464 slice_dim_start->literal().GetFirstInteger();
4465 if (!beg) {
4466 compatible = false;
4467 break;
4468 }
4469 VLOG(2) << "beg value: " << *beg;
4470 auto update_width = ShapeUtil::GetDimension(update_shape, dim);
4471 auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
4472 // Clamp beg so that it is non-negative.
4473 *beg = std::max<int64>(0, *beg);
4474 // Clamp beg so that it is in-bounds.
4475 *beg = std::min<int64>(bcast_width - update_width, *beg);
4476 VLOG(2) << "adjusted beg value: " << *beg;
4477 padding_config_dim->set_edge_padding_low(*beg);
4478 padding_config_dim->set_edge_padding_high(bcast_width -
4479 (*beg + update_width));
4480 // dynamic_update_slice does not specify a stride
4481 padding_config_dim->set_interior_padding(0);
4482 }
4483 }
4484
4485 if (compatible) {
4486 HloInstruction* pad =
4487 computation_->AddInstruction(HloInstruction::CreatePad(
4488 updated_shape, dus_update, pad_value, padding_config));
4489 VLOG(2) << dynamic_update_slice->ToString();
4490 VLOG(2) << " with pad:" << pad->ToString();
4491 VLOG(2) << " Computation before rewrite is: "
4492 << dynamic_update_slice->parent()->ToString();
4493 return ReplaceInstruction(dynamic_update_slice, pad);
4494 }
4495 }
4496
4497 // DynamicUpdateSlice where operand and dus_update have the same size is
4498 // equal to dus_update.
4499 if (SameShape(dynamic_update_slice, dus_update)) {
4500 return ReplaceInstruction(dynamic_update_slice, dus_update);
4501 }
4502
4503 // If any dimension of dus_update is 0, elide the DynamicUpdateSlice. This
4504 // optimization becomes invalid should we later prefer to warn about out of
4505 // bound indices.
4506 if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
4507 return ReplaceInstruction(dynamic_update_slice, updated);
4508 }
4509 return Status::OK();
4510 }
4511
HandleReduce(HloInstruction * hlo)4512 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
4513 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
4514 bool multi_output_reduce = reduce->shape().IsTuple();
4515 // For tuple reduce, we require all reduce shapes to be the same, up to the
4516 // element types, so we can just the first operand and the first result as a
4517 // representative.
4518 auto arg = reduce->inputs()[0];
4519 auto init_value = reduce->init_values()[0];
4520 const Shape& reduce_result_shape =
4521 multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
4522
4523 absl::Span<const int64> dimensions(reduce->dimensions());
4524 HloComputation* function = reduce->to_apply();
4525 if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
4526 ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
4527 if (multi_output_reduce) {
4528 std::vector<HloInstruction*> broadcast_inits;
4529 int64 inputs = reduce->input_count();
4530 for (int64 i = 0; i < inputs; ++i) {
4531 broadcast_inits.push_back(computation_->AddInstruction(
4532 HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
4533 reduce->init_values()[i], {})));
4534 }
4535 return ReplaceWithNewInstruction(
4536 reduce, HloInstruction::CreateTuple(broadcast_inits));
4537 } else {
4538 return ReplaceWithNewInstruction(
4539 reduce,
4540 HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
4541 }
4542 }
4543
4544 // Turn trivial variadic reductions into normal reductions.
4545 if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
4546 reduce->input_count() == 1 &&
4547 Match(function->root_instruction(), m::Tuple())) {
4548 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
4549 replacements;
4550 replacements[function->root_instruction()] = nullptr;
4551 auto new_function = computation_->parent()->AddEmbeddedComputation(
4552 function->CloneWithReplacements(
4553 std::move(replacements), /*extra_parameters=*/{},
4554 /*context=*/nullptr,
4555 /*suffix=*/"clone",
4556 /*new_root=*/function->root_instruction()->operand(0)));
4557 auto new_reduce = computation_->AddInstruction(
4558 HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
4559 reduce->dimensions(), new_function));
4560 return ReplaceWithNewInstruction(reduce,
4561 HloInstruction::CreateTuple({new_reduce}));
4562 }
4563
4564 if (options_.is_layout_sensitive()) {
4565 return Status::OK();
4566 }
4567
4568 // If the reduction results in the same number of elements, then the only
4569 // possible side effect would be a reshape. Since the init_value is an
4570 // identity of the reduction function, we can therefore replace the reduce
4571 // with a simple reshape, ignoring the reduction function completely.
4572 if (ShapeUtil::ElementsIn(reduce_result_shape) ==
4573 ShapeUtil::ElementsIn(arg->shape())) {
4574 if (multi_output_reduce) {
4575 std::vector<HloInstruction*> reshaped_args;
4576 int64 inputs = reduce->input_count();
4577 for (int64 i = 0; i < inputs; ++i) {
4578 reshaped_args.push_back(
4579 computation_->AddInstruction(HloInstruction::CreateReshape(
4580 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
4581 }
4582 return ReplaceWithNewInstruction(
4583 reduce, HloInstruction::CreateTuple(reshaped_args));
4584 } else {
4585 return ReplaceWithNewInstruction(
4586 reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
4587 }
4588 }
4589
4590 // TODO(b/131122694): Most of those optimizations below can be done for
4591 // multi-output reduces.
4592 if (multi_output_reduce) {
4593 return Status::OK();
4594 }
4595
4596 // A Transpose feeding a reduce can simply permute the reduction dimensions
4597 // field if the output of the reduce is a vector or scalar. Higher ranked
4598 // result may require a transpose of the output.
4599 if (reduce_result_shape.rank() <= 1 &&
4600 arg->opcode() == HloOpcode::kTranspose) {
4601 auto transpose_dimensions = arg->dimensions();
4602 std::vector<int64> new_reduce_dimensions;
4603 for (auto dim : dimensions) {
4604 new_reduce_dimensions.push_back(transpose_dimensions[dim]);
4605 }
4606 return ReplaceWithNewInstruction(
4607 reduce, HloInstruction::CreateReduce(
4608 reduce_result_shape, arg->mutable_operand(0), init_value,
4609 new_reduce_dimensions, function));
4610 }
4611
4612 // If a reduce feeds a reduce with the same computation and initial value,
4613 // they can be combined into a single reduce.
4614 if (arg->opcode() == HloOpcode::kReduce &&
4615 init_value->Identical(*arg->operand(1)) &&
4616 *function == *arg->to_apply()) {
4617 // Create a new reduce with the combined reduction dimensions of both
4618 // reduces.
4619 std::vector<int64> arg_dims = arg->dimensions();
4620 absl::c_sort(arg_dims);
4621 std::vector<int64> reduce_dims = reduce->dimensions();
4622 absl::c_sort(reduce_dims);
4623 // Transform reduce_dims to the same rank as the operand of the operand.
4624 for (int64 arg_dim : arg_dims) {
4625 for (int64& dim : reduce_dims) {
4626 if (dim >= arg_dim) {
4627 ++dim;
4628 }
4629 }
4630 }
4631 std::vector<int64> new_dimensions;
4632 new_dimensions.reserve(arg->dimensions().size() +
4633 reduce->dimensions().size());
4634 std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
4635 reduce_dims.end(), std::back_inserter(new_dimensions));
4636 return ReplaceWithNewInstruction(
4637 reduce, HloInstruction::CreateReduce(
4638 reduce_result_shape, arg->mutable_operand(0), init_value,
4639 new_dimensions, function));
4640 }
4641
4642 // A reshape that collapses multiple dimensions into a dimension being
4643 // reduced can just reduce all of those dimensions instead of doing a
4644 // collapsing reshape before a reduction.
4645 if (options_.enable_reduce_of_reshape() &&
4646 arg->opcode() == HloOpcode::kReshape) {
4647 std::vector<std::pair<int64, int64>> unmodified_dims =
4648 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
4649 arg->shape());
4650 std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
4651 std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
4652 for (auto dim : dimensions) {
4653 arg_dim_in_output[dim] = false;
4654 }
4655 for (auto dim_pair : unmodified_dims) {
4656 arg_dim_unmodified[dim_pair.second] = true;
4657 }
4658 // The goal is to verify that all dimensions that are not removed in the
4659 // reduce are unmodified by the reshape. For example:
4660 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
4661 bool can_move_reshape_into_reduce = true;
4662 for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
4663 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
4664 can_move_reshape_into_reduce = false;
4665 }
4666 }
4667 if (can_move_reshape_into_reduce) {
4668 changed_ = true;
4669 absl::flat_hash_set<int64> dimensions_not_to_reduce;
4670 for (auto dim_pair : unmodified_dims) {
4671 if (arg_dim_in_output[dim_pair.second]) {
4672 dimensions_not_to_reduce.insert(dim_pair.first);
4673 }
4674 }
4675 std::vector<int64> new_reduce_dimensions;
4676 for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
4677 if (!dimensions_not_to_reduce.contains(i)) {
4678 new_reduce_dimensions.push_back(i);
4679 }
4680 }
4681 return ReplaceWithNewInstruction(
4682 reduce, HloInstruction::CreateReduce(
4683 reduce_result_shape, arg->mutable_operand(0), init_value,
4684 new_reduce_dimensions, function));
4685 }
4686 }
4687 // Convert Reduce(concat({a,b,...})) to
4688 // map(reduce(a),map(reduce(b),...,))
4689 //
4690 // This should make fusion easier or use less memory bandwidth in the unfused
4691 // case.
4692 if (arg->opcode() == HloOpcode::kConcatenate &&
4693 absl::c_linear_search(reduce->dimensions(),
4694 arg->concatenate_dimension())) {
4695 HloInstruction* old_reduce = nullptr;
4696 for (HloInstruction* operand : arg->operands()) {
4697 HloInstruction* new_reduce = computation_->AddInstruction(
4698 HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
4699 reduce->dimensions(), function));
4700 if (old_reduce != nullptr) {
4701 new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
4702 reduce_result_shape, {old_reduce, new_reduce}, function));
4703 }
4704 old_reduce = new_reduce;
4705 }
4706 return ReplaceInstruction(reduce, old_reduce);
4707 }
4708
4709 HloInstruction *dot, *lhs, *rhs;
4710 // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
4711 // batch dimensions of the dot. The transformation supports reducing other
4712 // dimensions as well.
4713 if (options_.enable_dot_strength_reduction() &&
4714 Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
4715 Match(reduce->to_apply()->root_instruction(),
4716 m::Add(m::Parameter(), m::Parameter())) &&
4717 absl::c_any_of(reduce->dimensions(), [&](int64 dim) {
4718 return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
4719 })) {
4720 const auto& dnums = dot->dot_dimension_numbers();
4721 DotDimensionNumbers new_dnums = dnums;
4722 new_dnums.clear_lhs_batch_dimensions();
4723 new_dnums.clear_rhs_batch_dimensions();
4724 int64 removed_dims = 0;
4725 for (int64 batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
4726 ++batch_dim) {
4727 if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
4728 new_dnums.add_rhs_contracting_dimensions(
4729 dnums.rhs_batch_dimensions(batch_dim));
4730 new_dnums.add_lhs_contracting_dimensions(
4731 dnums.lhs_batch_dimensions(batch_dim));
4732 ++removed_dims;
4733 } else {
4734 new_dnums.add_rhs_batch_dimensions(
4735 dnums.rhs_batch_dimensions(batch_dim));
4736 new_dnums.add_lhs_batch_dimensions(
4737 dnums.lhs_batch_dimensions(batch_dim));
4738 }
4739 }
4740 std::vector<int64> reduce_dims;
4741 for (int64 dim : reduce->dimensions()) {
4742 if (dim >= dnums.lhs_batch_dimensions_size()) {
4743 reduce_dims.push_back(dim - removed_dims);
4744 }
4745 }
4746 TF_ASSIGN_OR_RETURN(
4747 auto new_dot,
4748 MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
4749 /*preferred_element_type=*/dot->shape().element_type()));
4750 dot->SetupDerivedInstruction(new_dot);
4751 if (reduce_dims.empty()) {
4752 return ReplaceInstruction(hlo, new_dot);
4753 }
4754 TF_ASSIGN_OR_RETURN(
4755 auto new_reduce,
4756 MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
4757 reduce->SetupDerivedInstruction(new_reduce);
4758 return ReplaceInstruction(hlo, new_reduce);
4759 }
4760 return Status::OK();
4761 }
4762
HandleReduceWindow(HloInstruction * reduce_window)4763 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
4764 HloInstruction* reduce_window) {
4765 // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
4766 if (reduce_window->shape().IsTuple()) {
4767 return Status::OK();
4768 }
4769 if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
4770 return ReplaceWithNewInstruction(
4771 reduce_window,
4772 HloInstruction::CreateBroadcast(reduce_window->shape(),
4773 reduce_window->mutable_operand(1), {}));
4774 }
4775 auto operand = reduce_window->mutable_operand(0);
4776 const Window& window = reduce_window->window();
4777 auto function = reduce_window->to_apply();
4778 if (ShapeUtil::IsScalar(operand->shape())) {
4779 TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
4780 return ReplaceWithNewInstruction(
4781 reduce_window,
4782 HloInstruction::CreateMap(reduce_window->shape(),
4783 {reduce_window->mutable_operand(1), operand},
4784 function));
4785 }
4786
4787 if (options_.enable_window_reduce_to_reduce_replacement()) {
4788 // A reduce window can be expressed as a reduce and a reshape if all
4789 // dimensions either have a window size of one or the entire dimension. If
4790 // there is no stride, dilation, or padding, this is as easy as checking the
4791 // size of the output shape and window dimension.
4792 //
4793 // The reshape is a bitcast since it adds one-sized dimensions. Often these
4794 // ones are immediately removed as well with another reshape. The
4795 // implementation of reduce tends to be slightly more efficient at reducing
4796 // entire dimensions compared to reduce window.
4797 auto effective_reduce_dims = [&] {
4798 if (window_util::HasStride(window) || window_util::HasDilation(window) ||
4799 window_util::HasPadding(window)) {
4800 return absl::InlinedVector<int64, 8>{};
4801 }
4802 absl::InlinedVector<int64, 8> reduce_dims;
4803 for (int64 i = 0; i < window.dimensions_size(); ++i) {
4804 if (window.dimensions(i).size() == 1) {
4805 continue;
4806 } else if (reduce_window->shape().dimensions(i) == 1) {
4807 reduce_dims.push_back(i);
4808 } else {
4809 return absl::InlinedVector<int64, 8>{};
4810 }
4811 }
4812 return reduce_dims;
4813 }();
4814
4815 // If a reduce window can be expressed as a reduce, do so and reshape the
4816 // output.
4817 if (!effective_reduce_dims.empty()) {
4818 Shape reduce_shape = ShapeUtil::FilterDimensions(
4819 [&](int64 dim) {
4820 return !absl::c_linear_search(effective_reduce_dims, dim);
4821 },
4822 reduce_window->shape());
4823 simplifier_->UpdateLayout(&reduce_shape);
4824 HloInstruction* reduce =
4825 computation_->AddInstruction(HloInstruction::CreateReduce(
4826 /*shape=*/reduce_shape,
4827 /*operand=*/operand,
4828 /*init_value=*/reduce_window->mutable_operand(1),
4829 /*dimensions_to_reduce=*/effective_reduce_dims,
4830 /*reduce_computation=*/function));
4831 return ReplaceWithNewInstruction(
4832 reduce_window,
4833 HloInstruction::CreateReshape(reduce_window->shape(), reduce));
4834 }
4835 }
4836
4837 // This optimization folds a pad op into reduce_window.
4838 HloInstruction* pad;
4839 const HloInstruction* convert = nullptr;
4840 if (operand->opcode() == HloOpcode::kPad) {
4841 pad = operand;
4842 } else if (operand->opcode() == HloOpcode::kConvert &&
4843 operand->operand(0)->opcode() == HloOpcode::kPad) {
4844 convert = operand;
4845 pad = operand->mutable_operand(0);
4846 } else {
4847 VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
4848 return Status::OK();
4849 }
4850
4851 VLOG(10) << "Considering folding Pad: " << pad->ToString()
4852 << "\ninto reduce-window: " << reduce_window->ToString()
4853 << (convert != nullptr
4854 ? absl::StrCat("\nvia convert: ", convert->ToString())
4855 : "");
4856
4857 // Do not fold interior padding into ReduceWindow since the backends do not
4858 // support it.
4859 const PaddingConfig& pad_config = pad->padding_config();
4860 if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
4861 VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
4862 return Status::OK();
4863 }
4864
4865 // If reduce_window already has padding, the pad value of the pad op and the
4866 // init value of reduce_window must match to allow folding the pad.
4867 const HloInstruction* pad_value = pad->operand(1);
4868 const HloInstruction* reduce_init_value = reduce_window->operand(1);
4869 if (pad_value != reduce_init_value) {
4870 auto literals_are_equivalent = [&] {
4871 auto& pad_literal = pad_value->literal();
4872 auto& reduce_init_literal = reduce_init_value->literal();
4873 if (pad_literal == reduce_init_literal) {
4874 return true;
4875 }
4876 auto converted_pad_literal =
4877 pad_literal.ConvertToShape(reduce_init_value->shape());
4878 if (!converted_pad_literal.ok()) {
4879 return false;
4880 }
4881 return converted_pad_literal.ValueOrDie() == reduce_init_literal;
4882 };
4883 // The pad value is usually a constant, so we handle that case and do not
4884 // try to get more fancy about proving equivalence in cases beyond that.
4885 if (pad_value->opcode() != HloOpcode::kConstant ||
4886 reduce_init_value->opcode() != HloOpcode::kConstant ||
4887 !literals_are_equivalent()) {
4888 VLOG(10) << "Not folding pad into reduce-window due to different pad "
4889 "values.";
4890 return Status::OK();
4891 }
4892 }
4893
4894 // If the pad puts a single non-identity value in each window that we're
4895 // reducing, then this is a broadcast.
4896 HloInstruction* pad_operand = pad->mutable_operand(0);
4897 auto is_effective_broadcast = [&] {
4898 if (window_util::HasStride(window)) {
4899 VLOG(10) << "Window has stride.";
4900 return false;
4901 }
4902 if (!window_util::HasSymmetricPadding(pad_config)) {
4903 VLOG(10) << "Window has uneven padding.";
4904 return false;
4905 }
4906 if (HasInteriorPadding(pad_config)) {
4907 VLOG(10) << "Window has interior padding.";
4908 return false;
4909 }
4910 for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
4911 const auto& pad_dimension = pad_config.dimensions(i);
4912 if ((pad_dimension.edge_padding_low() != 0 ||
4913 pad_dimension.edge_padding_high() != 0) &&
4914 pad_operand->shape().dimensions(i) != 1) {
4915 VLOG(10) << "Found non-trivial dimension being padded: " << i;
4916 return false;
4917 }
4918 }
4919 VLOG(10) << "Found to be padding trivial dimensions only.";
4920
4921 for (int64 i = 0; i < window.dimensions_size(); ++i) {
4922 const auto& pad_dimension = pad_config.dimensions(i);
4923 const WindowDimension& window_dimension = window.dimensions(i);
4924 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
4925 pad_dimension.edge_padding_high() != 0);
4926 if (dimension_has_padding &&
4927 window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
4928 VLOG(10) << "Found window did not cover single unpadded element in "
4929 "dimension: "
4930 << i;
4931 return false;
4932 }
4933 if (pad_operand->shape().dimensions(i) != 1 &&
4934 window_dimension.size() != 1) {
4935 VLOG(10) << "Found window covers more than one element in non-trivial "
4936 "dimension: "
4937 << i;
4938 return false;
4939 }
4940 }
4941 VLOG(10) << "Found window covers a single unpadded element.";
4942 return true;
4943 };
4944
4945 HloInstruction* new_reduce_window_operand;
4946 if (convert != nullptr) {
4947 Shape changed_shape = ShapeUtil::ChangeElementType(
4948 pad_operand->shape(), convert->shape().element_type());
4949 simplifier_->UpdateLayout(&changed_shape);
4950 new_reduce_window_operand = computation_->AddInstruction(
4951 HloInstruction::CreateConvert(changed_shape, pad_operand));
4952 } else {
4953 new_reduce_window_operand = pad_operand;
4954 }
4955
4956 if (is_effective_broadcast()) {
4957 VLOG(10) << "Replacing pad/reduce-window with broadcast.";
4958 auto fadd = [this](std::unique_ptr<HloInstruction> x) {
4959 return computation_->AddInstruction(std::move(x));
4960 };
4961 return ReplaceWithNewInstruction(
4962 reduce_window, HloInstruction::CreateBroadcastSequence(
4963 /*output_shape=*/reduce_window->shape(),
4964 /*operand=*/new_reduce_window_operand, fadd));
4965 }
4966
4967 // Carry out the folding of the pad into reduce_window.
4968 VLOG(10) << "Folding pad into reduce-window.";
4969 Window new_window = window;
4970 const int64 rank = reduce_window->shape().rank();
4971 TF_RET_CHECK(pad_config.dimensions_size() == rank);
4972 TF_RET_CHECK(window.dimensions_size() == rank);
4973 for (int64 i = 0; i < rank; ++i) {
4974 const auto& pad_dim = pad_config.dimensions(i);
4975 auto& window_dim = *new_window.mutable_dimensions(i);
4976 window_dim.set_padding_low(window_dim.padding_low() +
4977 pad_dim.edge_padding_low());
4978 window_dim.set_padding_high(window_dim.padding_high() +
4979 pad_dim.edge_padding_high());
4980 if (pad_dim.interior_padding() != 0) {
4981 CHECK_EQ(window_dim.base_dilation(), 1);
4982 window_dim.set_base_dilation(1 + pad_dim.interior_padding());
4983 }
4984 }
4985
4986 return ReplaceWithNewInstruction(
4987 reduce_window, HloInstruction::CreateReduceWindow(
4988 /*shape=*/reduce_window->shape(),
4989 /*operand=*/new_reduce_window_operand,
4990 /*init_value=*/reduce_window->mutable_operand(1),
4991 /*window=*/new_window,
4992 /*reduce_computation=*/function));
4993 }
4994
HandleSelect(HloInstruction * select)4995 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
4996 // select(x, y, y) -> y.
4997 if (select->operand(1) == select->operand(2)) {
4998 return ReplaceInstruction(select, select->mutable_operand(1));
4999 }
5000 // select(true, x, y) -> x.
5001 if (IsAll(select->operand(0), true)) {
5002 return ReplaceInstruction(select, select->mutable_operand(1));
5003 }
5004 // select(false, x, y) -> y.
5005 if (IsAll(select->operand(0), false)) {
5006 return ReplaceInstruction(select, select->mutable_operand(2));
5007 }
5008 return Status::OK();
5009 }
5010
HandleScatter(HloInstruction * scatter)5011 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* scatter) {
5012 if (ShapeUtil::IsZeroElementArray(scatter->operand(2)->shape()) &&
5013 ReplaceInstructionIfSameShape(scatter, scatter->mutable_operand(0))) {
5014 return Status::OK();
5015 }
5016 if (ShapeUtil::IsZeroElementArray(scatter->operand(1)->shape()) &&
5017 SameShape(scatter, scatter->operand(0)) &&
5018 SameShape(scatter, scatter->operand(2))) {
5019 return ReplaceWithNewInstruction(
5020 scatter, HloInstruction::CreateMap(
5021 scatter->shape(),
5022 {scatter->mutable_operand(0), scatter->mutable_operand(2)},
5023 scatter->to_apply()));
5024 }
5025 return Status::OK();
5026 }
HandleSort(HloInstruction * sort)5027 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5028 auto operand = sort->mutable_operand(0);
5029 int64 dimension_to_sort = sort->dimensions(0);
5030 if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5031 operand->shape().dimensions(dimension_to_sort) <= 1) {
5032 if (sort->operand_count() == 1) {
5033 return ReplaceInstruction(sort, operand);
5034 }
5035 // If it is key/value sort, the output of sort is a tuple.
5036 return ReplaceWithNewInstruction(
5037 sort, HloInstruction::CreateTuple(sort->operands()));
5038 }
5039 return Status::OK();
5040 }
5041
HandleSqrt(HloInstruction * sqrt)5042 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
5043 VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
5044 HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
5045 if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
5046 sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
5047 return ReplaceWithNewInstruction(
5048 sqrt, HloInstruction::CreateUnary(
5049 sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
5050 sqrt_operand->mutable_operand(0)));
5051 }
5052 return Status::OK();
5053 }
5054
5055 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64> perm)5056 bool OnlyPermutesDegenerateDims(const Shape& shape,
5057 absl::Span<const int64> perm) {
5058 std::vector<int64> new_permutation;
5059 int64 degenerate_count = 0;
5060 for (int64 i = 0; i < perm.size(); ++i) {
5061 if (shape.dimensions(i) != 1) {
5062 new_permutation.push_back(perm[i]);
5063 } else {
5064 ++degenerate_count;
5065 }
5066 }
5067 return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
5068 }
5069 } // namespace
5070
HandleTranspose(HloInstruction * transpose)5071 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
5072 auto operand = transpose->mutable_operand(0);
5073 if (std::is_sorted(transpose->dimensions().begin(),
5074 transpose->dimensions().end())) {
5075 VLOG(10) << "deleting no-op transpose";
5076 return ReplaceInstruction(transpose, operand);
5077 }
5078
5079 if (HloOpcode::kTranspose == operand->opcode()) {
5080 return ReplaceWithNewInstruction(
5081 transpose, HloInstruction::CreateTranspose(
5082 transpose->shape(), operand->mutable_operand(0),
5083 ComposePermutations(operand->dimensions(),
5084 transpose->dimensions())));
5085 }
5086
5087 // Convert transpose(dot(a,b)) to dot(b,a).
5088 if (operand->opcode() == HloOpcode::kDot && operand->user_count() == 1 &&
5089 operand->shape().rank() == 2) {
5090 TF_ASSIGN_OR_RETURN(bool did_transform, [&]() -> StatusOr<bool> {
5091 const auto& dnums = operand->dot_dimension_numbers();
5092 if (dnums.lhs_batch_dimensions_size() != 0) {
5093 return false;
5094 }
5095 HloInstruction* lhs = operand->mutable_operand(0);
5096 if (lhs->shape().rank() != 1 + dnums.lhs_contracting_dimensions_size()) {
5097 return false;
5098 }
5099 HloInstruction* rhs = operand->mutable_operand(1);
5100 if (rhs->shape().rank() != 1 + dnums.rhs_contracting_dimensions_size()) {
5101 return false;
5102 }
5103 DotDimensionNumbers new_dnums;
5104 *new_dnums.mutable_lhs_contracting_dimensions() =
5105 dnums.rhs_contracting_dimensions();
5106 *new_dnums.mutable_rhs_contracting_dimensions() =
5107 dnums.lhs_contracting_dimensions();
5108 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
5109 transpose, HloInstruction::CreateDot(transpose->shape(), /*lhs=*/rhs,
5110 /*rhs=*/lhs, new_dnums,
5111 operand->precision_config())));
5112 return true;
5113 }());
5114 if (did_transform) {
5115 return Status::OK();
5116 }
5117 }
5118
5119 // Replace transpose with a reshape if more than one degenerate method is
5120 // permuted.
5121 if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
5122 return ReplaceWithNewInstruction(
5123 transpose, HloInstruction::CreateReshape(
5124 transpose->shape(), transpose->mutable_operand(0)));
5125 }
5126
5127 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
5128 *operand->mutable_shape() = transpose->shape();
5129 return ReplaceInstruction(transpose, operand);
5130 }
5131
5132 if (options_.is_layout_sensitive() &&
5133 options_.replace_transpose_with_bitcast() &&
5134 TransposeIsBitcast(transpose)) {
5135 ReplaceWithBitcast(transpose);
5136 return Status::OK();
5137 }
5138
5139 return Status::OK();
5140 }
5141
FoldConvInputPad(HloInstruction * convolution)5142 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
5143 HloInstruction* convolution) {
5144 auto* lhs = convolution->mutable_operand(0);
5145 auto* rhs = convolution->mutable_operand(1);
5146 const auto& window = convolution->window();
5147 const ConvolutionDimensionNumbers& dnums =
5148 convolution->convolution_dimension_numbers();
5149
5150 if (lhs->opcode() != HloOpcode::kPad) {
5151 return false;
5152 }
5153
5154 // Convolution's padding is always zero, so bail if the kPad is adding
5155 // something other than zero.
5156 if (!IsAll(lhs->operand(1), 0)) {
5157 return false;
5158 }
5159
5160 const auto& padding = lhs->padding_config();
5161
5162 // Can't pad batch or feature dims.
5163 for (int64 dim :
5164 {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
5165 const auto& p = padding.dimensions(dim);
5166 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5167 p.interior_padding() != 0) {
5168 return false;
5169 }
5170 }
5171
5172 // Compute the window which is the result of merging the kPad and the
5173 // convolution's existing window.
5174 Window new_window = window;
5175 for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
5176 auto& w = *new_window.mutable_dimensions(dim);
5177 const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
5178 // Edge padding composes with itself in the straightforward way, but
5179 // composing interior padding is nontrivial, and we cowardly refuse to
5180 // think about it. If we see interior padding in either the kPad or conv,
5181 // bail if there's any sort of padding in the other.
5182 if (p.interior_padding() != 0 &&
5183 (w.padding_low() != 0 || w.padding_high() != 0 ||
5184 w.base_dilation() != 1)) {
5185 return false;
5186 }
5187 if (w.base_dilation() != 1 &&
5188 (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5189 p.interior_padding() != 0)) {
5190 return false;
5191 }
5192
5193 w.set_padding_low(w.padding_low() + p.edge_padding_low());
5194 w.set_padding_high(w.padding_high() + p.edge_padding_high());
5195 if (p.interior_padding() != 0) {
5196 CHECK_EQ(w.base_dilation(), 1);
5197 w.set_base_dilation(1 + p.interior_padding());
5198 }
5199 }
5200
5201 auto new_conv = convolution->CloneWithNewOperands(
5202 convolution->shape(), {lhs->mutable_operand(0), rhs});
5203 new_conv->set_window(new_window);
5204 TF_RETURN_IF_ERROR(
5205 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5206 return true;
5207 }
5208
FoldConvFilterPad(HloInstruction * convolution)5209 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
5210 HloInstruction* convolution) {
5211 auto* lhs = convolution->mutable_operand(0);
5212 auto* rhs = convolution->mutable_operand(1);
5213 const ConvolutionDimensionNumbers& dnums =
5214 convolution->convolution_dimension_numbers();
5215
5216 if (rhs->opcode() != HloOpcode::kPad) {
5217 return false;
5218 }
5219
5220 // Convolution's padding is always zero, so bail if the kPad is adding
5221 // something other than zero.
5222 if (!IsAll(rhs->operand(1), 0)) {
5223 return false;
5224 }
5225
5226 const auto& padding = rhs->padding_config();
5227
5228 // Can't pad or dilate feature dims.
5229 for (int64 dim : {dnums.kernel_input_feature_dimension(),
5230 dnums.kernel_output_feature_dimension()}) {
5231 const auto& p = padding.dimensions(dim);
5232 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
5233 p.interior_padding() != 0) {
5234 return false;
5235 }
5236 }
5237
5238 // Compute the window which is the result of merging the kPad and the
5239 // convolution's existing window.
5240 Window new_window = convolution->window();
5241 for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
5242 auto& w = *new_window.mutable_dimensions(dim);
5243 const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
5244
5245 // We can only do this transformation if p adds dilation to the filter --
5246 // edge padding on the filter is not supported in conv.
5247 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
5248 return false;
5249 }
5250
5251 // Nothing to do if the kPad for this dim is entirely a nop.
5252 if (p.interior_padding() == 0) {
5253 continue;
5254 }
5255
5256 // We cowardly refuse to think about how dilation composes with itself;
5257 // bail if both the kPad and conv have dilation on this dimension.
5258 if (w.window_dilation() > 1) {
5259 return false;
5260 }
5261 CHECK_EQ(w.window_dilation(), 1);
5262 w.set_window_dilation(1 + p.interior_padding());
5263 w.set_size(rhs->operand(0)->shape().dimensions(
5264 dnums.kernel_spatial_dimensions(dim)));
5265 }
5266
5267 auto new_conv = convolution->CloneWithNewOperands(
5268 convolution->shape(), {lhs, rhs->mutable_operand(0)});
5269 new_conv->set_window(new_window);
5270 TF_RETURN_IF_ERROR(
5271 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
5272 return true;
5273 }
5274
SwapConvOperands(HloInstruction * convolution)5275 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
5276 HloInstruction* convolution) {
5277 if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
5278 return false;
5279 }
5280 if (convolution->feature_group_count() > 1 ||
5281 convolution->batch_group_count() > 1) {
5282 return false;
5283 }
5284
5285 const auto& dnums = convolution->convolution_dimension_numbers();
5286 const auto& window_dims = convolution->window().dimensions();
5287 Window swapped_window;
5288
5289 HloInstruction *input = convolution->mutable_operand(0),
5290 *kernel = convolution->mutable_operand(1);
5291 int64 kernel_product = 1;
5292 int64 swapped_kernel_product = 1;
5293 DimensionVector reverse_dimensions;
5294 for (int64 spatial_dim = 0;
5295 spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
5296 const int64 kernel_size = window_dims[spatial_dim].size();
5297 const bool can_be_group_or_contraction =
5298 !window_dims[spatial_dim].window_reversal() &&
5299 window_dims[spatial_dim].padding_low() == 0 &&
5300 window_dims[spatial_dim].padding_high() == 0 &&
5301 window_dims[spatial_dim].window_dilation() == 1;
5302 const bool is_group_dim =
5303 can_be_group_or_contraction &&
5304 window_dims[spatial_dim].base_dilation() == kernel_size &&
5305 window_dims[spatial_dim].stride() == kernel_size - 1;
5306 const int64 input_size =
5307 input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
5308 const bool is_pure_contraction_dim =
5309 kernel_size == input_size && can_be_group_or_contraction &&
5310 window_dims[spatial_dim].base_dilation() == 1 &&
5311 window_dims[spatial_dim].stride() == 1;
5312 if (is_group_dim || is_pure_contraction_dim) {
5313 *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
5314 continue;
5315 }
5316
5317 const int64 dilated_kernel_size =
5318 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
5319 const int64 dilated_input_size =
5320 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
5321
5322 // Don't decide to swap if the input size is one, since many convolution
5323 // implementations can easily hand that special case efficiently.
5324 kernel_product *= kernel_size;
5325 swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
5326
5327 auto new_dim = swapped_window.add_dimensions();
5328 new_dim->set_size(input_size);
5329 // If the kernel is not reversed, the activations must be manually reversed.
5330 if (!window_dims[spatial_dim].window_reversal()) {
5331 reverse_dimensions.push_back(
5332 dnums.kernel_spatial_dimensions(spatial_dim));
5333 }
5334 // The input is not originally reversed so it must be reversed to move the
5335 // kernel.
5336 new_dim->set_window_reversal(true);
5337 // Base dilation and window dilation switch places.
5338 new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
5339 new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
5340 new_dim->set_stride(window_dims[spatial_dim].stride());
5341 new_dim->set_padding_low(dilated_input_size +
5342 window_dims[spatial_dim].padding_low() -
5343 dilated_kernel_size);
5344 new_dim->set_padding_high(dilated_input_size +
5345 window_dims[spatial_dim].padding_high() -
5346 dilated_kernel_size);
5347 }
5348
5349 // Don't transform if a naive convolution implementation would not have fewer
5350 // flops.
5351 if (kernel_product <= swapped_kernel_product) {
5352 return false;
5353 }
5354 ConvolutionDimensionNumbers swapped_dnums;
5355 *swapped_dnums.mutable_output_spatial_dimensions() =
5356 dnums.output_spatial_dimensions();
5357 // Swap batch and output feature of the output.
5358 swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
5359 swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
5360
5361 // Swap input dnums with kernel dnums
5362 *swapped_dnums.mutable_input_spatial_dimensions() =
5363 dnums.kernel_spatial_dimensions();
5364 swapped_dnums.set_input_batch_dimension(
5365 dnums.kernel_output_feature_dimension());
5366 swapped_dnums.set_input_feature_dimension(
5367 dnums.kernel_input_feature_dimension());
5368
5369 // Swap kernel dnums with input dnums
5370 *swapped_dnums.mutable_kernel_spatial_dimensions() =
5371 dnums.input_spatial_dimensions();
5372 swapped_dnums.set_kernel_output_feature_dimension(
5373 dnums.input_batch_dimension());
5374 swapped_dnums.set_kernel_input_feature_dimension(
5375 dnums.input_feature_dimension());
5376
5377 PrecisionConfig precision_config;
5378 precision_config.add_operand_precision(
5379 convolution->precision_config().operand_precision(1));
5380 precision_config.add_operand_precision(
5381 convolution->precision_config().operand_precision(0));
5382 if (!reverse_dimensions.empty()) {
5383 TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
5384 }
5385 TF_ASSIGN_OR_RETURN(
5386 HloInstruction * new_convolution,
5387 MakeConvolveHlo(
5388 kernel, input, /*feature_group_count=*/1,
5389 /*batch_group_count=*/1, swapped_window, swapped_dnums,
5390 precision_config,
5391 /*preferred_element_type=*/convolution->shape().element_type()));
5392
5393 convolution->SetupDerivedInstruction(new_convolution);
5394 TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
5395
5396 return true;
5397 }
5398
SimplifyConvToDot(HloInstruction * convolution)5399 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
5400 HloInstruction* convolution) {
5401 auto* lhs = convolution->mutable_operand(0);
5402 auto* rhs = convolution->mutable_operand(1);
5403 const auto& window = convolution->window();
5404 const ConvolutionDimensionNumbers& dnums =
5405 convolution->convolution_dimension_numbers();
5406
5407 if (!options_.enable_conv_simplification()) {
5408 return false;
5409 }
5410
5411 // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
5412 // layout-insensitive mode, for fear of adding nontrivial reshapes.
5413 if (!options_.is_layout_sensitive()) {
5414 return false;
5415 }
5416
5417 const Shape& input_shape = lhs->shape();
5418 const Shape& filter_shape = rhs->shape();
5419 const Shape& convolution_shape = convolution->shape();
5420 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
5421 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
5422 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
5423
5424 // Require the spatial dimensions in the kernel to have a bound of one.
5425 for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
5426 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
5427 return false;
5428 }
5429 }
5430
5431 // Stride ignores part of the output, which matrix multiplication does not do,
5432 // so require no stride. Padding and base (lhs) dilation both implicitly
5433 // extend the data, which matrix multiplication also does not do, so require
5434 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
5435 // for a 1x1 window, so window dilation is no problem.
5436 if (window_util::HasStride(window) || window_util::HasPadding(window) ||
5437 window_util::HasBaseDilation(window)) {
5438 return false;
5439 }
5440
5441 // Also, the shapes must align for a rowmajor matmul:
5442 // - the input and output have the same layout.
5443 // - for input/output, the channel dimension must be the most minor. Other
5444 // spatial dims can be in any order.
5445 // - for filters, the input channel dimension must be more major than the
5446 // output channel dimension. The width+height don't matter because
5447 // they are 1.
5448 //
5449 // These constraints are harsh. If the channel dimension is the most major
5450 // and/or the layout of input/output feature dimensions are reversed, we can
5451 // still convert Conv into more efficient Matmul with operand transposition
5452 // (such as the transposition flags in cuBLAS SGEMM).
5453 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
5454 LayoutUtil::Minor(input_shape.layout(), 0) !=
5455 dnums.input_feature_dimension() ||
5456 LayoutUtil::Minor(convolution_shape.layout(), 0) !=
5457 dnums.output_feature_dimension() ||
5458 // The input feature dimension should come later in the minor-to-major
5459 // order.
5460 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5461 dnums.kernel_input_feature_dimension()) <
5462 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
5463 dnums.kernel_output_feature_dimension()))) {
5464 return false;
5465 }
5466
5467 auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
5468 std::vector<int64> dims(operand->shape().dimensions_size());
5469 std::iota(dims.begin(), dims.end(), 0);
5470 return computation_->AddInstruction(
5471 HloInstruction::CreateBitcast(shape, operand));
5472 };
5473
5474 // Replace it with a dot, with bitcasts around it to get the right shape.
5475 const int64 input_channels =
5476 input_shape.dimensions(dnums.input_feature_dimension());
5477 const int64 output_channels =
5478 filter_shape.dimensions(dnums.kernel_output_feature_dimension());
5479
5480 // Computes the product of the non-feature dimensions.
5481 int64 conv_width = 1;
5482 for (int i = 0; i < input_shape.dimensions_size(); ++i) {
5483 if (i != dnums.input_feature_dimension()) {
5484 conv_width *= input_shape.dimensions(i);
5485 }
5486 }
5487
5488 // We already checked feature_dimension is most minor, so data in input_shape
5489 // and row-major {conv_width,input_channels} are bitwise identical.
5490 Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5491 input_shape.element_type(), {conv_width, input_channels});
5492 simplifier_->UpdateLayout(&new_input_shape);
5493 // We already checked input_feature_dimension is more major than
5494 // output_feature_dimension, so data in filter_shape and row-major
5495 // {input_channels,output_channels} are bitwise identical.
5496 Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5497 filter_shape.element_type(), {input_channels, output_channels});
5498 simplifier_->UpdateLayout(&new_filter_shape);
5499 Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
5500 convolution_shape.element_type(), {conv_width, output_channels});
5501 simplifier_->UpdateLayout(&dot_output_shape);
5502
5503 auto new_lhs = add_bitcast(new_input_shape, lhs);
5504 auto new_rhs = add_bitcast(new_filter_shape, rhs);
5505 DotDimensionNumbers dot_dimension_numbers;
5506 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
5507 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
5508 auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
5509 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
5510 convolution->precision_config()));
5511
5512 TF_RETURN_IF_ERROR(
5513 ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
5514 return true;
5515 }
5516
HandleConvolution(HloInstruction * convolution)5517 Status AlgebraicSimplifierVisitor::HandleConvolution(
5518 HloInstruction* convolution) {
5519 if (options_.enable_scalar_multiply_reduction()) {
5520 TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
5521 }
5522
5523 // Zero-sized input or filter.
5524 if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
5525 ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
5526 return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
5527 }
5528
5529 // Try to merge padding/dilation of the input with the convolution's window.
5530 TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
5531 if (folded_input_pad) {
5532 return Status::OK();
5533 }
5534
5535 // Try to merge dilation of the filter with the convolution's window.
5536 TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
5537 if (folded_filter_pad) {
5538 return Status::OK();
5539 }
5540
5541 // Try to swap convolution operands.
5542 TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
5543 if (swapped) {
5544 return Status::OK();
5545 }
5546 // Try to replace the convolution with a kDot instruction.
5547 TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
5548 if (replaced_with_dot) {
5549 return Status::OK();
5550 }
5551
5552 return Status::OK();
5553 }
5554
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)5555 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
5556 HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
5557 HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
5558 // Ensure shapes of min and max operand are equal to match current shape
5559 // inference.
5560 if (!SameShape(min_operand, max_operand)) {
5561 return false;
5562 }
5563
5564 auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
5565 max_operand, operand, min_operand);
5566 TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
5567 return true;
5568 }
5569
HandleMap(HloInstruction * map)5570 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
5571 auto* map_computation = map->to_apply();
5572 auto* map_root = map_computation->root_instruction();
5573 if (map_root->opcode() == HloOpcode::kParameter) {
5574 ReplaceInstructionIfSameShape(
5575 map, map->mutable_operand(map_root->parameter_number()));
5576 return Status::OK();
5577 }
5578 if (map_root->opcode() == HloOpcode::kConstant) {
5579 if (!ShapeUtil::IsScalar(map_root->shape())) {
5580 return Status::OK();
5581 }
5582 auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
5583 if (ShapeUtil::IsScalar(map->shape())) {
5584 return ReplaceWithNewInstruction(map, std::move(clone));
5585 }
5586 return ReplaceWithNewInstruction(
5587 map,
5588 HloInstruction::CreateBroadcast(
5589 map->shape(), computation_->AddInstruction(std::move(clone)), {}));
5590 }
5591 // Inline the map if the map computation only contains an elementwise
5592 // operation that can accept arbitrary shapes.
5593 if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
5594 return Status::OK();
5595 }
5596 std::vector<HloInstruction*> new_operands;
5597 for (auto* root_operand : map_root->operands()) {
5598 if (root_operand->opcode() != HloOpcode::kParameter) {
5599 return Status::OK();
5600 }
5601 new_operands.push_back(
5602 map->mutable_operand(root_operand->parameter_number()));
5603 }
5604 auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
5605 return ReplaceWithNewInstruction(map, std::move(clone));
5606 }
5607
Run(HloModule * module)5608 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
5609 XLA_VLOG_LINES(2,
5610 "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
5611 bool changed = false;
5612 AlgebraicSimplifierVisitor visitor(options_, this);
5613 for (auto* comp : module->MakeNonfusionComputations()) {
5614 if (visitor.Run(comp, options_, this)) {
5615 changed = true;
5616 }
5617 }
5618 XLA_VLOG_LINES(2,
5619 "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
5620 return changed;
5621 }
5622
5623 } // namespace xla
5624