1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <utility>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
28 #include "mlir/Dialect/SCF/SCF.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/Dialect/Tensor/IR/Tensor.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/OperationSupport.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/Interfaces/InferTypeOpInterface.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42
43 namespace mlir {
44 namespace mhlo {
45 namespace {
46
47 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anon5720ca320111::ShapeReificationPattern48 explicit ShapeReificationPattern(MLIRContext *context)
49 : OpRewritePattern<shape::ShapeOfOp>(context) {
50 // Recursively reify until we hit an op that doesn't support it.
51 setHasBoundedRewriteRecursion();
52 }
53
matchAndRewritemlir::mhlo::__anon5720ca320111::ShapeReificationPattern54 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
55 PatternRewriter &rewriter) const override {
56 // Only reify shape computation if operand allows for it.
57 auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
58 if (!shape_origin) return failure();
59
60 llvm::SmallVector<Value, 1> reifications;
61 if (failed(shape_origin.reifyReturnTypeShapes(
62 rewriter, shape_origin->getOperands(), reifications)))
63 return failure();
64 assert(reifications.size() == 1);
65 Value reified_shape = reifications.front();
66
67 // Insert cast if needed.
68 if (reified_shape.getType() != op.getType()) {
69 reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
70 reified_shape);
71 }
72
73 rewriter.replaceOp(op, reified_shape);
74 return success();
75 }
76 };
77
78 template <typename OpTy>
79 struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
80 using OpRewritePattern<OpTy>::OpRewritePattern;
81
matchAndRewritemlir::mhlo::__anon5720ca320111::InlineBroadcastedShapeOperandsPattern82 LogicalResult matchAndRewrite(OpTy op,
83 PatternRewriter &rewriter) const override {
84 // Find all the shape operands, direct and indirect.
85 SmallVector<Value, 8> inlined_operands;
86 for (Value direct : op->getOperands()) {
87 if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
88 for (Value indirect : bcast_op->getOperands())
89 inlined_operands.push_back(indirect);
90 } else {
91 inlined_operands.push_back(direct);
92 }
93 }
94
95 // Only rewrite if it makes a difference.
96 if (inlined_operands.size() == op.getNumOperands()) return failure();
97
98 // Inline shape operands.
99 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
100 inlined_operands, op->getAttrs());
101 return success();
102 }
103 };
104
MoveIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)105 LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
106 PatternRewriter &rewriter) {
107 // Only move into immediately preceding `assuming` op.
108 auto assuming_op =
109 llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
110 if (!assuming_op) return failure();
111
112 Block *body = assuming_op.getBody();
113 auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
114
115 // Find the operands to use if the op was within the assuming region. We
116 // will later use their copies, as we copy the assuming op and its body.
117 SmallVector<Value, 8> new_operands_unmapped =
118 llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
119 for (auto result : llvm::enumerate(assuming_op->getResults())) {
120 if (result.value() == v) return yield_op->getOperand(result.index());
121 }
122 return v;
123 }));
124
125 // Insert the rewritten assuming op right before the old one.
126 OpBuilder::InsertionGuard guard(rewriter);
127 rewriter.setInsertionPoint(assuming_op);
128 auto new_assuming_op = rewriter.create<shape::AssumingOp>(
129 assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
130 // Copy body.
131 BlockAndValueMapping mapping;
132 for (auto &nested : body->without_terminator())
133 b.clone(nested, mapping);
134
135 // Copy op into the new body and use the mapped operands.
136 for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
137 Value old_operand, new_operand_unmapped;
138 std::tie(old_operand, new_operand_unmapped) = it;
139 mapping.map(old_operand,
140 mapping.lookupOrDefault(new_operand_unmapped));
141 }
142 Operation *new_op = b.clone(*op, mapping);
143
144 // Yield the previous results and also the new ones.
145 auto mapped_results = llvm::to_vector<8>(llvm::map_range(
146 yield_op.operands(),
147 [&](Value v) { return mapping.lookupOrDefault(v); }));
148 mapped_results.append(new_op->getResults().begin(),
149 new_op->getResults().end());
150 return mapped_results;
151 });
152
153 // Replace the assuming op and the root op with the corresponding result
154 // value.
155 ValueRange new_assuming_op_results = new_assuming_op->getResults();
156 rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
157 rewriter.replaceOp(op, new_assuming_op_results.back());
158 return success();
159 }
160
161 /// Move operation into a preceding assuming op. This allows to process
162 /// operations that depend on the assuming op's results. It will eventually
163 /// allow to make assuming regions' constraints independent from each other.
164 template <typename OpTy>
165 struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
166 using OpRewritePattern<OpTy>::OpRewritePattern;
167
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveIntoAssumingOpPattern168 LogicalResult matchAndRewrite(OpTy op,
169 PatternRewriter &rewriter) const override {
170 return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
171 }
172 };
173
174 // Move elementwise operations into assuming regions. This will eventually allow
175 // for more fusion opportunities.
176 struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsIntoAssumingOpPatternmlir::mhlo::__anon5720ca320111::MoveElementwiseOpsIntoAssumingOpPattern177 explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
178 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
179
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveElementwiseOpsIntoAssumingOpPattern180 LogicalResult matchAndRewrite(Operation *op,
181 PatternRewriter &rewriter) const override {
182 // Apply to all elementwise and broadcasting elementwise operations.
183 if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
184 !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
185 return failure();
186
187 return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
188 }
189 };
190
191 /// Move operation out of assuming op. This is only valid for
192 /// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
193 /// will eventually allow to make assuming regions' constraints independent from
194 /// each other.
195 template <typename OpTy>
196 struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
197 using OpRewritePattern<OpTy>::OpRewritePattern;
198
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveOutOfAssumingOpPattern199 LogicalResult matchAndRewrite(OpTy op,
200 PatternRewriter &rewriter) const override {
201 // Must be inside of an assuming op.
202 auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
203 if (!assuming_op) return failure();
204
205 // Operands must not be defined within the assuming op.
206 Block *body = assuming_op.getBody();
207 auto is_available = [&](Value v) {
208 Operation *def = v.getDefiningOp();
209 return def == nullptr || def->getBlock() != body;
210 };
211 if (!llvm::all_of(op->getOperands(), is_available)) return failure();
212
213 // Move op before the assuming region.
214 OpBuilder::InsertionGuard guard(rewriter);
215 rewriter.setInsertionPoint(assuming_op);
216 Operation *new_op = rewriter.clone(*op);
217 rewriter.replaceOp(op, new_op->getResults());
218
219 // If the assuming region yields none of the new op's results, these values
220 // are exclusively used in the assuming op's body. In these cases there is
221 // no need for further rewrites.
222 auto is_new_op_result = [&](Value v) {
223 return llvm::is_contained(new_op->getResults(), v);
224 };
225 auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
226 if (llvm::none_of(yield_op.operands(), is_new_op_result)) return success();
227
228 // If the assuming region yields any of the new op's results, these values
229 // can instead bypass the assuming region. There is no need to yield them
230 // explicitly as they are assumed to be independent. The assuming op is
231 // rewritten accordingly.
232 SmallVector<Value, 2> replacement_values;
233 auto new_assuming_op = rewriter.create<shape::AssumingOp>(
234 assuming_op.getLoc(), assuming_op.witness(),
235 [&](OpBuilder &b, Location) {
236 // Copy body.
237 BlockAndValueMapping mapping;
238 for (Operation &nested : body->without_terminator()) {
239 b.clone(nested, mapping);
240 }
241
242 // Collect new yield operands.
243 SmallVector<Value, 2> new_yield_operands;
244 for (Value result : yield_op.operands()) {
245 if (is_new_op_result(result)) {
246 replacement_values.push_back(result);
247 } else {
248 new_yield_operands.push_back(mapping.lookup(result));
249 replacement_values.push_back(nullptr);
250 }
251 }
252 return new_yield_operands;
253 });
254
255 // Use the assuming op's results for the missing replacement values.
256 auto src = new_assuming_op.getResults().begin();
257 for (auto &dst : replacement_values) {
258 if (dst) continue;
259 dst = *src++;
260 }
261
262 rewriter.replaceOp(assuming_op, replacement_values);
263 return success();
264 }
265 };
266
267 /// Merge assuming regions if their constraints are independent from each other.
268 struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
269 using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
270
matchAndRewritemlir::mhlo::__anon5720ca320111::MergeAssumingOpsPattern271 LogicalResult matchAndRewrite(shape::AssumingOp op,
272 PatternRewriter &rewriter) const override {
273 // Merge assuming op with directly preceding one if both witnesses are
274 // availiable.
275 auto preceding_op =
276 llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
277 if (!preceding_op) return failure();
278 if (op.witness().getDefiningOp() == preceding_op) return failure();
279
280 // Merge witnesses.
281 OpBuilder::InsertionGuard guard(rewriter);
282 rewriter.setInsertionPoint(preceding_op);
283 Value new_witness = rewriter.create<shape::AssumingAllOp>(
284 op.witness().getDefiningOp()->getLoc(),
285 ValueRange{preceding_op.witness(), op.witness()});
286
287 // Merge assuming ops.
288 Block *body_a = preceding_op.getBody();
289 Block *body_b = op.getBody();
290 auto new_assuming_op = rewriter.create<shape::AssumingOp>(
291 preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
292 // Copy preceding op's body.
293 BlockAndValueMapping mapping;
294 for (auto &nested : body_a->without_terminator()) {
295 b.clone(nested, mapping);
296 }
297
298 // Map result values of preceding assuming op.
299 auto yield_op_a =
300 llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
301 for (auto pair :
302 llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
303 mapping.map(std::get<0>(pair),
304 mapping.lookupOrDefault(std::get<1>(pair)));
305 }
306
307 // Copy op's body.
308 for (auto &nested : body_b->without_terminator()) {
309 b.clone(nested, mapping);
310 }
311
312 // Collect merged assuming op's results.
313 SmallVector<Value, 4> mapped_results;
314 auto yield_op_b =
315 llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
316 for (Value v : yield_op_a.operands()) {
317 mapped_results.push_back(mapping.lookupOrDefault(v));
318 }
319 for (Value v : yield_op_b.operands()) {
320 mapped_results.push_back(mapping.lookupOrDefault(v));
321 }
322 return mapped_results;
323 });
324
325 // Replace the two assuming ops with the new corresponding results.
326 ValueRange new_results = new_assuming_op->getResults();
327 size_t split_at = preceding_op->getNumResults();
328 rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
329 rewriter.replaceOp(op, new_results.drop_front(split_at));
330 return success();
331 }
332 };
333
334 struct EliminateDuplicateCstrBroadcastableOps
335 : public OpRewritePattern<shape::CstrBroadcastableOp> {
336 using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
337
matchAndRewritemlir::mhlo::__anon5720ca320111::EliminateDuplicateCstrBroadcastableOps338 LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
339 PatternRewriter &rewriter) const override {
340 // Search for previous occurence of the same constraint.
341 Operation *it = op->getPrevNode();
342 while (it != nullptr) {
343 if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
344 if (candidate.shapes() == op.shapes()) {
345 rewriter.replaceOp(op, candidate.result());
346 return success();
347 }
348 }
349 it = it->getPrevNode();
350 }
351
352 return failure();
353 }
354 };
355
356 struct EarlyBroadcastInDimOpPattern
357 : public OpRewritePattern<DynamicBroadcastInDimOp> {
358 using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
359
matchAndRewritemlir::mhlo::__anon5720ca320111::EarlyBroadcastInDimOpPattern360 LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
361 PatternRewriter &rewriter) const override {
362 Operation *producer_op = bcast_op.operand().getDefiningOp();
363 if (!producer_op ||
364 !producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
365 !producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
366 return failure();
367 }
368
369 // Materialize broadcast on operands.
370 SmallVector<Value, 2> bcasted_operands;
371 Location loc = bcast_op.getLoc();
372 ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
373 for (Value operand : producer_op->getOperands()) {
374 // The broadcast only works on ranked operations.
375 auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
376 if (!operand_ty) {
377 return bcast_op.emitError()
378 << "Can only move up broadcasts over ranked tensor operands.";
379 }
380
381 auto bcasted_operand_ty =
382 RankedTensorType::get(ty_shape, operand_ty.getElementType());
383 bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
384 loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
385 bcast_op.broadcast_dimensions()));
386 }
387
388 // Create a copy of the producer op with the new broadcasted operands.
389 OperationState new_producer_op_state(
390 loc, producer_op->getName().getStringRef(), bcasted_operands,
391 bcast_op.getType(), producer_op->getAttrs());
392 Operation *new_producer_op =
393 rewriter.createOperation(new_producer_op_state);
394
395 // The original result of the broadcast now falls directly out of the new
396 // producer op. Use it instead.
397 rewriter.replaceOp(bcast_op, new_producer_op->getResults());
398
399 return success();
400 }
401 };
402
403 struct BroadcastPropagationPass
404 : public BroadcastPropagationPassBase<BroadcastPropagationPass> {
getDependentDialectsmlir::mhlo::__anon5720ca320111::BroadcastPropagationPass405 void getDependentDialects(DialectRegistry ®istry) const override {
406 registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
407 }
408
runOnFunctionmlir::mhlo::__anon5720ca320111::BroadcastPropagationPass409 void runOnFunction() override {
410 MLIRContext *ctx = &getContext();
411 RewritePatternSet patterns(ctx);
412 mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
413 if (failed(
414 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
415 return signalPassFailure();
416 }
417 }
418 };
419
420 } // namespace
421
PopulateBroadcastsPropagationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)422 void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
423 OwningRewritePatternList *patterns) {
424 // clang-format off
425 patterns->insert<
426 EliminateDuplicateCstrBroadcastableOps,
427 InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
428 MergeAssumingOpsPattern,
429 MoveElementwiseOpsIntoAssumingOpPattern,
430 MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
431 MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
432 MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
433 MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
434 EarlyBroadcastInDimOpPattern,
435 ShapeReificationPattern>(context);
436 // clang-format on
437 mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
438 context);
439 mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
440 shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
441 shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
442 shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
443 shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
444 tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
445 }
446
createBroadcastPropagationPass()447 std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
448 return std::make_unique<BroadcastPropagationPass>();
449 }
450
451 } // namespace mhlo
452 } // namespace mlir
453