1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/memory/memory.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/literal_util.h"
39 #include "tensorflow/compiler/xla/primitive_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/hlo_query.h"
48 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
49 #include "tensorflow/compiler/xla/shape.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/window_util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/core/bits.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/types.h"
61
62 namespace xla {
63
64 namespace {
65
66 namespace m = match;
67
IsAll(const HloInstruction * op,int8 value)68 bool IsAll(const HloInstruction* op, int8 value) {
69 switch (op->opcode()) {
70 case HloOpcode::kBroadcast:
71 return IsAll(op->operand(0), value);
72 case HloOpcode::kConstant:
73 return op->literal().IsAll(value);
74 default:
75 return false;
76 }
77 }
78
79 // Checks whether `op` is a floating-point constant or broadcast of a constant
80 // of the form +/- 2^k for some integer k positive, negative, or zero. Such
81 // values are interesting because multiplying by a power of 2 just moves the
82 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)83 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
84 // Unwrap the broadcast if necessary.
85 const HloInstruction* c;
86 if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
87 !Match(op, m::Broadcast(m::Constant(&c).WithShape(
88 m::Shape().IsEffectiveScalar())))) {
89 return false;
90 }
91 auto val = [&]() -> absl::optional<double> {
92 switch (c->shape().element_type()) {
93 case BF16:
94 return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
95 case F16:
96 return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
97 case F32:
98 return c->literal().GetFirstElement<float>();
99 case F64:
100 return c->literal().GetFirstElement<double>();
101 default:
102 // Cowardly refuse to consider complex types.
103 return absl::nullopt;
104 }
105 }();
106 if (!val) {
107 return false;
108 }
109
110 int exp;
111 double mantissa = std::frexp(*val, &exp);
112 // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value
113 // of +/-0.5 therefore indicates that the floating point value is a power of
114 // 2.
115 return mantissa == 0.5 || mantissa == -0.5;
116 }
117
118 // Returns whether the given transpose produces a result which is bit-wise
119 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)120 bool TransposeIsBitcast(const HloInstruction* transpose) {
121 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
122 const HloInstruction* operand = transpose->operand(0);
123 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
124 transpose->dimensions());
125 }
126
127 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)128 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
129 HloInstruction* instr, HloInstruction* operand,
130 const AlgebraicSimplifierOptions& options) {
131 // Can't replace chain of copies and reshapes with bitcasts if the compiler
132 // used a memory layout which isn't compatible.
133 if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
134 return operand;
135 }
136
137 // If the operand is a copy or reshape try to see if the operand's operand
138 // would produce a bitcast with initial instruction.
139 if (HloOpcode::kReshape == operand->opcode() ||
140 HloOpcode::kCopy == operand->opcode()) {
141 return BitcastingOperandOfReshapeOrCopyChainHelper(
142 instr, operand->mutable_operand(0), options);
143 }
144 return nullptr;
145 }
146
147 // Returns an operand of a chain of reshapes and copies that is bit-wise
148 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)149 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
150 HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
151 if (!options.is_layout_sensitive()) {
152 return nullptr;
153 }
154 CHECK(HloOpcode::kReshape == instr->opcode() ||
155 HloOpcode::kCopy == instr->opcode());
156 return BitcastingOperandOfReshapeOrCopyChainHelper(
157 instr, instr->mutable_operand(0), options);
158 }
159
IsUnstridedSlice(const HloInstruction * hlo)160 bool IsUnstridedSlice(const HloInstruction* hlo) {
161 return absl::c_all_of(hlo->slice_strides(),
162 [](int64 stride) { return stride == 1; });
163 }
164
165 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
166 // algebraic expressions to simplified forms. Note: This only supports
167 // simplifications that simply look at the operands of an instruction. For the
168 // more general case a worklist based approach would be needed.
169 class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
170 public:
171 // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)172 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
173 return Status::OK();
174 }
175
176 Status HandleAdd(HloInstruction* add) override;
177
178 Status HandleAnd(HloInstruction* logical_and) override;
179
180 Status HandleBitcast(HloInstruction* bitcast) override;
181
182 Status HandleBitcastConvert(HloInstruction* bitcast) override;
183
184 Status HandleBroadcast(HloInstruction* broadcast) override;
185
186 Status HandleConcatenate(HloInstruction* concatenate) override;
187
188 Status HandleConstant(HloInstruction* constant) override;
189
190 Status HandleCopy(HloInstruction* copy) override;
191
192 Status HandleConvert(HloInstruction* convert) override;
193
194 Status HandleComplex(HloInstruction* complex) override;
195
196 Status HandleReal(HloInstruction* real) override;
197
198 Status HandleImag(HloInstruction* imag) override;
199
200 Status HandleIota(HloInstruction* instruction) override;
201
202 Status HandleConvolution(HloInstruction* convolution) override;
203
204 Status HandleDivide(HloInstruction* divide) override;
205
206 Status HandleDot(HloInstruction* dot) override;
207
208 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
209
210 Status HandleLog(HloInstruction* log) override;
211
212 Status HandleMultiply(HloInstruction* multiply) override;
213
214 Status HandleNegate(HloInstruction* negate) override;
215
216 Status HandleNot(HloInstruction* logical_not) override;
217
218 Status HandleOr(HloInstruction* logical_or) override;
219
220 Status HandlePad(HloInstruction* pad) override;
221
222 Status HandlePower(HloInstruction* power) override;
223
224 Status HandleRemainder(HloInstruction* remainder) override;
225
226 Status HandleReshape(HloInstruction* reshape) override;
227
228 Status HandleReduce(HloInstruction* reduce) override;
229
230 Status HandleReduceWindow(HloInstruction* reduce_window) override;
231
232 Status HandleReverse(HloInstruction* reverse) override;
233 Status HandleSlice(HloInstruction* slice) override;
234 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
235 Status HandleDynamicUpdateSlice(
236 HloInstruction* dynamic_update_slice) override;
237
238 Status HandleSelect(HloInstruction* select) override;
239
240 Status HandleSort(HloInstruction* sort) override;
241
242 Status HandleTranspose(HloInstruction* transpose) override;
243
244 Status HandleSubtract(HloInstruction* sub) override;
245
246 Status HandleMap(HloInstruction* map) override;
247
248 // Returns whether algebraic simplification has occurred.
changed() const249 const bool changed() const { return changed_; }
250
251 // Runs the visitor on a computation.
252 static bool Run(HloComputation* computation,
253 const AlgebraicSimplifierOptions& options);
254
255 private:
AlgebraicSimplifierVisitor(HloComputation * computation,const AlgebraicSimplifierOptions & options)256 explicit AlgebraicSimplifierVisitor(HloComputation* computation,
257 const AlgebraicSimplifierOptions& options)
258 : computation_(computation), options_(options) {}
259
260 // Transforms Dots where at least one input is a vector or has a degenerate
261 // dimension and converts it into a multiply and reduce. This should enable
262 // more fusion than leaving the nodes as Dot operations.
263 StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);
264
265 // Removes dimension dim from hlo.
StripDim(HloInstruction * hlo,int64 dim)266 HloInstruction* StripDim(HloInstruction* hlo, int64 dim) {
267 CHECK_EQ(hlo->shape().dimensions(dim), 1);
268 return computation_->AddInstruction(HloInstruction::CreateReshape(
269 ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo));
270 }
271
272 // Reshapes an instruction to rank 1 if it is not already rank 1.
Flatten(HloInstruction * hlo)273 HloInstruction* Flatten(HloInstruction* hlo) {
274 if (hlo->shape().rank() == 1) {
275 return hlo;
276 }
277 return computation_->AddInstruction(HloInstruction::CreateReshape(
278 ShapeUtil::MakeShape(hlo->shape().element_type(),
279 {ShapeUtil::ElementsIn(hlo->shape())}),
280 hlo));
281 }
282
283 // Converts to primitive type if the input hlo is not that type, otherwise
284 // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)285 HloInstruction* AsType(HloInstruction* hlo,
286 const PrimitiveType element_type) {
287 if (hlo->shape().element_type() == element_type) {
288 return hlo;
289 }
290 return computation_->AddInstruction(HloInstruction::CreateConvert(
291 ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
292 }
293
294 // Transposes a dot operand such that the batch dimensions are the msot major,
295 // and the contracting dimensions are most minor.
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64> batch_dimensions,absl::Span<const int64> contracting_dimensions)296 StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
297 HloInstruction* dot_operand, absl::Span<const int64> batch_dimensions,
298 absl::Span<const int64> contracting_dimensions) {
299 std::vector<int64> transpose_dimensions(batch_dimensions.begin(),
300 batch_dimensions.end());
301 for (int64 i = 0; i < dot_operand->shape().rank(); ++i) {
302 if (!(absl::c_linear_search(batch_dimensions, i) ||
303 absl::c_linear_search(contracting_dimensions, i))) {
304 transpose_dimensions.push_back(i);
305 }
306 }
307 transpose_dimensions.insert(transpose_dimensions.end(),
308 contracting_dimensions.begin(),
309 contracting_dimensions.end());
310 return MakeTransposeHlo(dot_operand, transpose_dimensions);
311 }
312
313 // Helper method to perform and add reduction on a list of dimensions.
AddReduce(HloInstruction * hlo,absl::Span<const int64> dims)314 HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64> dims) {
315 HloInstruction* zero =
316 computation_->AddInstruction(HloInstruction::CreateConstant(
317 LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
318 HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
319 Shape shape = ShapeUtil::FilterDimensions(
320 [&](int64 dim) { return !absl::c_linear_search(dims, dim); },
321 hlo->shape());
322 return computation_->AddInstruction(HloInstruction::CreateReduce(
323 shape, hlo, zero, dims, AddReduce_computation));
324 }
325
AddReduce(HloInstruction * hlo,int64 dim)326 HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
327 return AddReduce(hlo, std::vector<int64>{dim});
328 }
329
330 // Convenience method for replacing an instruction with a bitcast. If operand
331 // is not null, then the bitcast will use the specified operand instead of the
332 // operand of the instruction.
333 void ReplaceWithBitcast(HloInstruction* instruction,
334 HloInstruction* operand = nullptr);
335
336 // Replace old instruction with new instruction if old and new instructions
337 // have the same shape. Updates uses and root instruction. Returns whether a
338 // replacement was made.
339 bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
340 HloInstruction* new_instruction);
341
342 // Returns whether the shape of the output of the given instructions are the
343 // same for the purposes of simplification. If options_.is_layout_sensitive()
344 // is true, then this tests shape equality including layout
345 // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
346 // tests shape compatibility (ShapeUtil::Compatible).
347 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
348
349 // Returns whether it was possible to transform `root` to a clamp instruction.
350 // With min a minimum instruction, max a maximum instruction, min_operand a
351 // operand of min and max_operand a operand of max.
352 // Precondition: root is either a minimum or a maximum.
353 bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
354 HloInstruction* min_operand,
355 HloInstruction* operand, HloInstruction* max,
356 HloInstruction* max_operand);
357
358 // A Broadcast that feeds an element-wise operation with a unique non-scalar
359 // operand can sink to after the operation.
360 StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
361 HloInstruction* broadcast);
362
363 // Replaces the existing HLO instruction old_instruction, with
364 // new_instruction, and marks the optimizer status as changed.
365 // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)366 Status ReplaceWithNewInstruction(
367 HloInstruction* old_instruction,
368 std::unique_ptr<HloInstruction> new_instruction) {
369 VLOG(3) << "Replacing instruction:";
370 VLOG(3) << " old: " << old_instruction->ToString();
371 VLOG(3) << " new: " << new_instruction->ToString();
372 TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
373 old_instruction, std::move(new_instruction)));
374 changed_ = true;
375 return Status::OK();
376 }
377
378 // Replaces the existing HLO instruction old_instruction, with
379 // new_instruction, and marks the optimizer status as changed.
380 // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)381 Status ReplaceInstruction(HloInstruction* old_instruction,
382 HloInstruction* new_instruction) {
383 VLOG(3) << "Replacing instruction:";
384 VLOG(3) << " old: " << old_instruction->ToString();
385 VLOG(3) << " new: " << new_instruction->ToString();
386 TF_RETURN_IF_ERROR(
387 computation_->ReplaceInstruction(old_instruction, new_instruction));
388 changed_ = true;
389 return Status::OK();
390 }
391
392 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
393 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
394 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
395 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
396
397 StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
398
GetOrCreateScalarAddComputation()399 HloComputation* GetOrCreateScalarAddComputation() {
400 if (scalar_add_computation_) {
401 return scalar_add_computation_;
402 }
403
404 HloComputation::Builder b("scalar_add_computation");
405 Shape shape = ShapeUtil::MakeShape(F32, {});
406 auto scalar_lhs = b.AddInstruction(
407 HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
408 auto scalar_rhs = b.AddInstruction(
409 HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
410 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
411 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
412 scalar_add_computation_ =
413 computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
414 return scalar_add_computation_;
415 }
416
417 // Tries to fold a kPad in the input or filter into the convolution
418 // instruction's window.
419 StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
420 StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
421
422 // Tries to use a kDot in place of the given convolution.
423 StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
424
425 // Tries to simplify a slice where the result of the slice is a scalar.
426 StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
427
428 // Tries to convert slice(reshape(X)) into reshape(slice(X))
429 StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
430
431 // Current HloComputation instance the AlgebraicSimplifierVisitor is
432 // traversing.
433 HloComputation* computation_;
434
435 // The backend-specific options selected for the algebraic simplifier.
436 const AlgebraicSimplifierOptions& options_;
437
438 // Whether algebraic simplification has occurred.
439 bool changed_ = false;
440
441 // Cached computation for adding two scalar F32.
442 HloComputation* scalar_add_computation_ = nullptr;
443 };
444
445 } // namespace
446
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options)447 bool AlgebraicSimplifierVisitor::Run(
448 HloComputation* computation, const AlgebraicSimplifierOptions& options) {
449 AlgebraicSimplifierVisitor visitor(computation, options);
450 TF_CHECK_OK(computation->Accept(&visitor));
451 return visitor.changed_;
452 }
453
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const454 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
455 const HloInstruction* rhs) const {
456 if (options_.is_layout_sensitive()) {
457 return ShapeUtil::Equal(lhs->shape(), rhs->shape());
458 } else {
459 return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
460 }
461 }
462
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)463 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
464 HloInstruction* operand) {
465 CHECK_EQ(1, instruction->operand_count());
466 if (operand == nullptr) {
467 operand = instruction->mutable_operand(0);
468 }
469 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
470 ShapeUtil::ElementsIn(operand->shape()));
471 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
472 ShapeUtil::ByteSizeOf(operand->shape()));
473
474 auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary(
475 instruction->shape(), HloOpcode::kBitcast, operand));
476 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
477 }
478
ReplaceInstructionIfSameShape(HloInstruction * old_instruction,HloInstruction * new_instruction)479 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
480 HloInstruction* old_instruction, HloInstruction* new_instruction) {
481 if (!SameShape(old_instruction, new_instruction)) {
482 return false;
483 }
484 TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
485 return true;
486 }
487
HandleAdd(HloInstruction * add)488 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
489 HloInstruction *lhs, *rhs;
490 CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
491
492 // A + 0 => A
493 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
494 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
495 return Status::OK();
496 }
497 // 0 + A => A
498 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
499 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
500 return Status::OK();
501 }
502
503 // Canonicalization: Put constants on the right. This makes the reassociation
504 // rules below simpler.
505 VLOG(10) << "trying transform [Const + A => A + Const]";
506 if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
507 return ReplaceWithNewInstruction(
508 add,
509 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
510 }
511
512 // Reassociate to allow constant folding.
513 //
514 // Note: This is not general. For example, we won't reassociate
515 //
516 // (A + C1) + (B + C2) => A + B + (C1 + C2).
517 //
518 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
519 HloInstruction *a, *c1, *c2;
520 if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
521 m::Constant(&c2)))) {
522 TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
523 MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
524 return ReplaceWithNewInstruction(
525 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
526 sum_of_constants));
527 }
528
529 // A*C + B*C => (A+B)*C
530 //
531 // - If A, B, and C are integers, do this unconditionally. Proof of
532 // correctness: https://rise4fun.com/Alive/u9X.
533 //
534 // - If A, B, and C are floating point, do this if C is a scalar constant or
535 // broadcast of scalar constant and is equal to +/- 2^k for some (possibly
536 // negative) integer k.
537 //
538 // Multiplying by a power of 2 just moves the exponent, so our answer is
539 // exact modulo rounding of intermediate results so long as
540 //
541 // - none of the three products has an exponent which underflows (so the
542 // result is 0 or denormal), and
543 // - none of the three products overflows to inf.
544 //
545 // Proof: See algebraic_simplifier_proof_distributive_property.py.
546 //
547 // We deem these differences in rounding, underflow, and overflow
548 // acceptable in the ML context.
549 HloInstruction *b, *c;
550 if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
551 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
552 (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
553 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
554 (ShapeUtil::ElementIsIntegral(add->shape()) ||
555 IsAllFpConstantPowerOf2(c))) {
556 return ReplaceWithNewInstruction(
557 add, HloInstruction::CreateBinary(
558 add->shape(), HloOpcode::kMultiply,
559 computation_->AddInstruction(HloInstruction::CreateBinary(
560 add->shape(), HloOpcode::kAdd, a, b)),
561 c));
562 }
563 return Status::OK();
564 }
565
HandleAnd(HloInstruction * logical_and)566 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
567 HloInstruction *lhs, *rhs;
568 CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
569 // Simplify logical and
570 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
571 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
572 // A && True => A
573 VLOG(10) << "trying transform [A && True => A]: "
574 << logical_and->ToString();
575 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
576 return Status::OK();
577 }
578 // True && A => A
579 VLOG(10) << "trying transform [True && A => A]: "
580 << logical_and->ToString();
581 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
582 return Status::OK();
583 }
584
585 // A && False => False
586 VLOG(10) << "trying transform [A && False => False]: "
587 << logical_and->ToString();
588 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) {
589 return Status::OK();
590 }
591
592 // False && A => False
593 VLOG(10) << "trying transform [False && A => False]: "
594 << logical_and->ToString();
595 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) {
596 return Status::OK();
597 }
598 }
599
600 return Status::OK();
601 }
602
HandleBitcast(HloInstruction * bitcast)603 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
604 // If a bitcast feeds a bitcast, make it a single bitcast.
605 HloInstruction* op;
606 if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
607 return ReplaceWithNewInstruction(
608 bitcast,
609 HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op));
610 }
611 // All bitcasts can be eliminated (assuming layout constraints are
612 // satisified).
613 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
614 return Status::OK();
615 }
616
HandleBitcastConvert(HloInstruction * bitcast)617 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
618 HloInstruction* bitcast) {
619 // Eliminate bitcast converts between same shape.
620 ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
621 return Status::OK();
622 }
623
HandleCopy(HloInstruction * copy)624 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
625 // If a copy feeds a copy, make it a single copy.
626 HloInstruction* op;
627 if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
628 return ReplaceWithNewInstruction(
629 copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
630 }
631 // All copies can be eliminated (assuming layout constraints are satisified).
632 if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
633 return Status::OK();
634 }
635
636 if (HloInstruction* bitcast_operand =
637 BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
638 ReplaceWithBitcast(copy, bitcast_operand);
639 }
640
641 return Status::OK();
642 }
643
HandleConcatenate(HloInstruction * concatenate)644 Status AlgebraicSimplifierVisitor::HandleConcatenate(
645 HloInstruction* concatenate) {
646 absl::Span<HloInstruction* const> operands(concatenate->operands());
647 if (operands.size() == 1) {
648 // Unary concatenates are useless.
649 ReplaceInstructionIfSameShape(concatenate, operands[0]);
650 return Status::OK();
651 }
652 // Filter out and remove empty operands.
653 std::vector<HloInstruction*> nonempty_operands;
654 for (HloInstruction* operand : operands) {
655 if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
656 nonempty_operands.push_back(operand);
657 }
658 }
659 if (nonempty_operands.size() < operands.size()) {
660 HloInstruction* replacement;
661 if (nonempty_operands.empty()) {
662 replacement = operands[0];
663 } else if (nonempty_operands.size() == 1) {
664 replacement = nonempty_operands[0];
665 } else {
666 replacement =
667 computation_->AddInstruction(concatenate->CloneWithNewOperands(
668 concatenate->shape(), nonempty_operands));
669 }
670 VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
671 << replacement->ToString();
672 ReplaceInstructionIfSameShape(concatenate, replacement);
673 return Status::OK();
674 }
675
676 // Check if we can merge "adjacent" slice operands which take slices from the
677 // same other op. For simplicity we only merge unstrided slices.
678 int64 concatenate_dimension = concatenate->concatenate_dimension();
679 for (int64 i = 0; i < operands.size(); ++i) {
680 if (operands[i]->opcode() != HloOpcode::kSlice ||
681 !IsUnstridedSlice(operands[i])) {
682 continue;
683 }
684 int64 slice_end = operands[i]->slice_limits(concatenate_dimension);
685 HloInstruction* slice_operand = operands[i]->mutable_operand(0);
686 int64 j = i + 1;
687 while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice &&
688 IsUnstridedSlice(operands[j]) &&
689 operands[j]->operand(0) == slice_operand &&
690 operands[j]->slice_starts(concatenate_dimension) == slice_end) {
691 // Check that all the slice_start values are the same in all other
692 // dimensions. This implies that the slice_limit values are also the same,
693 // because operands of concatenate need to have the same shape, and we
694 // already checked that the slices are unstrided.
695 bool same_other_starts = true;
696 for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) {
697 if (k == concatenate_dimension) {
698 continue;
699 }
700 if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
701 same_other_starts = false;
702 break;
703 }
704 }
705 if (!same_other_starts) {
706 break;
707 }
708 slice_end = operands[j]->slice_limits(concatenate_dimension);
709 ++j;
710 }
711 if (j - i > 1) {
712 Shape new_slice_shape = operands[i]->shape();
713 new_slice_shape.set_dimensions(
714 concatenate_dimension,
715 slice_end - operands[i]->slice_starts(concatenate_dimension));
716 auto new_limit_indices = operands[i]->slice_limits();
717 new_limit_indices[concatenate_dimension] = slice_end;
718 auto new_slice_op =
719 computation_->AddInstruction(HloInstruction::CreateSlice(
720 new_slice_shape, slice_operand,
721 /*start_indices=*/operands[i]->slice_starts(),
722 /*limit_indices=*/new_limit_indices,
723 /*strides=*/operands[i]->slice_strides()));
724 std::vector<HloInstruction*> new_operands;
725 for (int64 k = 0; k < i; ++k) {
726 new_operands.push_back(operands[k]);
727 }
728 new_operands.push_back(new_slice_op);
729 for (int64 k = j; k < operands.size(); ++k) {
730 new_operands.push_back(operands[k]);
731 }
732 auto replacement =
733 computation_->AddInstruction(concatenate->CloneWithNewOperands(
734 concatenate->shape(), new_operands));
735 ReplaceInstructionIfSameShape(concatenate, replacement);
736 return Status::OK();
737 }
738 }
739
740 if (operands.size() == 2) {
741 // A binary concat with a broadcasted scalar as an operand can be converted
742 // into a pad which is simpler to fold into other operations.
743 bool is_effective_low_pad = Match(
744 operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
745 bool is_effective_high_pad = Match(
746 operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
747 if (!is_effective_low_pad && !is_effective_high_pad) {
748 return Status::OK();
749 }
750 PaddingConfig padding_config;
751 for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) {
752 auto padding_config_dim = padding_config.add_dimensions();
753 padding_config_dim->set_edge_padding_high(0);
754 padding_config_dim->set_edge_padding_low(0);
755 padding_config_dim->set_interior_padding(0);
756 if (dim == concatenate_dimension) {
757 if (is_effective_low_pad) {
758 padding_config_dim->set_edge_padding_low(
759 operands[0]->shape().dimensions(dim));
760 } else {
761 padding_config_dim->set_edge_padding_high(
762 operands[1]->shape().dimensions(dim));
763 }
764 }
765 }
766 int64 operand_to_pad = is_effective_low_pad ? 1 : 0;
767 int64 pad_value_operand = is_effective_low_pad ? 0 : 1;
768 HloInstruction* pad =
769 computation_->AddInstruction(HloInstruction::CreatePad(
770 concatenate->shape(), operands[operand_to_pad],
771 operands[pad_value_operand]->mutable_operand(0), padding_config));
772 return ReplaceInstruction(concatenate, pad);
773 }
774 return Status::OK();
775 }
776
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal)777 static HloInstruction* BuildTupleConstant(HloComputation* computation,
778 const LiteralSlice& literal) {
779 if (literal.shape().IsTuple()) {
780 std::vector<HloInstruction*> elems;
781 elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
782 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
783 elems.push_back(
784 BuildTupleConstant(computation, LiteralSlice(literal, {i})));
785 }
786 return computation->AddInstruction(HloInstruction::CreateTuple(elems));
787 } else {
788 return computation->AddInstruction(
789 HloInstruction::CreateConstant(literal.Clone()));
790 }
791 }
792
HandleConstant(HloInstruction * constant)793 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
794 // Tuple constants aren't directly supported by any backend. Expand them into
795 // explicit Tuple instructions.
796 if (constant->shape().IsTuple()) {
797 return ReplaceInstruction(
798 constant, BuildTupleConstant(computation_, constant->literal()));
799 }
800
801 if (constant->shape().element_type() == TOKEN) {
802 return Status::OK();
803 }
804
805 // If a literal is all the same element replace it with a scalar broadcast.
806 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
807 constant->literal().IsAllFirst()) {
808 Literal unique_scalar(
809 LiteralUtil::GetFirstScalarLiteral(constant->literal()));
810 HloInstruction* scalar = computation_->AddInstruction(
811 HloInstruction::CreateConstant(std::move(unique_scalar)));
812 return ReplaceWithNewInstruction(
813 constant,
814 HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
815 }
816
817 // If a literal is an increasing sequence from zero, replace it with an iota.
818 if (constant->shape().rank() == 1 &&
819 ShapeUtil::ElementsIn(constant->shape()) > 1 &&
820 constant->literal().IsR1Iota()) {
821 return ReplaceWithNewInstruction(
822 constant, HloInstruction::CreateIota(constant->shape(), 0));
823 }
824 return Status::OK();
825 }
826
HandleSubtract(HloInstruction * sub)827 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
828 HloInstruction *lhs, *rhs;
829 CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
830 // A - 0 => A
831 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
832 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
833 return Status::OK();
834 }
835
836 // Canonicalize subtraction of a constant to addition.
837 VLOG(10) << "trying transform [A - Const => A + (-Const)]";
838 if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
839 HloInstruction* negative_const = computation_->AddInstruction(
840 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
841 return ReplaceWithNewInstruction(
842 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
843 negative_const));
844 }
845
846 return Status::OK();
847 }
848 namespace {
849 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)850 Status InvertConstant(const HloInstruction& constant, Literal* result) {
851 return result->Populate<T>([&](absl::Span<const int64> indices) {
852 return T{1.0} / constant.literal().Get<T>(indices);
853 });
854 }
855
856 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation)857 std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
858 HloComputation* computation) {
859 HloInstruction *a, *b, *c;
860 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
861
862 if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
863 !Match(b, m::ConstantEffectiveScalar(&c)) &&
864 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
865 return nullptr;
866 }
867
868 if (ShapeUtil::ElementIsSigned(divide->shape())) {
869 int64 b_value = c->literal().GetFirstElement<T>();
870 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
871 // Handle negative dividends by negating the result of the division.
872 HloInstruction* zero_like_a = BroadcastZeros(
873 computation, a->shape().element_type(), a->shape().dimensions());
874
875 auto* dividend_is_negative =
876 computation->AddInstruction(HloInstruction::CreateCompare(
877 ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
878 ComparisonDirection::kLt));
879
880 auto* negated_dividend = computation->AddInstruction(
881 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
882
883 auto* abs_dividend =
884 computation->AddInstruction(HloInstruction::CreateTernary(
885 a->shape(), HloOpcode::kSelect, dividend_is_negative,
886 negated_dividend, a));
887
888 int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
889
890 auto* shift_amount =
891 computation->AddInstruction(HloInstruction::CreateConstant(
892 LiteralUtil::CreateR0<T>(log2_abs_b_value)));
893 if (!ShapeUtil::IsScalar(b->shape())) {
894 shift_amount = computation->AddInstruction(
895 HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
896 }
897
898 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
899 divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
900 shift_amount));
901
902 auto* neqated_quotient =
903 computation->AddInstruction(HloInstruction::CreateUnary(
904 quotient->shape(), HloOpcode::kNegate, quotient));
905
906 return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
907 dividend_is_negative,
908 neqated_quotient, quotient);
909 }
910 } else {
911 uint64 b_value = c->literal().GetFirstElement<T>();
912 if (IsPowerOfTwo(b_value)) {
913 int log2_abs_b_value = tensorflow::Log2Floor64(b_value);
914 HloInstruction* shift_amount =
915 computation->AddInstruction(HloInstruction::CreateConstant(
916 LiteralUtil::CreateR0<T>(log2_abs_b_value)));
917 if (!ShapeUtil::IsScalar(b->shape())) {
918 shift_amount = computation->AddInstruction(
919 HloInstruction::CreateBroadcast(b->shape(), shift_amount, {}));
920 }
921 return HloInstruction::CreateBinary(
922 divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount);
923 }
924 }
925
926 return nullptr;
927 }
928 } // namespace
929
HandleDivide(HloInstruction * divide)930 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
931 HloInstruction *a, *b, *c, *d;
932 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
933 // A/1 => A
934 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
935 if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) {
936 return Status::OK();
937 }
938
939 // A / B => A >> log2(B) if B is a power of 2.
940 switch (divide->shape().element_type()) {
941 case S8:
942 if (std::unique_ptr<HloInstruction> shift =
943 TryDivideToShift<int8>(divide, computation_)) {
944 return ReplaceWithNewInstruction(divide, std::move(shift));
945 }
946 break;
947 case S16:
948 if (std::unique_ptr<HloInstruction> shift =
949 TryDivideToShift<int16>(divide, computation_)) {
950 return ReplaceWithNewInstruction(divide, std::move(shift));
951 }
952 break;
953 case S32:
954 if (std::unique_ptr<HloInstruction> shift =
955 TryDivideToShift<int32>(divide, computation_)) {
956 return ReplaceWithNewInstruction(divide, std::move(shift));
957 }
958 break;
959 case S64:
960 if (std::unique_ptr<HloInstruction> shift =
961 TryDivideToShift<int64>(divide, computation_)) {
962 return ReplaceWithNewInstruction(divide, std::move(shift));
963 }
964 break;
965 case U8:
966 if (std::unique_ptr<HloInstruction> shift =
967 TryDivideToShift<uint8>(divide, computation_)) {
968 return ReplaceWithNewInstruction(divide, std::move(shift));
969 }
970 break;
971 case U16:
972 if (std::unique_ptr<HloInstruction> shift =
973 TryDivideToShift<uint16>(divide, computation_)) {
974 return ReplaceWithNewInstruction(divide, std::move(shift));
975 }
976 break;
977 case U32:
978 if (std::unique_ptr<HloInstruction> shift =
979 TryDivideToShift<uint32>(divide, computation_)) {
980 return ReplaceWithNewInstruction(divide, std::move(shift));
981 }
982 break;
983 case U64:
984 if (std::unique_ptr<HloInstruction> shift =
985 TryDivideToShift<uint64>(divide, computation_)) {
986 return ReplaceWithNewInstruction(divide, std::move(shift));
987 }
988 break;
989 default:
990 break;
991 }
992
993 Shape* shape;
994 // exp(A)/exp(B) => exp(A-B)
995 if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
996 .WithShape(m::Shape(&shape)))) {
997 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
998 HloInstruction* subtract = computation_->AddInstruction(
999 HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1000 return ReplaceWithNewInstruction(
1001 divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1002 }
1003
1004 // A/exp(B) => A*exp(-B)
1005 if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1006 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1007 HloInstruction* negate = computation_->AddInstruction(
1008 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1009 HloInstruction* new_exp = computation_->AddInstruction(
1010 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1011 return ReplaceWithNewInstruction(
1012 divide, HloInstruction::CreateBinary(divide->shape(),
1013 HloOpcode::kMultiply, a, new_exp));
1014 }
1015
1016 // A/pow(B,C) => A*pow(B,-C)
1017 if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1018 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1019 // The output shape of the created negate operator should be the same as the
1020 // input.
1021 const Shape& negate_shape = c->shape();
1022 HloInstruction* negate = computation_->AddInstruction(
1023 HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1024 // And the power operator should retain the output shape of the old one.
1025 const Shape& new_power_shape = b->shape();
1026 HloInstruction* new_power =
1027 computation_->AddInstruction(HloInstruction::CreateBinary(
1028 new_power_shape, HloOpcode::kPower, b, negate));
1029 return ReplaceWithNewInstruction(
1030 divide, HloInstruction::CreateBinary(
1031 divide->shape(), HloOpcode::kMultiply, a, new_power));
1032 }
1033
1034 // A/sqrt(B) => A*rsqrt(X).
1035 if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1036 auto* rsqrt = computation_->AddInstruction(
1037 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1038 return ReplaceWithNewInstruction(
1039 divide, HloInstruction::CreateBinary(rsqrt->shape(),
1040 HloOpcode::kMultiply, a, rsqrt));
1041 }
1042
1043 // A/rsqrt(B) => A*sqrt(B).
1044 if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1045 auto* sqrt = computation_->AddInstruction(
1046 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1047 return ReplaceWithNewInstruction(
1048 divide, HloInstruction::CreateBinary(sqrt->shape(),
1049 HloOpcode::kMultiply, a, sqrt));
1050 }
1051
1052 // Simplifying integral division would produce unexpected results.
1053 if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1054 return Status::OK();
1055 }
1056
1057 // A / Const => A * (1 / Const)
1058 //
1059 // (Backends can do this transformation, but generally only if the constant is
1060 // a scalar.)
1061 if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
1062 Shape result_shape = b->literal().shape();
1063 Literal new_literal(result_shape);
1064 switch (result_shape.element_type()) {
1065 case F16:
1066 TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
1067 break;
1068 case F32:
1069 TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
1070 break;
1071 case BF16:
1072 TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
1073 break;
1074 case F64:
1075 TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
1076 break;
1077 case C64:
1078 TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
1079 break;
1080 case C128:
1081 TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
1082 break;
1083 default:
1084 return Status::OK();
1085 }
1086 auto inverse = computation_->AddInstruction(
1087 HloInstruction::CreateConstant((new_literal.Clone())));
1088 TF_ASSIGN_OR_RETURN(auto new_divide,
1089 MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1090 return ReplaceInstruction(divide, new_divide);
1091 }
1092
1093 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
1094 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1095 m::Divide(m::Op(&c), m::Op(&d))))) {
1096 TF_ASSIGN_OR_RETURN(auto a_times_d,
1097 MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1098 TF_ASSIGN_OR_RETURN(auto b_times_c,
1099 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1100 TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1101 a_times_d, b_times_c));
1102
1103 return ReplaceInstruction(divide, new_divide);
1104 }
1105
1106 // (A / B) / C => A / (B * C)
1107 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1108 TF_ASSIGN_OR_RETURN(auto b_times_c,
1109 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1110 TF_ASSIGN_OR_RETURN(auto new_divide,
1111 MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1112 return ReplaceInstruction(divide, new_divide);
1113 }
1114
1115 // A / (B / C) => (A*C) / B
1116 if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1117 TF_ASSIGN_OR_RETURN(auto a_times_c,
1118 MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1119 TF_ASSIGN_OR_RETURN(auto new_divide,
1120 MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1121 return ReplaceInstruction(divide, new_divide);
1122 }
1123
1124 return Status::OK();
1125 }
1126
HandleDotStrengthReduction(HloInstruction * dot)1127 StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
1128 HloInstruction* dot) {
1129 HloInstruction *lhs, *rhs;
1130 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1131
1132 const auto kept_dim = [](int64 rank, int64 contracting_dimension,
1133 absl::Span<const int64> batch_dimensions) -> int64 {
1134 for (int64 i = 0; i < rank; ++i) {
1135 if (i != contracting_dimension &&
1136 !absl::c_linear_search(batch_dimensions, i)) {
1137 return i;
1138 }
1139 }
1140 return -1;
1141 };
1142
1143 const int64 dot_rank = dot->shape().rank();
1144 const int64 rhs_rank = rhs->shape().rank();
1145 const int64 lhs_rank = lhs->shape().rank();
1146 const auto& dnums = dot->dot_dimension_numbers();
1147 if (dnums.rhs_contracting_dimensions_size() != 1) {
1148 return false;
1149 }
1150 if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) {
1151 return false;
1152 }
1153 int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0);
1154 int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim,
1155 AsInt64Slice(dnums.lhs_batch_dimensions()));
1156 // If there is no non-contracting dimension in rank 2, do not strength reduce.
1157 if (lhs_kept_dim == -1 && lhs_rank > 1) {
1158 return false;
1159 }
1160 if (lhs->IsRank2Transpose()) {
1161 lhs = lhs->mutable_operand(0);
1162 std::swap(lhs_collapsing_dim, lhs_kept_dim);
1163 }
1164
1165 int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0);
1166 int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim,
1167 AsInt64Slice(dnums.rhs_batch_dimensions()));
1168 // If there is no non-contracting dimension in rank 2, do not strength reduce.
1169 if (rhs_kept_dim == -1 && rhs_rank > 1) {
1170 return false;
1171 }
1172 if (rhs->IsRank2Transpose()) {
1173 rhs = rhs->mutable_operand(0);
1174 std::swap(rhs_collapsing_dim, rhs_kept_dim);
1175 }
1176
1177 auto reshape_if_necessary = [&](HloInstruction* hlo) {
1178 hlo = AsType(hlo, dot->shape().element_type());
1179 if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
1180 hlo = computation_->AddInstruction(
1181 HloInstruction::CreateReshape(dot->shape(), hlo));
1182 }
1183 return hlo;
1184 };
1185
1186 auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
1187 return AddReduce(AsType(hlo, F32), dim);
1188 };
1189
1190 auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
1191 absl::Span<const int64> dims) {
1192 return computation_->AddInstruction(
1193 HloInstruction::CreateBroadcast(shape, hlo, dims));
1194 };
1195
1196 auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
1197 int64 dim) {
1198 return broadcast(hlo, shape, {dim});
1199 };
1200
1201 auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
1202 return computation_->AddInstruction(HloInstruction::CreateBinary(
1203 local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
1204 };
1205
1206 // Strength reduce dot(a[K] , b[K]) =
1207 // reshape(result.shape,
1208 // reduce_sum(multiply(a, b), {0}))
1209 if (rhs_rank == 1 && lhs_rank == 1) {
1210 TF_RETURN_IF_ERROR(ReplaceInstruction(
1211 dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0))));
1212 return true;
1213 }
1214
1215 if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
1216 ShapeUtil::IsEffectiveScalar(lhs->shape())) {
1217 TF_RETURN_IF_ERROR(ReplaceInstruction(
1218 dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
1219 return true;
1220 }
1221
1222 // Simplify outer product into multiply with broadcasting.
1223 //
1224 // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
1225 if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
1226 TF_RETURN_IF_ERROR(ReplaceInstruction(
1227 dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
1228 broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
1229 return true;
1230 }
1231
1232 // Strength reduce dot(a[1, K], b) =
1233 // reshape(result.shape,
1234 // reduce_sum(
1235 // multiply(broadcast(reshape(a, [K]), {0}), b),
1236 // {0})
1237 // )
1238 // )
1239 if (lhs_rank == 1 ||
1240 (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) {
1241 if (rhs->shape().rank() == 1) {
1242 TF_RETURN_IF_ERROR(
1243 ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
1244 multiply(Flatten(lhs), rhs), 0))));
1245 return true;
1246 }
1247 TF_RETURN_IF_ERROR(ReplaceInstruction(
1248 dot, reshape_if_necessary(add_reduce_in_f32(
1249 multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
1250 rhs_collapsing_dim),
1251 rhs),
1252 rhs_collapsing_dim))));
1253 return true;
1254 }
1255
1256 // Strength reduce dot(a, b[K, 1]) =
1257 // reshape(result.shape,
1258 // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
1259 // )
1260 if (rhs_rank == 1 ||
1261 (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) {
1262 TF_RETURN_IF_ERROR(ReplaceInstruction(
1263 dot, reshape_if_necessary(add_reduce_in_f32(
1264 multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
1265 lhs_collapsing_dim)),
1266 lhs_collapsing_dim))));
1267 return true;
1268 }
1269
1270 // Only consider kDot with batch dimension.
1271 if (dot_rank <= 2) {
1272 return false;
1273 }
1274
1275 CHECK_EQ(rhs_rank, lhs_rank);
1276 CHECK_EQ(dot_rank, lhs_rank);
1277 // If there is more than one non-contracting dimension or the batch dimensions
1278 // are not equal, bail out since transposes may be required to do a strength
1279 // reduction.
1280 if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank ||
1281 !absl::c_equal(dnums.lhs_batch_dimensions(),
1282 dnums.rhs_batch_dimensions())) {
1283 return false;
1284 }
1285
1286 auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) {
1287 absl::InlinedVector<int64, 8> dims;
1288 for (int64 i = 0; i < rank; ++i) {
1289 if (i != non_broadcast_dim) {
1290 dims.push_back(i);
1291 }
1292 }
1293 return dims;
1294 };
1295
1296 // If the contracting dimension is 1, remove the degnerate dimnensions from
1297 // the lhs and rhs, broadcast each to the result shape and multiply.
1298 if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
1299 (rhs_kept_dim == rhs_rank - 1 ||
1300 (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
1301 CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1);
1302 const int64 lhs_kept_dim_in_output =
1303 lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim;
1304 absl::InlinedVector<int64, 8> lhs_broadcast_dims;
1305 for (const int64 dim : dnums.lhs_batch_dimensions()) {
1306 lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim);
1307 }
1308 absl::InlinedVector<int64, 8> rhs_broadcast_dims = lhs_broadcast_dims;
1309 lhs_broadcast_dims.push_back(lhs_kept_dim_in_output);
1310 absl::c_sort(lhs_broadcast_dims);
1311 rhs_broadcast_dims.push_back(dot_rank - 1);
1312 absl::c_sort(rhs_broadcast_dims);
1313 TF_RETURN_IF_ERROR(ReplaceInstruction(
1314 dot, reshape_if_necessary(
1315 multiply(broadcast(StripDim(lhs, lhs_collapsing_dim),
1316 dot->shape(), lhs_broadcast_dims),
1317 broadcast(StripDim(rhs, rhs_collapsing_dim),
1318 dot->shape(), rhs_broadcast_dims)))));
1319 return true;
1320 }
1321
1322 // If the lhs and rhs non-contracting dimensions are both one, strip each one,
1323 // multiply and then reduce the collapsing dimension
1324 if (lhs->shape().dimensions(lhs_kept_dim) == 1 &&
1325 rhs->shape().dimensions(rhs_kept_dim) == 1 &&
1326 lhs_kept_dim == rhs_kept_dim) {
1327 auto new_lhs = StripDim(lhs, lhs_kept_dim);
1328 auto new_rhs = StripDim(rhs, rhs_kept_dim);
1329 const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim
1330 ? (rhs_collapsing_dim - 1)
1331 : rhs_collapsing_dim;
1332 TF_RETURN_IF_ERROR(
1333 ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
1334 multiply(new_lhs, new_rhs), reduce_dim))));
1335 return true;
1336 }
1337
1338 // If the lhs non-contracting dimensions is one, strip the one, brodcast to
1339 // the rhs shape, multiply and then reduce the collapsing dimension
1340 if (lhs->shape().dimensions(lhs_kept_dim) == 1) {
1341 auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(),
1342 broadcast_dims(rhs_rank, rhs_kept_dim));
1343 TF_RETURN_IF_ERROR(ReplaceInstruction(
1344 dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs),
1345 rhs_collapsing_dim))));
1346 return true;
1347 }
1348
1349 // If the rhs non-contracting dimensions is one, strip the one, brodcast to
1350 // the lhs shape, multiply and then reduce the collapsing dimension
1351 if (rhs->shape().dimensions(rhs_kept_dim) == 1) {
1352 auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(),
1353 broadcast_dims(lhs_rank, lhs_kept_dim));
1354 TF_RETURN_IF_ERROR(ReplaceInstruction(
1355 dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs),
1356 lhs_collapsing_dim))));
1357 return true;
1358 }
1359
1360 return false;
1361 }
1362
OptimizeDotOfConcat(HloInstruction * dot)1363 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
1364 HloInstruction* dot) {
1365 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1366 if (dnums.lhs_contracting_dimensions_size() != 1 ||
1367 dnums.lhs_batch_dimensions_size() != 0) {
1368 return nullptr;
1369 }
1370
1371 const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
1372 const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
1373 HloInstruction *lhs, *rhs;
1374 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1375
1376 TF_ASSIGN_OR_RETURN(
1377 HloInstruction * optimized_lhs_concat,
1378 OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs,
1379 rhs_contracting_dim, /*swapped=*/false));
1380 if (optimized_lhs_concat) {
1381 return optimized_lhs_concat;
1382 }
1383
1384 return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs,
1385 lhs_contracting_dim, /*swapped=*/true);
1386 }
1387
OptimizeDotOfConcatHelper(const HloInstruction & dot,HloInstruction * lhs,int64 lhs_contracting_dim,HloInstruction * rhs,int64 rhs_contracting_dim,bool swapped)1388 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
1389 const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim,
1390 HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
1391 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
1392 lhs->concatenate_dimension() == lhs_contracting_dim &&
1393 rhs->opcode() == HloOpcode::kConstant;
1394 if (!can_optimize) {
1395 return nullptr;
1396 }
1397
1398 // We're replacing this:
1399 //
1400 // +-----+-----+-----+ +-------------------+
1401 // | | | | | |
1402 // | | | | | R_0 |
1403 // | | | | | |
1404 // | | | | +-------------------+
1405 // | | | | | |
1406 // | L_0 | L_1 | L_2 | * | R_1 |
1407 // | | | | | |
1408 // | | | | +-------------------+
1409 // | | | | | |
1410 // | | | | | R_2 |
1411 // | | | | | |
1412 // +-----+-----+-----+ +-------------------+
1413 //
1414 // with this:
1415 //
1416 // [Sum over i]
1417 //
1418 // +-----+ +-------------------+
1419 // | | | |
1420 // | | * | R_i |
1421 // | | | |
1422 // | | +-------------------+
1423 // | |
1424 // | L_i |
1425 // | |
1426 // | |
1427 // | |
1428 // | |
1429 // | |
1430 // +-----+
1431 //
1432 // where the LHS is a concatenate operation (so we can "split" the LHS tensor
1433 // for free) and the RHS is a constant tensor (and thus can be split at
1434 // compile time). In the future, we may also want to do this when both the
1435 // LHS and the RHS are concatenate operations that line up along the dimension
1436 // being contracted over.
1437 //
1438 // We should be able to generalize this transform to work on a non-constant
1439 // RHS when/if we have in-place slices or support input-fusing slices into
1440 // Dots.
1441
1442 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
1443 // the diagram above).
1444 DotDimensionNumbers new_dot_dnums;
1445 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
1446 : lhs_contracting_dim);
1447 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
1448 : rhs_contracting_dim);
1449
1450 // Here we use the MKN notation, where the contracted dimension has K
1451 // elements and the two non-contracted dimensions have M and N elements.
1452 HloInstruction* add_result = nullptr;
1453 int64 rhs_contracting_dim_offset = 0;
1454 int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim);
1455 for (HloInstruction* concat_op : lhs->operands()) {
1456 int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
1457 Shape rhs_slice_shape(rhs->shape());
1458 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
1459
1460 std::array<int64, 2> start_indices;
1461 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
1462 start_indices[1 - rhs_contracting_dim] = 0;
1463
1464 std::array<int64, 2> limit_indices;
1465 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
1466 limit_indices[1 - rhs_contracting_dim] = n;
1467
1468 HloInstruction* rhs_slice =
1469 computation_->AddInstruction(HloInstruction::CreateSlice(
1470 rhs_slice_shape, rhs, /*start_indices=*/start_indices,
1471 /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
1472
1473 // TODO(b/69062148): We can get rid of `swapped` once all backends support
1474 // "non-canonical" contraction dimensions (that contracts dimension 1 of the
1475 // LHS with dimension 0 of the RHS). But for now we keep the same
1476 // contraction dimensions as the incoming dot operation to ensure the new
1477 // dot operations can be lowered.
1478 HloInstruction *new_dot_lhs, *new_dot_rhs;
1479 if (swapped) {
1480 new_dot_lhs = rhs_slice;
1481 new_dot_rhs = concat_op;
1482 } else {
1483 new_dot_lhs = concat_op;
1484 new_dot_rhs = rhs_slice;
1485 }
1486
1487 auto* new_dot = computation_->AddInstruction(
1488 HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
1489 new_dot_dnums, dot.precision_config()));
1490
1491 if (add_result) {
1492 add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
1493 dot.shape(), HloOpcode::kAdd, add_result, new_dot));
1494 } else {
1495 add_result = new_dot;
1496 }
1497
1498 rhs_contracting_dim_offset += sub_k;
1499 }
1500
1501 return add_result;
1502 }
1503
OptimizeDotOfGather(HloInstruction * dot)1504 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
1505 HloInstruction* dot) {
1506 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1507 if (dnums.lhs_contracting_dimensions_size() != 1 ||
1508 dnums.rhs_contracting_dimensions_size() != 1 ||
1509 dnums.lhs_batch_dimensions_size() != 0 ||
1510 dnums.rhs_batch_dimensions_size() != 0 ||
1511 dot->shape().dimensions_size() != 2) { // dot output 2D
1512 VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
1513 return nullptr;
1514 }
1515
1516 // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
1517 // Currently a Gather is a DynamicSlice.
1518 auto is_dynamic_slice_constant_combination =
1519 [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
1520 // First operand is a DynamicSlice(Constant).
1521 if (a->opcode() != HloOpcode::kDynamicSlice) {
1522 return false;
1523 }
1524 auto* dynamic_slice_op = a->operand(0);
1525 if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
1526 return false;
1527 }
1528 // Second operand is a Constant.
1529 if (b->opcode() != HloOpcode::kConstant) {
1530 return false;
1531 }
1532 // The DynamicSlice output is a vector.
1533 const Shape& dynamic_slice_shape = a->shape();
1534 if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
1535 return false;
1536 }
1537 // Constant size is the same before and after slice in the contracting
1538 // dimension, otherwise we either must precompute for all possible slice
1539 // indices or dot is invalid.
1540 const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
1541 if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
1542 dynamic_slice_shape.dimensions(a_contracting_dimension)) {
1543 return false;
1544 }
1545 return true;
1546 };
1547
1548 HloInstruction* lhs = dot->mutable_operand(0);
1549 HloInstruction* rhs = dot->mutable_operand(1);
1550 int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
1551 int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
1552
1553 if (!is_dynamic_slice_constant_combination(
1554 lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
1555 !is_dynamic_slice_constant_combination(
1556 rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
1557 VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
1558 "dot(ctB, DS(ctA)), where the two constants have equal "
1559 "contracting dimensions.";
1560 return nullptr;
1561 }
1562
1563 // LHS is DynamicSlice:
1564 // input: dot(DS(ctA), ctB))
1565 // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
1566 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
1567 // output: DS(dot(ctA, ctB))
1568 // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
1569
1570 // RHS is DynamicSlice:
1571 // input: dot(ctA, DS(ctB))
1572 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
1573 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
1574 // output: DS(dot(ctA, ctB))
1575 // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
1576
1577 bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
1578 HloDynamicSliceInstruction* dynamic_slice =
1579 lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
1580 : Cast<HloDynamicSliceInstruction>(rhs);
1581
1582 // ctA:
1583 HloInstruction* left_operand =
1584 lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
1585 // ctB:
1586 HloInstruction* right_operand =
1587 lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
1588 // Build ctA x ctB.
1589 const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
1590 const int n =
1591 right_operand->shape().dimensions(1 - rhs_contracting_dimension);
1592 auto memoized_shape =
1593 ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
1594 auto* memoized_inst = computation_->AddInstruction(
1595 HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
1596 dnums, dot->precision_config()));
1597 // Get pair {start, 0} or {0, start}.
1598 // Position of start:
1599 int index_of_non_zero_start = lhs_is_dynamic_slice
1600 ? 1 - lhs_contracting_dimension
1601 : 1 - rhs_contracting_dimension;
1602 // Position of zero:
1603 int index_of_zero_start = 1 - index_of_non_zero_start;
1604
1605 // Slice out start and 0 components and reorder if necessary.
1606 auto indices_type = dynamic_slice->operand(1)->shape().element_type();
1607 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
1608 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
1609 HloInstruction* non_zero_start =
1610 dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
1611 HloInstruction* zero_start =
1612 dynamic_slice->mutable_operand(1 + index_of_zero_start);
1613 std::vector<HloInstruction*> new_start_indices;
1614 if (lhs_is_dynamic_slice) {
1615 new_start_indices = {non_zero_start, zero_start};
1616 } else {
1617 new_start_indices = {zero_start, non_zero_start};
1618 }
1619
1620 // Build DynamicSlice(ctA x ctB).
1621 const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
1622 const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
1623 auto* memoized_lookup =
1624 computation_->AddInstruction(HloInstruction::CreateDynamicSlice(
1625 dot->shape(), memoized_inst, new_start_indices,
1626 {new_slice_m, new_slice_n}));
1627
1628 return memoized_lookup;
1629 }
1630
HandleDot(HloInstruction * dot)1631 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
1632 HloInstruction *lhs, *rhs;
1633 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
1634 if (options_.is_layout_sensitive()) {
1635 return Status::OK();
1636 }
1637 // Replace a zero element dot with a broadcast of the constant 0.
1638 if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
1639 ShapeUtil::IsZeroElementArray(lhs->shape()) ||
1640 ShapeUtil::IsZeroElementArray(rhs->shape())) {
1641 auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
1642 LiteralUtil::Zero(dot->shape().element_type())));
1643 return ReplaceWithNewInstruction(
1644 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
1645 }
1646
1647 // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are
1648 // rank 2 or below.
1649 if (dot->shape().element_type() != F32 &&
1650 dot->shape().element_type() != BF16) {
1651 return Status::OK();
1652 }
1653
1654 // If there are no contracting dimensions, a dot can be rewritten as
1655 // mul(broadcast(transpose(x)),broadcast(transpose(y)))
1656 if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
1657 TF_ASSIGN_OR_RETURN(
1658 HloInstruction * new_lhs,
1659 NormalizeDotOperandToBatchMajorAndContractingMinor(
1660 lhs,
1661 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
1662 AsInt64Slice(
1663 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
1664 if (dot->shape().rank() != lhs->shape().rank()) {
1665 std::vector<int64> lhs_broadcast_dims(lhs->shape().rank());
1666 absl::c_iota(lhs_broadcast_dims, 0);
1667 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1668 dot->shape(), new_lhs, lhs_broadcast_dims));
1669 }
1670 TF_ASSIGN_OR_RETURN(
1671 HloInstruction * new_rhs,
1672 NormalizeDotOperandToBatchMajorAndContractingMinor(
1673 rhs,
1674 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
1675 AsInt64Slice(
1676 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
1677 if (dot->shape().rank() != rhs->shape().rank()) {
1678 std::vector<int64> rhs_broadcast_dims(
1679 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1680 absl::c_iota(rhs_broadcast_dims, 0);
1681 for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
1682 rhs_broadcast_dims.push_back(i);
1683 }
1684 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1685 dot->shape(), new_rhs, rhs_broadcast_dims));
1686 }
1687 return ReplaceWithNewInstruction(
1688 dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
1689 new_lhs, new_rhs));
1690 }
1691
1692 // If the lhs or rhs have only batch and contracting dimensions, a dot can be
1693 // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
1694 if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1695 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() ==
1696 lhs->shape().rank()) ||
1697 (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() +
1698 dot->dot_dimension_numbers().rhs_batch_dimensions_size() ==
1699 rhs->shape().rank())) {
1700 TF_ASSIGN_OR_RETURN(
1701 HloInstruction * new_lhs,
1702 NormalizeDotOperandToBatchMajorAndContractingMinor(
1703 lhs,
1704 AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()),
1705 AsInt64Slice(
1706 dot->dot_dimension_numbers().lhs_contracting_dimensions())));
1707 TF_ASSIGN_OR_RETURN(
1708 HloInstruction * new_rhs,
1709 NormalizeDotOperandToBatchMajorAndContractingMinor(
1710 rhs,
1711 AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()),
1712 AsInt64Slice(
1713 dot->dot_dimension_numbers().rhs_contracting_dimensions())));
1714
1715 int64 lhs_outer_dims =
1716 lhs->shape().rank() -
1717 (dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1718 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
1719 int64 rhs_outer_dims =
1720 rhs->shape().rank() -
1721 (dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
1722 dot->dot_dimension_numbers().rhs_contracting_dimensions_size());
1723 CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
1724 if (rhs_outer_dims > 0) {
1725 std::vector<int64> lhs_broadcast_dims(
1726 dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1727 absl::c_iota(lhs_broadcast_dims, 0);
1728 lhs_broadcast_dims.resize(lhs->shape().rank());
1729 std::iota(lhs_broadcast_dims.begin() +
1730 dot->dot_dimension_numbers().lhs_batch_dimensions_size(),
1731 lhs_broadcast_dims.end(),
1732 dot->dot_dimension_numbers().lhs_batch_dimensions_size() +
1733 rhs_outer_dims);
1734 new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1735 new_rhs->shape(), new_lhs, lhs_broadcast_dims));
1736 } else if (lhs_outer_dims > 0) {
1737 std::vector<int64> rhs_broadcast_dims(
1738 dot->dot_dimension_numbers().rhs_batch_dimensions_size());
1739 absl::c_iota(rhs_broadcast_dims, 0);
1740 rhs_broadcast_dims.resize(rhs->shape().rank());
1741 std::iota(rhs_broadcast_dims.begin() +
1742 dot->dot_dimension_numbers().rhs_batch_dimensions_size(),
1743 rhs_broadcast_dims.end(),
1744 dot->dot_dimension_numbers().rhs_batch_dimensions_size() +
1745 lhs_outer_dims);
1746 new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast(
1747 new_lhs->shape(), new_rhs, rhs_broadcast_dims));
1748 }
1749
1750 TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
1751 MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
1752 std::vector<int64> reduce_dims(
1753 dot->dot_dimension_numbers().lhs_contracting_dimensions_size());
1754 new_dot = AsType(new_dot, F32);
1755 const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
1756 absl::c_iota(
1757 reduce_dims,
1758 outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size());
1759 new_dot = AddReduce(new_dot, reduce_dims);
1760 new_dot = AsType(new_dot, dot->shape().element_type());
1761 return ReplaceInstruction(dot, new_dot);
1762 }
1763
1764 if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
1765 dot->shape().rank() > 2) {
1766 if (options_.enable_dot_strength_reduction() &&
1767 !options_.is_layout_sensitive()) {
1768 TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status());
1769 }
1770 return Status::OK();
1771 }
1772
1773 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
1774 OptimizeDotOfConcat(dot));
1775 if (dot_of_concat_optimized) {
1776 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
1777 "constant)...)";
1778 return ReplaceInstruction(dot, dot_of_concat_optimized);
1779 }
1780
1781 // Simplify dot(ConstA, Gather(Index, ConstB)) to:
1782 // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
1783 // batched version of dot.
1784 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
1785 OptimizeDotOfGather(dot));
1786 if (dot_of_gather_optimized) {
1787 VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
1788 "gather(i, dot*(constA, constB))";
1789 return ReplaceInstruction(dot, dot_of_gather_optimized);
1790 }
1791
1792 if (options_.enable_dot_strength_reduction() &&
1793 !options_.is_layout_sensitive()) {
1794 TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
1795 HandleDotStrengthReduction(dot));
1796 if (did_strength_reduction) {
1797 return Status::OK();
1798 }
1799 }
1800
1801 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
1802 if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 &&
1803 dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 &&
1804 dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 &&
1805 dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 &&
1806 lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
1807 DotDimensionNumbers dot_dimension_numbers;
1808 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
1809 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
1810 auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
1811 ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
1812 rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
1813 dot->precision_config()));
1814 return ReplaceWithNewInstruction(
1815 dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
1816 }
1817
1818 return Status::OK();
1819 }
1820
HandleMultiply(HloInstruction * multiply)1821 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
1822 HloInstruction *lhs, *rhs;
1823 CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
1824 // A*1 => A
1825 VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
1826 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
1827 return Status::OK();
1828 }
1829 // 1*A => A
1830 VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
1831 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
1832 return Status::OK();
1833 }
1834
1835 // 0*A => 0. Only applies for integral types for correct NaN-handling.
1836 if (IsAll(lhs, 0) &&
1837 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
1838 ReplaceInstructionIfSameShape(multiply, lhs)) {
1839 return Status::OK();
1840 }
1841 // A*0 => 0
1842 if (IsAll(rhs, 0) &&
1843 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
1844 ReplaceInstructionIfSameShape(multiply, rhs)) {
1845 return Status::OK();
1846 }
1847
1848 // exp(A) * exp(B) => exp(A+B)
1849 if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
1850 auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
1851 multiply->shape(), HloOpcode::kAdd, lhs, rhs));
1852 return ReplaceWithNewInstruction(
1853 multiply,
1854 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
1855 }
1856 return Status::OK();
1857 }
1858
HandleNegate(HloInstruction * negate)1859 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
1860 // negate(negate(x)) => x
1861 HloInstruction* x;
1862 if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
1863 ReplaceInstructionIfSameShape(negate, x)) {
1864 return Status::OK();
1865 }
1866 return Status::OK();
1867 }
1868
HandleNot(HloInstruction * logical_not)1869 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
1870 // not(not(x)) => x
1871 HloInstruction* x;
1872 if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
1873 ReplaceInstructionIfSameShape(logical_not, x)) {
1874 return Status::OK();
1875 }
1876 return Status::OK();
1877 }
1878
HandleOr(HloInstruction * logical_or)1879 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
1880 HloInstruction *lhs, *rhs;
1881 CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
1882
1883 // Simplify logical or
1884 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1885 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1886 // A || True => True
1887 VLOG(10) << "trying transform [A || True => True]: "
1888 << logical_or->ToString();
1889 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
1890 return Status::OK();
1891 }
1892 // True || A => True
1893 VLOG(10) << "trying transform [True || A => True]: "
1894 << logical_or->ToString();
1895 if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
1896 return Status::OK();
1897 }
1898
1899 // A || False => A
1900 VLOG(10) << "trying transform [A || False => A]: "
1901 << logical_or->ToString();
1902 if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) {
1903 return Status::OK();
1904 }
1905
1906 // False || A => A
1907 VLOG(10) << "trying transform [False || A => A]: "
1908 << logical_or->ToString();
1909 if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) {
1910 return Status::OK();
1911 }
1912 }
1913
1914 return Status::OK();
1915 }
1916
HandleLog(HloInstruction * log)1917 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
1918 // ln(exp(A)) => A
1919 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
1920 HloInstruction *a, *b;
1921 if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
1922 ReplaceInstructionIfSameShape(log, a)) {
1923 return Status::OK();
1924 }
1925
1926 // ln(pow(A,B)) => B*ln(A)
1927 if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
1928 auto new_log = computation_->AddInstruction(
1929 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
1930 return ReplaceWithNewInstruction(
1931 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
1932 new_log, b));
1933 }
1934
1935 return Status::OK();
1936 }
1937
HandleGetTupleElement(HloInstruction * get_tuple_element)1938 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
1939 HloInstruction* get_tuple_element) {
1940 auto operand = get_tuple_element->mutable_operand(0);
1941 if (operand->opcode() == HloOpcode::kTuple) {
1942 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
1943 VLOG(10) << "trying transform "
1944 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
1945 << get_tuple_element->ToString();
1946 if (ReplaceInstructionIfSameShape(
1947 get_tuple_element,
1948 operand->mutable_operand(get_tuple_element->tuple_index()))) {
1949 return Status::OK();
1950 }
1951 }
1952 return Status::OK();
1953 }
1954
1955 namespace {
1956
1957 // Return whether the given reshape instruction leaves the dimensions at the
1958 // given input indices unmodified, and returns their output indices.
1959 //
1960 // Example:
1961 // input_dim_indices = {2, 3}
1962 // input shape = T[a, b, x, y, cd]
1963 // output shape = T[ab, x, 1, y, c, d]
1964 // return value = {1, 3}
1965 //
1966 // Precondition: input_dim_indices is sorted.
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64> input_dim_indices)1967 absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
1968 const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
1969 CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
1970 CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
1971
1972 std::vector<int64> output_dim_indices;
1973 std::vector<std::pair<int64, int64>> unmodified_dims =
1974 ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
1975 hlo->shape());
1976 size_t i = 0; // index to unmodified_dims
1977 for (int64 input_dim_index : input_dim_indices) {
1978 // Search unmodified_dims for input_dim_index. We can search from the last
1979 // matching position because input_dim_indices is guaranteed to be sorted.
1980 while (i < unmodified_dims.size() &&
1981 unmodified_dims[i].first < input_dim_index) {
1982 ++i;
1983 }
1984 if (i >= unmodified_dims.size() ||
1985 unmodified_dims[i].first != input_dim_index) {
1986 return absl::nullopt;
1987 }
1988 output_dim_indices.push_back(unmodified_dims[i].second);
1989 }
1990 return output_dim_indices;
1991 }
1992
1993 // Returns true if the output of "instruction" is a permutation of the
1994 // elements of "operand". Precondition: "operand" is an operand of
1995 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)1996 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
1997 HloInstruction* operand) {
1998 DCHECK(!instruction->OperandIndices(operand).empty());
1999 switch (instruction->opcode()) {
2000 case HloOpcode::kReshape:
2001 case HloOpcode::kReverse:
2002 case HloOpcode::kTranspose:
2003 return true;
2004 case HloOpcode::kSort:
2005 return (!instruction->shape().IsTuple());
2006 default:
2007 return false;
2008 }
2009 }
2010
2011 // Returns true if the output of "instruction" is a subset of the elements of
2012 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)2013 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
2014 HloInstruction* operand) {
2015 std::vector<int64> operand_indices = instruction->OperandIndices(operand);
2016 CHECK(!operand_indices.empty());
2017 if (operand_indices.size() != 1) {
2018 return false;
2019 }
2020 int64 operand_index = operand_indices[0];
2021 switch (instruction->opcode()) {
2022 case HloOpcode::kSlice:
2023 CHECK_EQ(0, operand_index);
2024 return true;
2025 case HloOpcode::kDynamicSlice:
2026 return operand_index == 0;
2027 default:
2028 return false;
2029 }
2030 }
2031
2032 } // namespace
2033
HandleBroadcast(HloInstruction * broadcast)2034 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
2035 HloInstruction* operand;
2036 CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
2037 auto dims = broadcast->dimensions();
2038 // A degenerate broadcast of a reshape that does not change the number of
2039 // elements can be replaced by a reshape.
2040 if (std::is_sorted(dims.begin(), dims.end()) &&
2041 ShapeUtil::ElementsIn(broadcast->shape()) ==
2042 ShapeUtil::ElementsIn(operand->shape())) {
2043 VLOG(10) << "transform broadcast(X) -> reshape(X) where "
2044 "n(broadcast(X)) == n(X)";
2045 return ReplaceWithNewInstruction(
2046 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
2047 }
2048
2049 // A degenerate broadcast that has the same input and output rank can be
2050 // converted into a transpose.
2051 if (broadcast->shape().rank() == operand->shape().rank() &&
2052 ShapeUtil::ElementsIn(broadcast->shape()) ==
2053 ShapeUtil::ElementsIn(operand->shape())) {
2054 VLOG(10) << "transform broadcast(X) -> transpose(X) where "
2055 "n(broadcast(X)) == n(X)";
2056 return ReplaceWithNewInstruction(
2057 broadcast,
2058 HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
2059 }
2060
2061 // A broadcast of a reshape which merely inserts 1-sized dimensions can
2062 // elide its operand.
2063 {
2064 bool merely_inserts_or_deletes_1_sized_dimensions;
2065 std::vector<int64> inserted_indices, deleted_indices;
2066 std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
2067 inserted_indices) =
2068 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
2069 if (merely_inserts_or_deletes_1_sized_dimensions &&
2070 deleted_indices.empty()) {
2071 std::reverse(inserted_indices.begin(), inserted_indices.end());
2072 for (auto inserted_index : inserted_indices) {
2073 dims.erase(dims.begin() + inserted_index);
2074 }
2075 return ReplaceWithNewInstruction(
2076 broadcast,
2077 HloInstruction::CreateBroadcast(broadcast->shape(),
2078 operand->mutable_operand(0), dims));
2079 }
2080 }
2081
2082 // A Broadcast that feeds a unary element-wise operation can sink the
2083 // broadcast after the unary element-wise operation.
2084 TF_ASSIGN_OR_RETURN(
2085 bool sink_succeeded,
2086 TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
2087 changed_ |= sink_succeeded;
2088 if (sink_succeeded) {
2089 return Status::OK();
2090 }
2091
2092 // A scalar broadcast feeding an instruction which only permutes (reshape,
2093 // transpose, sort, reverse) or selects a subset of operand elements (slice,
2094 // dynamic slice) can be replaced with a broadcast directly to the output
2095 // shape of the instruction.
2096 if (ShapeUtil::IsScalar(operand->shape())) {
2097 for (HloInstruction* user : broadcast->users()) {
2098 // Skip if the broadcast user has no uses itself.
2099 if (user->user_count() == 0 && user != computation_->root_instruction()) {
2100 continue;
2101 }
2102 if (OutputIsPermutationOfOperandElements(user, broadcast) ||
2103 OutputIsSubsetOfOperandElements(user, broadcast)) {
2104 VLOG(10) << "transform permuting/subset of a scalar broadcast into "
2105 << "a single broadcast";
2106 HloInstruction* new_broadcast = computation_->AddInstruction(
2107 HloInstruction::CreateBroadcast(user->shape(), operand, {}));
2108 // Use HloInstruction::ReplaceAllUsesWith instead of
2109 // HloComputation::ReplaceWithNewInstruction because we are replacing an
2110 // instruction other than the visited instruction.
2111 changed_ = true;
2112 return user->ReplaceAllUsesWith(new_broadcast);
2113 }
2114 }
2115 return Status::OK();
2116 }
2117
2118 // broadcast(iota) -> iota.
2119 if (operand->opcode() == HloOpcode::kIota) {
2120 return ReplaceWithNewInstruction(
2121 broadcast,
2122 HloInstruction::CreateIota(
2123 broadcast->shape(),
2124 dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
2125 }
2126
2127 // Merge two consecutive broadcasts into a single one.
2128 if (operand->opcode() == HloOpcode::kBroadcast) {
2129 std::vector<int64> new_dimensions;
2130 for (auto dim : operand->dimensions()) {
2131 new_dimensions.push_back(dims[dim]);
2132 }
2133 return ReplaceWithNewInstruction(
2134 broadcast,
2135 HloInstruction::CreateBroadcast(
2136 broadcast->shape(), operand->mutable_operand(0), new_dimensions));
2137 }
2138 return Status::OK();
2139 }
2140
2141 // A conversion to the same element type as the operand is a nop and can be
2142 // removed. A conversion of a constant can be simplified by making a new
2143 // constant.
HandleConvert(HloInstruction * convert)2144 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
2145 PrimitiveType src_type = convert->operand(0)->shape().element_type();
2146 PrimitiveType dest_type = convert->shape().element_type();
2147 if (src_type == dest_type) {
2148 return ReplaceInstruction(convert, convert->mutable_operand(0));
2149 }
2150 return Status::OK();
2151 }
2152
2153 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)2154 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
2155 HloInstruction *c0, *c1;
2156 if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
2157 c0 == c1) {
2158 return ReplaceInstruction(complex, c0);
2159 }
2160 return Status::OK();
2161 }
2162
2163 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)2164 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
2165 HloInstruction* op;
2166 if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
2167 return ReplaceInstruction(real, op);
2168 }
2169 return Status::OK();
2170 }
2171
2172 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)2173 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
2174 HloInstruction* op;
2175 if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
2176 return ReplaceInstruction(imag, op);
2177 }
2178 return Status::OK();
2179 }
2180
HandleIota(HloInstruction * instruction)2181 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
2182 // iota -> zero if the iota dimension never produces an element other than
2183 // zero.
2184 auto* iota = Cast<HloIotaInstruction>(instruction);
2185 if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
2186 auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
2187 LiteralUtil::Zero(iota->shape().element_type()).Clone()));
2188 return ReplaceWithNewInstruction(
2189 iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
2190 }
2191 return Status::OK();
2192 }
2193
HandlePad(HloInstruction * pad)2194 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
2195 if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
2196 return ReplaceWithNewInstruction(
2197 pad, HloInstruction::CreateBroadcast(pad->shape(),
2198 pad->mutable_operand(1), {}));
2199 }
2200
2201 // Interior padding on one sized dimensions have no effect. As a result it
2202 // makes other simplifications possible if there is no interior padding.
2203 if (HasInteriorPadding(pad->padding_config())) {
2204 PaddingConfig padding_config = pad->padding_config();
2205 bool cleared_interior_padding = false;
2206 for (int64 i = 0; i < pad->shape().rank(); ++i) {
2207 if (padding_config.dimensions(i).interior_padding() > 0 &&
2208 pad->operand(0)->shape().dimensions(i) == 1) {
2209 cleared_interior_padding = true;
2210 padding_config.mutable_dimensions(i)->set_interior_padding(0);
2211 }
2212 }
2213 if (cleared_interior_padding) {
2214 return ReplaceWithNewInstruction(
2215 pad,
2216 HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
2217 pad->mutable_operand(1), padding_config));
2218 }
2219 }
2220
2221 // Eliminate nop pads (padding all zero), and replace a pad with negative
2222 // padding with a pad with non-negative padding followed by a slice.
2223 bool all_zero = true;
2224 bool has_negative = false;
2225 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2226 if (padding_dimension.edge_padding_low() < 0 ||
2227 padding_dimension.edge_padding_high() < 0) {
2228 has_negative = true;
2229 }
2230 if (padding_dimension.edge_padding_low() != 0 ||
2231 padding_dimension.edge_padding_high() != 0) {
2232 all_zero = false;
2233 }
2234 }
2235
2236 if (all_zero) {
2237 ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
2238 return Status::OK();
2239 }
2240
2241 if (has_negative) {
2242 // Pad has negative padding. Replace with a pad with the non-negative
2243 // padding followed by a slice which effectively performs the negative
2244 // padding.
2245 // TODO(b/34628603): Add support for negative padding in the backends, or
2246 // change kPad semantics to disallow negative padding and use slice
2247 // instead.
2248
2249 // First construct the padding config with non-negative entries and the
2250 // compute the shape of this new pad instruction.
2251 PaddingConfig nonzero_padding = pad->padding_config();
2252 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
2253 PaddingConfig::PaddingConfigDimension* padding_dimension =
2254 nonzero_padding.mutable_dimensions(i);
2255 // Set negative padding to zero.
2256 if (padding_dimension->edge_padding_low() < 0) {
2257 padding_dimension->set_edge_padding_low(0);
2258 }
2259 if (padding_dimension->edge_padding_high() < 0) {
2260 padding_dimension->set_edge_padding_high(0);
2261 }
2262 }
2263
2264 TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
2265 MakePadHlo(pad->mutable_operand(0),
2266 pad->mutable_operand(1), nonzero_padding));
2267 // Copy the layout from the original pad instructions. The new pad and the
2268 // slice instruction should all have the same layout.
2269 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
2270 pad->shape(), nonzero_pad->mutable_shape()));
2271
2272 // Second, construct the slice instruction to perform the negative padding.
2273 std::vector<int64> start_indices;
2274 std::vector<int64> end_indices;
2275 std::vector<int64> strides;
2276 for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
2277 const PaddingConfig::PaddingConfigDimension& padding_dimension =
2278 pad->padding_config().dimensions(i);
2279 int64 start = 0;
2280 if (padding_dimension.edge_padding_low() < 0) {
2281 start = -1 * padding_dimension.edge_padding_low();
2282 }
2283 int64 end = nonzero_pad->shape().dimensions(i);
2284 if (padding_dimension.edge_padding_high() < 0) {
2285 end += padding_dimension.edge_padding_high();
2286 }
2287 start_indices.push_back(start);
2288 end_indices.push_back(end);
2289 strides.push_back(1);
2290 }
2291
2292 TF_ASSIGN_OR_RETURN(
2293 HloInstruction * slice,
2294 MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
2295
2296 // Verify that the slice shape matches the pad shape.
2297 TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape()));
2298
2299 return ReplaceInstruction(pad, slice);
2300 }
2301
2302 return Status::OK();
2303 }
2304
HandlePower(HloInstruction * power)2305 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
2306 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
2307 HloInstruction *lhs, *rhs;
2308 CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
2309 if (IsAll(rhs, 0)) {
2310 auto one = HloInstruction::CreateConstant(
2311 LiteralUtil::One(power->shape().element_type()).Clone());
2312 std::unique_ptr<HloInstruction> ones;
2313 if (ShapeUtil::IsScalar(power->shape())) {
2314 ones = std::move(one);
2315 } else {
2316 ones = HloInstruction::CreateBroadcast(
2317 power->shape(), computation_->AddInstruction(std::move(one)), {});
2318 }
2319 return ReplaceWithNewInstruction(power, std::move(ones));
2320 }
2321
2322 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
2323 if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
2324 return Status::OK();
2325 }
2326
2327 // pow(exp(A),B) => exp(A*B)
2328 HloInstruction *a, *b;
2329 if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
2330 auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
2331 power->shape(), HloOpcode::kMultiply, a, b));
2332 return ReplaceWithNewInstruction(
2333 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
2334 a_times_b));
2335 }
2336 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
2337 if (IsAll(rhs, 2)) {
2338 return ReplaceWithNewInstruction(
2339 power, HloInstruction::CreateBinary(power->shape(),
2340 HloOpcode::kMultiply, lhs, lhs));
2341 }
2342
2343 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
2344 if (IsAll(rhs, -1)) {
2345 auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
2346 LiteralUtil::One(rhs->shape().element_type()).Clone()));
2347
2348 // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
2349 // broadcast in divide HLO as we are trying to eliminate implicit
2350 // broadcasting at HLO level.
2351 auto* broadcast_one = computation_->AddInstruction(
2352 HloInstruction::CreateBroadcast(power->shape(), one, {}));
2353 return ReplaceWithNewInstruction(
2354 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
2355 broadcast_one, lhs));
2356 }
2357
2358 VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: "
2359 << power->ToString();
2360
2361 // Don't perform this optimization if either of the exponents is complex; this
2362 // identity is true only for real-valued exponents. In addition, we cowardly
2363 // refuse to do this transformation if the two expontents have different
2364 // element types.
2365 if (lhs->opcode() == HloOpcode::kPower &&
2366 !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) &&
2367 !ShapeUtil::ElementIsComplex(rhs->shape()) &&
2368 ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) {
2369 auto exponent_product =
2370 computation_->AddInstruction(HloInstruction::CreateBinary(
2371 rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
2372 return ReplaceWithNewInstruction(
2373 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower,
2374 lhs->mutable_operand(0),
2375 exponent_product));
2376 }
2377
2378 return Status::OK();
2379 }
2380
2381 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)2382 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
2383 HloInstruction* broadcast) {
2384 TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
2385 bool changed = false;
2386 if (ShapeUtil::IsScalar(broadcast->shape())) {
2387 return false;
2388 }
2389 HloInstruction* operand = broadcast->mutable_operand(0);
2390 for (HloInstruction* user : broadcast->users()) {
2391 if (user->user_count() == 0 && user != computation_->root_instruction()) {
2392 continue;
2393 }
2394 // Do not move reshapes or broadcasts past copies since the shape the copy
2395 // will operate on will change.
2396 if (user->opcode() == HloOpcode::kCopy) {
2397 continue;
2398 }
2399 // Do not change the shape of fusion nodes in case there a multiple shapes
2400 // inside the fusion node already.
2401 if (user->opcode() == HloOpcode::kFusion) {
2402 continue;
2403 }
2404 if (!user->IsElementwise()) {
2405 continue;
2406 }
2407
2408 // Find the unique non-scalar operand or continue if there isn't one.
2409 int64 scalar_broadcast_count = 0;
2410 int64 broadcast_use_count = 0;
2411 for (HloInstruction* user_operand : user->operands()) {
2412 if (user_operand->opcode() == HloOpcode::kBroadcast &&
2413 ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
2414 ++scalar_broadcast_count;
2415 } else if (broadcast == user_operand) {
2416 ++broadcast_use_count;
2417 }
2418 }
2419 if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
2420 continue;
2421 }
2422 std::vector<HloInstruction*> new_operands;
2423 new_operands.reserve(user->operand_count());
2424
2425 for (HloInstruction* user_operand : user->operands()) {
2426 if (user_operand->opcode() == HloOpcode::kBroadcast &&
2427 ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
2428 new_operands.push_back(
2429 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2430 ShapeUtil::ChangeElementType(
2431 operand->shape(), user_operand->shape().element_type()),
2432 user_operand->mutable_operand(0), {})));
2433 } else {
2434 CHECK_EQ(broadcast, user_operand);
2435 new_operands.push_back(operand);
2436 }
2437 }
2438 VLOG(4) << "Sinking broadcast after user:";
2439 VLOG(4) << " old broadcast: " << broadcast->ToString();
2440 VLOG(4) << " old user: " << user->ToString();
2441 HloInstruction* new_user =
2442 computation_->AddInstruction(user->CloneWithNewOperands(
2443 ShapeUtil::ChangeElementType(operand->shape(),
2444 user->shape().element_type()),
2445 new_operands));
2446 VLOG(4) << " new user: " << new_user->ToString();
2447 HloInstruction* new_broadcast =
2448 computation_->AddInstruction(HloInstruction::CreateBroadcast(
2449 user->shape(), new_user, broadcast->dimensions()));
2450 VLOG(4) << " new broadcast: " << new_broadcast->ToString();
2451 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
2452 changed = true;
2453 }
2454 return changed;
2455 }
2456
2457 namespace {
2458 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation)2459 std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
2460 HloComputation* computation) {
2461 HloInstruction *a, *b, *c;
2462 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
2463
2464 if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
2465 !Match(b, m::ConstantEffectiveScalar(&c)) &&
2466 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
2467 return nullptr;
2468 }
2469
2470 if (ShapeUtil::ElementIsSigned(remainder->shape())) {
2471 int64 b_value = c->literal().GetFirstElement<T>();
2472 if (b_value > 0 && IsPowerOfTwo(static_cast<uint64>(b_value))) {
2473 // Handle negative dividends by negating the result of the division.
2474 HloInstruction* zero_like_a = BroadcastZeros(
2475 computation, a->shape().element_type(), a->shape().dimensions());
2476
2477 auto* dividend_is_negative =
2478 computation->AddInstruction(HloInstruction::CreateCompare(
2479 ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
2480 ComparisonDirection::kLt));
2481
2482 auto* negated_dividend = computation->AddInstruction(
2483 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
2484
2485 auto* abs_dividend =
2486 computation->AddInstruction(HloInstruction::CreateTernary(
2487 a->shape(), HloOpcode::kSelect, dividend_is_negative,
2488 negated_dividend, a));
2489
2490 auto* mask_amount =
2491 computation->AddInstruction(HloInstruction::CreateConstant(
2492 LiteralUtil::CreateR0<T>(b_value - 1)));
2493 if (!ShapeUtil::IsScalar(b->shape())) {
2494 mask_amount = computation->AddInstruction(
2495 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
2496 }
2497
2498 auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary(
2499 remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount));
2500
2501 auto* neqated_quotient =
2502 computation->AddInstruction(HloInstruction::CreateUnary(
2503 quotient->shape(), HloOpcode::kNegate, quotient));
2504
2505 return HloInstruction::CreateTernary(
2506 remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
2507 neqated_quotient, quotient);
2508 }
2509 } else {
2510 uint64 b_value = c->literal().GetFirstElement<T>();
2511 if (IsPowerOfTwo(b_value)) {
2512 HloInstruction* mask_amount =
2513 computation->AddInstruction(HloInstruction::CreateConstant(
2514 LiteralUtil::CreateR0<T>(b_value - 1)));
2515 if (!ShapeUtil::IsScalar(b->shape())) {
2516 mask_amount = computation->AddInstruction(
2517 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
2518 }
2519 return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
2520 a, mask_amount);
2521 }
2522 }
2523 return nullptr;
2524 }
2525 } // namespace
2526
HandleRemainder(HloInstruction * remainder)2527 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
2528 HloInstruction *a, *b;
2529 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
2530
2531 // A % B => A & (B - 1) if B is a power of 2.
2532 switch (remainder->shape().element_type()) {
2533 case S8:
2534 if (std::unique_ptr<HloInstruction> shift =
2535 TryRemainderToAnd<int8>(remainder, computation_)) {
2536 return ReplaceWithNewInstruction(remainder, std::move(shift));
2537 }
2538 break;
2539 case S16:
2540 if (std::unique_ptr<HloInstruction> shift =
2541 TryRemainderToAnd<int16>(remainder, computation_)) {
2542 return ReplaceWithNewInstruction(remainder, std::move(shift));
2543 }
2544 break;
2545 case S32:
2546 if (std::unique_ptr<HloInstruction> shift =
2547 TryRemainderToAnd<int32>(remainder, computation_)) {
2548 return ReplaceWithNewInstruction(remainder, std::move(shift));
2549 }
2550 break;
2551 case S64:
2552 if (std::unique_ptr<HloInstruction> shift =
2553 TryRemainderToAnd<int64>(remainder, computation_)) {
2554 return ReplaceWithNewInstruction(remainder, std::move(shift));
2555 }
2556 break;
2557 case U8:
2558 if (std::unique_ptr<HloInstruction> shift =
2559 TryRemainderToAnd<uint8>(remainder, computation_)) {
2560 return ReplaceWithNewInstruction(remainder, std::move(shift));
2561 }
2562 break;
2563 case U16:
2564 if (std::unique_ptr<HloInstruction> shift =
2565 TryRemainderToAnd<uint16>(remainder, computation_)) {
2566 return ReplaceWithNewInstruction(remainder, std::move(shift));
2567 }
2568 break;
2569 case U32:
2570 if (std::unique_ptr<HloInstruction> shift =
2571 TryRemainderToAnd<uint32>(remainder, computation_)) {
2572 return ReplaceWithNewInstruction(remainder, std::move(shift));
2573 }
2574 break;
2575 case U64:
2576 if (std::unique_ptr<HloInstruction> shift =
2577 TryRemainderToAnd<uint64>(remainder, computation_)) {
2578 return ReplaceWithNewInstruction(remainder, std::move(shift));
2579 }
2580 break;
2581 default:
2582 break;
2583 }
2584
2585 return Status::OK();
2586 }
2587
HandleReshape(HloInstruction * reshape)2588 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
2589 auto operand = reshape->mutable_operand(0);
2590
2591 // Reshape directly to empty constant if the shape contains zero-element
2592 // dimension.
2593 if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
2594 // If the instruction doesn't have a layout, use a default layout for
2595 // the literal result.
2596 Shape reshaped_shape = reshape->shape();
2597 if (!LayoutUtil::HasLayout(reshaped_shape)) {
2598 LayoutUtil::SetToDefaultLayout(&reshaped_shape);
2599 }
2600 auto empty_constant = HloInstruction::CreateConstant(
2601 Literal::CreateFromShape(reshaped_shape));
2602
2603 return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
2604 }
2605
2606 // Delete no-op reshapes, i.e. where shape = operand shape.
2607 if (SameShape(reshape, operand)) {
2608 VLOG(10) << "deleting no-op reshape";
2609 return ReplaceInstruction(reshape, operand);
2610 }
2611
2612 // Merge reshapes.
2613 if (HloOpcode::kReshape == operand->opcode()) {
2614 return ReplaceWithNewInstruction(
2615 reshape, HloInstruction::CreateReshape(reshape->shape(),
2616 operand->mutable_operand(0)));
2617 }
2618
2619 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
2620 *operand->mutable_shape() = reshape->shape();
2621 return ReplaceInstruction(reshape, operand);
2622 }
2623
2624 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
2625 auto opt_dims = ReshapeLeavesDimensionsUnmodified(
2626 reshape, reshape->operand(0)->dimensions());
2627 if (opt_dims.has_value()) {
2628 return ReplaceWithNewInstruction(
2629 reshape,
2630 HloInstruction::CreateBroadcast(
2631 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
2632 *opt_dims));
2633 }
2634 }
2635
2636 // reshape(iota) -> iota.
2637 if (operand->opcode() == HloOpcode::kIota) {
2638 auto* iota = Cast<HloIotaInstruction>(operand);
2639 auto opt_dims =
2640 ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()});
2641 if (opt_dims.has_value()) {
2642 CHECK_EQ(opt_dims->size(), 1);
2643 return ReplaceWithNewInstruction(
2644 reshape,
2645 HloInstruction::CreateIota(reshape->shape(), opt_dims->front()));
2646 }
2647 }
2648
2649 // Make this a bitcast if possible.
2650 if (HloInstruction* bitcast_operand =
2651 BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
2652 ReplaceWithBitcast(reshape, bitcast_operand);
2653 }
2654 return Status::OK();
2655 }
2656
HandleReverse(HloInstruction * reverse)2657 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
2658 // When all the dimensions to reverse are trivial (i.e. the bound is 1),
2659 // there is nothing to be done.
2660 auto dim_is_one = [&](int64 i) -> bool {
2661 return reverse->shape().dimensions(i) == 1;
2662 };
2663 if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
2664 return ReplaceInstruction(reverse, reverse->mutable_operand(0));
2665 }
2666 return Status::OK();
2667 }
2668
TrySimplifyScalarSlice(HloInstruction * slice)2669 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
2670 HloInstruction* slice) {
2671 // Only try to do this for effective scalars. We could do the same for slicing
2672 // out larger pieces of padding (replacing with a broadcast of the padding
2673 // value), but this is probably not worth it.
2674 if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
2675 return false;
2676 }
2677
2678 if (slice->operand(0)->opcode() == HloOpcode::kPad) {
2679 VLOG(10) << "Trying to simplify scalar slice of pad";
2680 // Check there's no internal padding. Again, we could handle that too, since
2681 // everything is statically known, but it's not worth it.
2682 auto pad = Cast<HloPadInstruction>(slice->mutable_operand(0));
2683 auto padding_config = pad->padding_config();
2684 int64 rank = padding_config.dimensions_size();
2685 if (HasInteriorPadding(padding_config)) {
2686 VLOG(10) << "Not folding scalar slice of pad, pad has interior padding";
2687 return false;
2688 }
2689
2690 // Check whether the scalar we're slicing out falls into the padding.
2691 bool in_padding = [&]() {
2692 for (int64 i = 0; i < rank; ++i) {
2693 int64 start = slice->slice_starts(i);
2694 int64 low = padding_config.dimensions(i).edge_padding_low();
2695 int64 data = pad->operand(0)->shape().dimensions(i);
2696 if (start < low || start >= low + data) {
2697 return true;
2698 }
2699 }
2700 return false;
2701 }();
2702
2703 if (in_padding) {
2704 VLOG(10) << "Folding scalar slice of pad into padding value";
2705 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2706 slice, HloInstruction::CreateReshape(slice->shape(),
2707 pad->mutable_padding_value())));
2708 return true;
2709 } else {
2710 // We already know the output of the slice is scalar. If the padded
2711 // value is scalar, and it's not in the padding, then it's exactly the
2712 // output value.
2713 bool replaced =
2714 ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0));
2715 if (replaced) {
2716 VLOG(10) << "Folding scalar slice of pad into padded value";
2717 } else {
2718 VLOG(10) << "Not folding scalar slice of pad into padded value as they "
2719 "have different shapes.";
2720 }
2721 return replaced;
2722 }
2723 }
2724
2725 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
2726 VLOG(10) << "Trying to simplify scalar slice of concat";
2727 // Only do this for R1, there's no chance of this being useful otherwise.
2728 if (slice->shape().rank() != 1) {
2729 VLOG(10) << "Not folding, slice is not rank 1";
2730 return false;
2731 }
2732 HloConcatenateInstruction* concat =
2733 Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
2734 int64 operand_start = 0;
2735 int64 operand_num = 0;
2736 // Weird loop structure to avoid annoying off-by-one errors.
2737 while (true) {
2738 TF_RET_CHECK(operand_num < concat->operand_count());
2739 const HloInstruction* operand = concat->operand(operand_num);
2740 int64 next_operand_start = operand_start + operand->shape().dimensions(0);
2741 if (next_operand_start > slice->slice_starts(0)) {
2742 break;
2743 }
2744 operand_start = next_operand_start;
2745 operand_num++;
2746 }
2747
2748 bool replaced = ReplaceInstructionIfSameShape(
2749 slice, concat->mutable_operand(operand_num));
2750 if (replaced) {
2751 VLOG(10) << "Folding scalar slice of concat into concat operand";
2752 } else {
2753 VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
2754 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2755 slice, HloInstruction::CreateSlice(
2756 slice->shape(), concat->mutable_operand(operand_num),
2757 {slice->slice_starts(0) - operand_start},
2758 {slice->slice_starts(0) - operand_start + 1},
2759 slice->slice_strides())));
2760 }
2761 return true;
2762 }
2763
2764 return false;
2765 }
2766
TryToReorderSliceAndReshape(HloInstruction * slice)2767 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
2768 HloInstruction* slice) {
2769 CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
2770 if (!IsUnstridedSlice(slice)) {
2771 return false;
2772 }
2773 HloInstruction* reshape = slice->mutable_operand(0);
2774 if (reshape->opcode() != HloOpcode::kReshape) {
2775 return false;
2776 }
2777 HloInstruction* new_slice_operand = reshape->mutable_operand(0);
2778 int64 slice_rank = slice->shape().rank();
2779 std::vector<int64> sliced_dims;
2780 for (int64 i = 0; i < slice_rank; ++i) {
2781 if (slice->slice_starts(i) != 0 ||
2782 slice->slice_limits(i) != reshape->shape().dimensions(i)) {
2783 sliced_dims.push_back(i);
2784 }
2785 }
2786
2787 if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
2788 slice->slice_starts(0) == 0) {
2789 const Shape& new_slice_shape = new_slice_operand->shape();
2790 const int64 rank = new_slice_shape.rank();
2791 std::vector<int64> new_slice_starts(rank, 0);
2792 std::vector<int64> new_slice_stides(rank, 1);
2793 std::vector<int64> new_slice_limits(new_slice_shape.dimensions().begin(),
2794 new_slice_shape.dimensions().end());
2795 int64 slice_elements = ShapeUtil::ElementsIn(slice->shape());
2796 for (int64 i = rank - 1; i >= 0; --i) {
2797 if (slice_elements >= new_slice_limits[i]) {
2798 if (slice_elements % new_slice_limits[i] != 0) {
2799 return false;
2800 }
2801 slice_elements /= new_slice_limits[i];
2802 } else {
2803 new_slice_limits[i] = slice_elements;
2804 slice_elements = 1;
2805 }
2806 }
2807 HloInstruction* new_slice =
2808 computation_->AddInstruction(HloInstruction::CreateSlice(
2809 ShapeUtil::MakeShape(new_slice_shape.element_type(),
2810 new_slice_limits),
2811 new_slice_operand, new_slice_starts, new_slice_limits,
2812 new_slice_stides));
2813 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2814 slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
2815 return true;
2816 }
2817 return false;
2818 }
2819
HandleSlice(HloInstruction * slice)2820 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
2821 // Delete no-op slices, i.e. where shape = operand shape.
2822 if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
2823 return Status::OK();
2824 }
2825
2826 if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
2827 IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
2828 HloInstruction* operand_slice = slice->mutable_operand(0);
2829 std::vector<int64> new_slice_starts = slice->slice_starts();
2830 std::vector<int64> new_slice_limits = slice->slice_limits();
2831 for (int64 i = 0; i < new_slice_starts.size(); ++i) {
2832 new_slice_starts[i] += operand_slice->slice_starts(i);
2833 new_slice_limits[i] += operand_slice->slice_starts(i);
2834 }
2835 return ReplaceWithNewInstruction(
2836 slice, HloInstruction::CreateSlice(
2837 slice->shape(), operand_slice->mutable_operand(0),
2838 new_slice_starts, new_slice_limits, slice->slice_strides()));
2839 }
2840
2841 auto only_broadcast_dims_sliced = [&] {
2842 if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
2843 return false;
2844 }
2845 for (int64 dim : slice->operand(0)->dimensions()) {
2846 if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
2847 slice->slice_limits(dim) !=
2848 slice->operand(0)->shape().dimensions(dim)) {
2849 return false;
2850 }
2851 }
2852 return true;
2853 };
2854 if (only_broadcast_dims_sliced()) {
2855 return ReplaceWithNewInstruction(
2856 slice,
2857 HloInstruction::CreateBroadcast(
2858 slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
2859 slice->mutable_operand(0)->dimensions()));
2860 }
2861
2862 TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
2863 if (replaced) {
2864 return Status::OK();
2865 }
2866
2867 TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
2868 if (replaced) {
2869 return Status::OK();
2870 }
2871 return Status::OK();
2872 }
2873
HandleDynamicSlice(HloInstruction * dynamic_slice)2874 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
2875 HloInstruction* dynamic_slice) {
2876 auto operand = dynamic_slice->mutable_operand(0);
2877 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
2878 return ReplaceInstruction(dynamic_slice, operand);
2879 }
2880 // DynamicSlice where operand has the same size as the output is simply equal
2881 // to operand.
2882 if (SameShape(operand, dynamic_slice)) {
2883 return ReplaceInstruction(dynamic_slice, operand);
2884 }
2885 return Status::OK();
2886 }
2887
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)2888 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
2889 HloInstruction* dynamic_update_slice) {
2890 auto update = dynamic_update_slice->mutable_operand(1);
2891
2892 // DynamicUpdateSlice where operand and update have the same size is simply
2893 // equal to update.
2894 if (SameShape(dynamic_update_slice, update)) {
2895 return ReplaceInstruction(dynamic_update_slice, update);
2896 }
2897
2898 // If any dimension of update is 0, elide the DynamicUpdateSlice. This
2899 // optimization becomes invalid should we later prefer to warn about out of
2900 // bound indices.
2901 if (ShapeUtil::IsZeroElementArray(update->shape())) {
2902 return ReplaceInstruction(dynamic_update_slice,
2903 dynamic_update_slice->mutable_operand(0));
2904 }
2905 return Status::OK();
2906 }
2907
HandleReduce(HloInstruction * hlo)2908 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
2909 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
2910 bool multi_output_reduce = reduce->shape().IsTuple();
2911
2912 // For tuple reduce, we require all reduce shapes to be the same, up to the
2913 // element types, so we can just the first operand and the first result as a
2914 // representative.
2915 auto arg = reduce->inputs()[0];
2916 auto init_value = reduce->init_values()[0];
2917 const Shape& reduce_result_shape =
2918 multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
2919
2920 absl::Span<const int64> dimensions(reduce->dimensions());
2921 HloComputation* function = reduce->to_apply();
2922 if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
2923 ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
2924 if (multi_output_reduce) {
2925 std::vector<HloInstruction*> broadcast_inits;
2926 int64 inputs = reduce->input_count();
2927 for (int64 i = 0; i < inputs; ++i) {
2928 broadcast_inits.push_back(computation_->AddInstruction(
2929 HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
2930 reduce->init_values()[i], {})));
2931 }
2932 return ReplaceWithNewInstruction(
2933 reduce, HloInstruction::CreateTuple(broadcast_inits));
2934 } else {
2935 return ReplaceWithNewInstruction(
2936 reduce,
2937 HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
2938 }
2939 }
2940
2941 // If the reduction results in the same number of elements, then the only
2942 // possible side effect would be a reshape. Since the init_value is an
2943 // identity of the reduction function, we can therefore replace the reduce
2944 // with a simple reshape, ignoring the reduction function completely.
2945 if (ShapeUtil::ElementsIn(reduce_result_shape) ==
2946 ShapeUtil::ElementsIn(arg->shape())) {
2947 if (multi_output_reduce) {
2948 std::vector<HloInstruction*> reshaped_args;
2949 int64 inputs = reduce->input_count();
2950 for (int64 i = 0; i < inputs; ++i) {
2951 reshaped_args.push_back(
2952 computation_->AddInstruction(HloInstruction::CreateReshape(
2953 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
2954 }
2955 return ReplaceWithNewInstruction(
2956 reduce, HloInstruction::CreateTuple(reshaped_args));
2957 } else {
2958 return ReplaceWithNewInstruction(
2959 reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
2960 }
2961 }
2962
2963 // TODO(b/112040122): Most of those optimizations below can be done for
2964 // multi-output reduces.
2965 if (multi_output_reduce) {
2966 return Status::OK();
2967 }
2968
2969 // A Transpose feeding a reduce can simply permute the reduction dimensions
2970 // field if the output of the reduce is a vector or scalar. Higher ranked
2971 // result may require a transpose of the output.
2972 if (reduce_result_shape.rank() <= 1 &&
2973 arg->opcode() == HloOpcode::kTranspose) {
2974 auto transpose_dimensions = arg->dimensions();
2975 std::vector<int64> new_reduce_dimensions;
2976 for (auto dim : dimensions) {
2977 new_reduce_dimensions.push_back(transpose_dimensions[dim]);
2978 }
2979 return ReplaceWithNewInstruction(
2980 reduce, HloInstruction::CreateReduce(
2981 reduce_result_shape, arg->mutable_operand(0), init_value,
2982 new_reduce_dimensions, function));
2983 }
2984
2985 // If a reduce feeds a reduce with the same computation and initial value,
2986 // they can be combined into a single reduce.
2987 if (arg->opcode() == HloOpcode::kReduce &&
2988 init_value->Identical(*arg->operand(1)) &&
2989 *function == *arg->to_apply()) {
2990 // Create a new reduce with the combined reduction dimensions of both
2991 // reduces.
2992 std::vector<int64> arg_dims = arg->dimensions();
2993 absl::c_sort(arg_dims);
2994 std::vector<int64> reduce_dims = reduce->dimensions();
2995 absl::c_sort(reduce_dims);
2996 // Transform reduce_dims to the same rank as the operand of the operand.
2997 for (int64 arg_dim : arg_dims) {
2998 for (int64& dim : reduce_dims) {
2999 if (dim >= arg_dim) {
3000 ++dim;
3001 }
3002 }
3003 }
3004 std::vector<int64> new_dimensions;
3005 new_dimensions.reserve(arg->dimensions().size() +
3006 reduce->dimensions().size());
3007 std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
3008 reduce_dims.end(), std::back_inserter(new_dimensions));
3009 return ReplaceWithNewInstruction(
3010 reduce, HloInstruction::CreateReduce(
3011 reduce_result_shape, arg->mutable_operand(0), init_value,
3012 new_dimensions, function));
3013 }
3014
3015 // A reshape that collapses multiple dimensions into a dimension being
3016 // reduced can just reduce all of those dimensions instead of doing a
3017 // collapsing reshape before a reduction.
3018 if (arg->opcode() == HloOpcode::kReshape) {
3019 std::vector<std::pair<int64, int64>> unmodified_dims =
3020 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
3021 arg->shape());
3022 std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
3023 std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
3024 for (auto dim : dimensions) {
3025 arg_dim_in_output[dim] = false;
3026 }
3027 for (auto dim_pair : unmodified_dims) {
3028 arg_dim_unmodified[dim_pair.second] = true;
3029 }
3030 // The goal is to verify that all dimensions that are not removed in the
3031 // reduce are unmodified by the reshape. For example:
3032 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
3033 bool can_move_reshape_into_reduce = true;
3034 for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
3035 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
3036 can_move_reshape_into_reduce = false;
3037 }
3038 }
3039 if (can_move_reshape_into_reduce) {
3040 changed_ = true;
3041 absl::flat_hash_set<int64> dimensions_not_to_reduce;
3042 for (auto dim_pair : unmodified_dims) {
3043 if (arg_dim_in_output[dim_pair.second]) {
3044 dimensions_not_to_reduce.insert(dim_pair.first);
3045 }
3046 }
3047 std::vector<int64> new_reduce_dimensions;
3048 for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
3049 if (!dimensions_not_to_reduce.contains(i)) {
3050 new_reduce_dimensions.push_back(i);
3051 }
3052 }
3053 return ReplaceWithNewInstruction(
3054 reduce, HloInstruction::CreateReduce(
3055 reduce_result_shape, arg->mutable_operand(0), init_value,
3056 new_reduce_dimensions, function));
3057 }
3058 }
3059 // Convert Reduce(concat({a,b,...})) to
3060 // map(reduce(a),map(reduce(b),...,))
3061 //
3062 // This should make fusion easier or use less memory bandwidth in the unfused
3063 // case.
3064 if (arg->opcode() == HloOpcode::kConcatenate &&
3065 absl::c_linear_search(reduce->dimensions(),
3066 arg->concatenate_dimension())) {
3067 HloInstruction* old_reduce = nullptr;
3068 for (HloInstruction* operand : arg->operands()) {
3069 HloInstruction* new_reduce = computation_->AddInstruction(
3070 HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
3071 reduce->dimensions(), function));
3072 if (old_reduce != nullptr) {
3073 new_reduce = computation_->AddInstruction(HloInstruction::CreateMap(
3074 reduce_result_shape, {old_reduce, new_reduce}, function));
3075 }
3076 old_reduce = new_reduce;
3077 }
3078 return ReplaceInstruction(reduce, old_reduce);
3079 }
3080 return Status::OK();
3081 }
3082
HandleReduceWindow(HloInstruction * reduce_window)3083 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
3084 HloInstruction* reduce_window) {
3085 if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
3086 return ReplaceWithNewInstruction(
3087 reduce_window,
3088 HloInstruction::CreateBroadcast(reduce_window->shape(),
3089 reduce_window->mutable_operand(1), {}));
3090 }
3091 auto operand = reduce_window->mutable_operand(0);
3092 const Window& window = reduce_window->window();
3093 auto function = reduce_window->to_apply();
3094 if (ShapeUtil::IsScalar(operand->shape())) {
3095 TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
3096 return ReplaceWithNewInstruction(
3097 reduce_window,
3098 HloInstruction::CreateMap(reduce_window->shape(),
3099 {reduce_window->mutable_operand(1), operand},
3100 function));
3101 }
3102
3103 if (options_.enable_window_reduce_to_reduce_replacement()) {
3104 // A reduce window can be expressed as a reduce and a reshape if all
3105 // dimensions either have a window size of one or the entire dimension. If
3106 // there is no stride, dilation, or padding, this is as easy as checking the
3107 // size of the output shape and window dimension.
3108 //
3109 // The reshape is a bitcast since it adds one-sized dimensions. Often these
3110 // ones are immediately removed as well with another reshape. The
3111 // implementation of reduce tends to be slightly more efficient at reducing
3112 // entire dimensions compared to reduce window.
3113 auto effective_reduce_dims = [&] {
3114 if (window_util::HasStride(window) || window_util::HasDilation(window) ||
3115 window_util::HasPadding(window)) {
3116 return absl::InlinedVector<int64, 8>{};
3117 }
3118 absl::InlinedVector<int64, 8> reduce_dims;
3119 for (int64 i = 0; i < window.dimensions_size(); ++i) {
3120 if (window.dimensions(i).size() == 1) {
3121 continue;
3122 } else if (reduce_window->shape().dimensions(i) == 1) {
3123 reduce_dims.push_back(i);
3124 } else {
3125 return absl::InlinedVector<int64, 8>{};
3126 }
3127 }
3128 return reduce_dims;
3129 }();
3130
3131 // If a reduce window can be expressed as a reduce, do so and reshape the
3132 // output.
3133 if (!effective_reduce_dims.empty()) {
3134 Shape reduce_shape = ShapeUtil::FilterDimensions(
3135 [&](int64 dim) {
3136 return !absl::c_linear_search(effective_reduce_dims, dim);
3137 },
3138 reduce_window->shape());
3139 HloInstruction* reduce =
3140 computation_->AddInstruction(HloInstruction::CreateReduce(
3141 /*shape=*/reduce_shape,
3142 /*operand=*/operand,
3143 /*init_value=*/reduce_window->mutable_operand(1),
3144 /*dimensions_to_reduce=*/effective_reduce_dims,
3145 /*reduce_computation=*/function));
3146 return ReplaceWithNewInstruction(
3147 reduce_window,
3148 HloInstruction::CreateReshape(reduce_window->shape(), reduce));
3149 }
3150 }
3151
3152 // This optimization folds a pad op into reduce_window.
3153 HloInstruction* pad;
3154 const HloInstruction* convert = nullptr;
3155 if (operand->opcode() == HloOpcode::kPad) {
3156 pad = operand;
3157 } else if (operand->opcode() == HloOpcode::kConvert &&
3158 operand->operand(0)->opcode() == HloOpcode::kPad) {
3159 convert = operand;
3160 pad = operand->mutable_operand(0);
3161 } else {
3162 VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
3163 return Status::OK();
3164 }
3165
3166 // Bail on dilation.
3167 if (window_util::HasDilation(window)) {
3168 VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
3169 return Status::OK();
3170 }
3171
3172 VLOG(10) << "Considering folding Pad: " << pad->ToString()
3173 << "\ninto reduce-window: " << reduce_window->ToString()
3174 << (convert != nullptr
3175 ? absl::StrCat("\nvia convert: ", convert->ToString())
3176 : "");
3177
3178 // Do not fold interior padding into ReduceWindow since the backends do not
3179 // support it.
3180 const PaddingConfig& pad_config = pad->padding_config();
3181 if (HasInteriorPadding(pad_config)) {
3182 VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
3183 return Status::OK();
3184 }
3185
3186 // If reduce_window already has padding, the pad value of the pad op and the
3187 // init value of reduce_window must match to allow folding the pad.
3188 const HloInstruction* pad_value = pad->operand(1);
3189 const HloInstruction* reduce_init_value = reduce_window->operand(1);
3190 if (pad_value != reduce_init_value) {
3191 auto literals_are_equivalent = [&] {
3192 auto& pad_literal = pad_value->literal();
3193 auto& reduce_init_literal = reduce_init_value->literal();
3194 if (pad_literal == reduce_init_literal) {
3195 return true;
3196 }
3197 auto converted_pad_literal =
3198 pad_literal.ConvertToShape(reduce_init_value->shape());
3199 if (!converted_pad_literal.ok()) {
3200 return false;
3201 }
3202 return converted_pad_literal.ValueOrDie() == reduce_init_literal;
3203 };
3204 // The pad value is usually a constant, so we handle that case and do not
3205 // try to get more fancy about proving equivalence in cases beyond that.
3206 if (pad_value->opcode() != HloOpcode::kConstant ||
3207 reduce_init_value->opcode() != HloOpcode::kConstant ||
3208 !literals_are_equivalent()) {
3209 VLOG(10) << "Not folding pad into reduce-window due to different pad "
3210 "values.";
3211 return Status::OK();
3212 }
3213 }
3214
3215 // If the pad puts a single non-identity value in each window that we're
3216 // reducing, then this is a broadcast.
3217 HloInstruction* pad_operand = pad->mutable_operand(0);
3218 auto is_effective_broadcast = [&] {
3219 if (window_util::HasStride(window)) {
3220 VLOG(10) << "Window has stride.";
3221 return false;
3222 }
3223 if (!window_util::HasSymmetricPadding(pad_config)) {
3224 VLOG(10) << "Window has uneven padding.";
3225 return false;
3226 }
3227 for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
3228 const auto& pad_dimension = pad_config.dimensions(i);
3229 if ((pad_dimension.edge_padding_low() != 0 ||
3230 pad_dimension.edge_padding_high() != 0) &&
3231 pad_operand->shape().dimensions(i) != 1) {
3232 VLOG(10) << "Found non-trivial dimension being padded: " << i;
3233 return false;
3234 }
3235 }
3236 VLOG(10) << "Found to be padding trivial dimensions only.";
3237
3238 for (int64 i = 0; i < window.dimensions_size(); ++i) {
3239 const auto& pad_dimension = pad_config.dimensions(i);
3240 const WindowDimension& window_dimension = window.dimensions(i);
3241 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
3242 pad_dimension.edge_padding_high() != 0);
3243 if (dimension_has_padding &&
3244 window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
3245 VLOG(10) << "Found window did not cover single unpadded element in "
3246 "dimension: "
3247 << i;
3248 return false;
3249 }
3250 if (pad_operand->shape().dimensions(i) != 1 &&
3251 window_dimension.size() != 1) {
3252 VLOG(10) << "Found window covers more than one element in non-trivial "
3253 "dimension: "
3254 << i;
3255 return false;
3256 }
3257 }
3258 VLOG(10) << "Found window covers a single unpadded element.";
3259 return true;
3260 };
3261
3262 HloInstruction* new_reduce_window_operand;
3263 if (convert != nullptr) {
3264 new_reduce_window_operand =
3265 computation_->AddInstruction(HloInstruction::CreateConvert(
3266 ShapeUtil::ChangeElementType(pad_operand->shape(),
3267 convert->shape().element_type()),
3268 pad_operand));
3269 } else {
3270 new_reduce_window_operand = pad_operand;
3271 }
3272
3273 if (is_effective_broadcast()) {
3274 VLOG(10) << "Replacing pad/reduce-window with broadcast.";
3275 auto fadd = [this](std::unique_ptr<HloInstruction> x) {
3276 return computation_->AddInstruction(std::move(x));
3277 };
3278 return ReplaceWithNewInstruction(
3279 reduce_window, HloInstruction::CreateBroadcastSequence(
3280 /*output_shape=*/reduce_window->shape(),
3281 /*operand=*/new_reduce_window_operand, fadd));
3282 }
3283
3284 // Carry out the folding of the pad into reduce_window.
3285 VLOG(10) << "Folding pad into reduce-window.";
3286 Window new_window = window;
3287 const int64 rank = reduce_window->shape().rank();
3288 TF_RET_CHECK(pad_config.dimensions_size() == rank);
3289 TF_RET_CHECK(window.dimensions_size() == rank);
3290 for (int64 i = 0; i < rank; ++i) {
3291 const auto& pad_dim = pad_config.dimensions(i);
3292 auto& window_dim = *new_window.mutable_dimensions(i);
3293 window_dim.set_padding_low(window_dim.padding_low() +
3294 pad_dim.edge_padding_low());
3295 window_dim.set_padding_high(window_dim.padding_high() +
3296 pad_dim.edge_padding_high());
3297 }
3298
3299 return ReplaceWithNewInstruction(
3300 reduce_window, HloInstruction::CreateReduceWindow(
3301 /*shape=*/reduce_window->shape(),
3302 /*operand=*/new_reduce_window_operand,
3303 /*init_value=*/reduce_window->mutable_operand(1),
3304 /*window=*/new_window,
3305 /*reduce_computation=*/function));
3306 }
3307
HandleSelect(HloInstruction * select)3308 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
3309 // select(x, y, y) -> y.
3310 if (select->operand(1) == select->operand(2)) {
3311 return ReplaceInstruction(select, select->mutable_operand(1));
3312 }
3313 // select(true, x, y) -> x.
3314 if (IsAll(select->operand(0), true)) {
3315 return ReplaceInstruction(select, select->mutable_operand(1));
3316 }
3317 // select(false, x, y) -> y.
3318 if (IsAll(select->operand(0), false)) {
3319 return ReplaceInstruction(select, select->mutable_operand(2));
3320 }
3321 return Status::OK();
3322 }
3323
HandleSort(HloInstruction * sort)3324 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
3325 auto operand = sort->mutable_operand(0);
3326 int64 dimension_to_sort = sort->dimensions(0);
3327 if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
3328 operand->shape().dimensions(dimension_to_sort) <= 1) {
3329 if (sort->operand_count() == 1) {
3330 return ReplaceInstruction(sort, operand);
3331 }
3332 // If it is key/value sort, the output of sort is a tuple.
3333 return ReplaceWithNewInstruction(
3334 sort, HloInstruction::CreateTuple(sort->operands()));
3335 }
3336 return Status::OK();
3337 }
3338
3339 namespace {
OnlyPermutesMoreThanOneDegenerateDim(const Shape & shape,absl::Span<const int64> perm)3340 bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape,
3341 absl::Span<const int64> perm) {
3342 std::vector<int64> new_permutation;
3343 int64 degenerate_count = 0;
3344 for (int64 i = 0; i < perm.size(); ++i) {
3345 if (shape.dimensions(i) != 1) {
3346 new_permutation.push_back(perm[i]);
3347 } else {
3348 ++degenerate_count;
3349 }
3350 }
3351 return degenerate_count > 1 && absl::c_is_sorted(new_permutation);
3352 }
3353 } // namespace
3354
HandleTranspose(HloInstruction * transpose)3355 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
3356 auto operand = transpose->mutable_operand(0);
3357 if (std::is_sorted(transpose->dimensions().begin(),
3358 transpose->dimensions().end())) {
3359 VLOG(10) << "deleting no-op transpose";
3360 return ReplaceInstruction(transpose, operand);
3361 }
3362
3363 if (HloOpcode::kTranspose == operand->opcode()) {
3364 return ReplaceWithNewInstruction(
3365 transpose, HloInstruction::CreateTranspose(
3366 transpose->shape(), operand->mutable_operand(0),
3367 ComposePermutations(operand->dimensions(),
3368 transpose->dimensions())));
3369 }
3370
3371 // Replace transpose with a reshape if more than one degenerate method is
3372 // permuted.
3373 if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(),
3374 transpose->dimensions())) {
3375 return ReplaceWithNewInstruction(
3376 transpose, HloInstruction::CreateReshape(
3377 transpose->shape(), transpose->mutable_operand(0)));
3378 }
3379
3380 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
3381 *operand->mutable_shape() = transpose->shape();
3382 return ReplaceInstruction(transpose, operand);
3383 }
3384
3385 if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) {
3386 ReplaceWithBitcast(transpose);
3387 return Status::OK();
3388 }
3389
3390 return Status::OK();
3391 }
3392
FoldConvInputPad(HloInstruction * convolution)3393 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
3394 HloInstruction* convolution) {
3395 auto* lhs = convolution->mutable_operand(0);
3396 auto* rhs = convolution->mutable_operand(1);
3397 const auto& window = convolution->window();
3398 const ConvolutionDimensionNumbers& dnums =
3399 convolution->convolution_dimension_numbers();
3400
3401 if (lhs->opcode() != HloOpcode::kPad) {
3402 return false;
3403 }
3404
3405 // Convolution's padding is always zero, so bail if the kPad is adding
3406 // something other than zero.
3407 if (!IsAll(lhs->operand(1), 0)) {
3408 return false;
3409 }
3410
3411 const auto& padding = lhs->padding_config();
3412
3413 // Can't pad batch or feature dims.
3414 for (int64 dim :
3415 {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
3416 const auto& p = padding.dimensions(dim);
3417 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3418 p.interior_padding() != 0) {
3419 return false;
3420 }
3421 }
3422
3423 // Compute the window which is the result of merging the kPad and the
3424 // convolution's existing window.
3425 Window new_window = window;
3426 for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
3427 auto& w = *new_window.mutable_dimensions(dim);
3428 const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
3429 // Edge padding composes with itself in the straightforward way, but
3430 // composing interior padding is nontrivial, and we cowardly refuse to
3431 // think about it. If we see interior padding in either the kPad or conv,
3432 // bail if there's any sort of padding in the other.
3433 if (p.interior_padding() != 0 &&
3434 (w.padding_low() != 0 || w.padding_high() != 0 ||
3435 w.base_dilation() != 1)) {
3436 return false;
3437 }
3438 if (w.base_dilation() != 1 &&
3439 (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3440 p.interior_padding() != 0)) {
3441 return false;
3442 }
3443
3444 w.set_padding_low(w.padding_low() + p.edge_padding_low());
3445 w.set_padding_high(w.padding_high() + p.edge_padding_high());
3446 if (p.interior_padding() != 0) {
3447 CHECK_EQ(w.base_dilation(), 1);
3448 w.set_base_dilation(1 + p.interior_padding());
3449 }
3450 }
3451
3452 auto new_conv = convolution->CloneWithNewOperands(
3453 convolution->shape(), {lhs->mutable_operand(0), rhs});
3454 new_conv->set_window(new_window);
3455 TF_RETURN_IF_ERROR(
3456 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
3457 return true;
3458 }
3459
FoldConvFilterPad(HloInstruction * convolution)3460 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
3461 HloInstruction* convolution) {
3462 auto* lhs = convolution->mutable_operand(0);
3463 auto* rhs = convolution->mutable_operand(1);
3464 const ConvolutionDimensionNumbers& dnums =
3465 convolution->convolution_dimension_numbers();
3466
3467 if (rhs->opcode() != HloOpcode::kPad) {
3468 return false;
3469 }
3470
3471 // Convolution's padding is always zero, so bail if the kPad is adding
3472 // something other than zero.
3473 if (!IsAll(rhs->operand(1), 0)) {
3474 return false;
3475 }
3476
3477 const auto& padding = rhs->padding_config();
3478
3479 // Can't pad or dilate feature dims.
3480 for (int64 dim : {dnums.kernel_input_feature_dimension(),
3481 dnums.kernel_output_feature_dimension()}) {
3482 const auto& p = padding.dimensions(dim);
3483 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
3484 p.interior_padding() != 0) {
3485 return false;
3486 }
3487 }
3488
3489 // Compute the window which is the result of merging the kPad and the
3490 // convolution's existing window.
3491 Window new_window = convolution->window();
3492 for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
3493 auto& w = *new_window.mutable_dimensions(dim);
3494 const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
3495
3496 // We can only do this transformation if p adds dilation to the filter --
3497 // edge padding on the filter is not supported in conv.
3498 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
3499 return false;
3500 }
3501
3502 // Nothing to do if the kPad for this dim is entirely a nop.
3503 if (p.interior_padding() == 0) {
3504 continue;
3505 }
3506
3507 // We cowardly refuse to think about how dilation composes with itself;
3508 // bail if both the kPad and conv have dilation on this dimension.
3509 if (w.window_dilation() > 1) {
3510 return false;
3511 }
3512 CHECK_EQ(w.window_dilation(), 1);
3513 w.set_window_dilation(1 + p.interior_padding());
3514 w.set_size(rhs->operand(0)->shape().dimensions(
3515 dnums.kernel_spatial_dimensions(dim)));
3516 }
3517
3518 auto new_conv = convolution->CloneWithNewOperands(
3519 convolution->shape(), {lhs, rhs->mutable_operand(0)});
3520 new_conv->set_window(new_window);
3521 TF_RETURN_IF_ERROR(
3522 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
3523 return true;
3524 }
3525
SimplifyConvToDot(HloInstruction * convolution)3526 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
3527 HloInstruction* convolution) {
3528 auto* lhs = convolution->mutable_operand(0);
3529 auto* rhs = convolution->mutable_operand(1);
3530 const auto& window = convolution->window();
3531 const ConvolutionDimensionNumbers& dnums =
3532 convolution->convolution_dimension_numbers();
3533
3534 if (!options_.enable_conv_simplification()) {
3535 return false;
3536 }
3537
3538 // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
3539 // layout-insensitive mode, for fear of adding nontrivial reshapes.
3540 if (!options_.is_layout_sensitive()) {
3541 return false;
3542 }
3543
3544 const Shape& input_shape = lhs->shape();
3545 const Shape& filter_shape = rhs->shape();
3546 const Shape& convolution_shape = convolution->shape();
3547 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
3548 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
3549 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
3550
3551 // Require the spatial dimensions in the kernel to have a bound of one.
3552 for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
3553 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
3554 return false;
3555 }
3556 }
3557
3558 // Stride ignores part of the output, which matrix multiplication does not do,
3559 // so require no stride. Padding and base (lhs) dilation both implicitly
3560 // extend the data, which matrix multiplication also does not do, so require
3561 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
3562 // for a 1x1 window, so window dilation is no problem.
3563 if (window_util::HasStride(window) || window_util::HasPadding(window) ||
3564 window_util::HasBaseDilation(window)) {
3565 return false;
3566 }
3567
3568 // Also, the shapes must align for a rowmajor matmul:
3569 // - the input and output have the same layout.
3570 // - for input/output, the channel dimension must be the most minor. Other
3571 // spatial dims can be in any order.
3572 // - for filters, the input channel dimension must be more major than the
3573 // output channel dimension. The width+height don't matter because
3574 // they are 1.
3575 //
3576 // These constraints are harsh. If the channel dimension is the most major
3577 // and/or the layout of input/output feature dimensions are reversed, we can
3578 // still convert Conv into more efficient Matmul with operand transposition
3579 // (such as the transposition flags in cuBLAS SGEMM).
3580 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
3581 LayoutUtil::Minor(input_shape.layout(), 0) !=
3582 dnums.input_feature_dimension() ||
3583 LayoutUtil::Minor(convolution_shape.layout(), 0) !=
3584 dnums.output_feature_dimension() ||
3585 // The input feature dimension should come later in the minor-to-major
3586 // order.
3587 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
3588 dnums.kernel_input_feature_dimension()) <
3589 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
3590 dnums.kernel_output_feature_dimension()))) {
3591 return false;
3592 }
3593
3594 auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
3595 std::vector<int64> dims(operand->shape().dimensions_size());
3596 std::iota(dims.begin(), dims.end(), 0);
3597 return computation_->AddInstruction(
3598 HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand));
3599 };
3600
3601 // Replace it with a dot, with bitcasts around it to get the right shape.
3602 const int64 input_channels =
3603 input_shape.dimensions(dnums.input_feature_dimension());
3604 const int64 output_channels =
3605 filter_shape.dimensions(dnums.kernel_output_feature_dimension());
3606
3607 // Computes the product of the non-feature dimensions.
3608 int64 conv_width = 1;
3609 for (int i = 0; i < input_shape.dimensions_size(); ++i) {
3610 if (i != dnums.input_feature_dimension()) {
3611 conv_width *= input_shape.dimensions(i);
3612 }
3613 }
3614
3615 // We already checked feature_dimension is most minor, so data in input_shape
3616 // and row-major {conv_width,input_channels} are bitwise identical.
3617 const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3618 input_shape.element_type(), {conv_width, input_channels});
3619 // We already checked input_feature_dimension is more major than
3620 // output_feature_dimension, so data in filter_shape and row-major
3621 // {input_channels,output_channels} are bitwise identical.
3622 const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3623 filter_shape.element_type(), {input_channels, output_channels});
3624 const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
3625 convolution_shape.element_type(), {conv_width, output_channels});
3626
3627 auto new_lhs = add_bitcast(new_input_shape, lhs);
3628 auto new_rhs = add_bitcast(new_filter_shape, rhs);
3629 DotDimensionNumbers dot_dimension_numbers;
3630 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
3631 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
3632 auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
3633 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
3634 convolution->precision_config()));
3635
3636 TF_RETURN_IF_ERROR(
3637 ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
3638 return true;
3639 }
3640
HandleConvolution(HloInstruction * convolution)3641 Status AlgebraicSimplifierVisitor::HandleConvolution(
3642 HloInstruction* convolution) {
3643 // Zero-sized input or filter.
3644 if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
3645 ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
3646 return ReplaceWithNewInstruction(
3647 convolution,
3648 HloInstruction::CreateBroadcast(
3649 convolution->shape(),
3650 computation_->AddInstruction(HloInstruction::CreateConstant(
3651 LiteralUtil::Zero(convolution->shape().element_type()))),
3652 {}));
3653 }
3654
3655 // Try to merge padding/dilation of the input with the convolution's window.
3656 TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
3657 if (folded_input_pad) {
3658 return Status::OK();
3659 }
3660
3661 // Try to merge dilation of the filter with the convolution's window.
3662 TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
3663 if (folded_filter_pad) {
3664 return Status::OK();
3665 }
3666
3667 // Try to replace the convolution with a kDot instruction.
3668 TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
3669 if (replaced_with_dot) {
3670 return Status::OK();
3671 }
3672
3673 return Status::OK();
3674 }
3675
TransformToClampIfSameShape(HloInstruction * root,HloInstruction * min,HloInstruction * min_operand,HloInstruction * operand,HloInstruction * max,HloInstruction * max_operand)3676 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
3677 HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
3678 HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
3679 // Ensure shapes of min and max operand are equal to match current shape
3680 // inference.
3681 if (!SameShape(min_operand, max_operand)) {
3682 return false;
3683 }
3684
3685 auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
3686 max_operand, operand, min_operand);
3687 TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
3688 return true;
3689 }
3690
HandleMap(HloInstruction * map)3691 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
3692 auto* map_computation = map->to_apply();
3693 auto* map_root = map_computation->root_instruction();
3694 if (map_root->opcode() == HloOpcode::kParameter) {
3695 ReplaceInstructionIfSameShape(
3696 map, map->mutable_operand(map_root->parameter_number()));
3697 return Status::OK();
3698 }
3699 if (map_root->opcode() == HloOpcode::kConstant) {
3700 if (!ShapeUtil::IsScalar(map_root->shape())) {
3701 return Status::OK();
3702 }
3703 auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
3704 if (ShapeUtil::IsScalar(map->shape())) {
3705 return ReplaceWithNewInstruction(map, std::move(clone));
3706 }
3707 return ReplaceWithNewInstruction(
3708 map,
3709 HloInstruction::CreateBroadcast(
3710 map->shape(), computation_->AddInstruction(std::move(clone)), {}));
3711 }
3712 // Inline the map if the map computation only contains an elementwise
3713 // operation that can accept arbitrary shapes.
3714 if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
3715 return Status::OK();
3716 }
3717 std::vector<HloInstruction*> new_operands;
3718 for (auto* root_operand : map_root->operands()) {
3719 if (root_operand->opcode() != HloOpcode::kParameter) {
3720 return Status::OK();
3721 }
3722 new_operands.push_back(
3723 map->mutable_operand(root_operand->parameter_number()));
3724 }
3725 auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
3726 return ReplaceWithNewInstruction(map, std::move(clone));
3727 }
3728
Run(HloModule * module)3729 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
3730 XLA_VLOG_LINES(2,
3731 "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
3732 bool changed = false;
3733 for (auto* comp : module->MakeNonfusionComputations()) {
3734 if (AlgebraicSimplifierVisitor::Run(comp, options_)) {
3735 changed = true;
3736 }
3737 }
3738 XLA_VLOG_LINES(2,
3739 "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
3740 return changed;
3741 }
3742
3743 } // namespace xla
3744