• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <utility>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
28 #include "mlir/Dialect/SCF/SCF.h"
29 #include "mlir/Dialect/Shape/IR/Shape.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/Dialect/Tensor/IR/Tensor.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/OperationSupport.h"
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/Interfaces/InferTypeOpInterface.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
42 
43 namespace mlir {
44 namespace mhlo {
45 namespace {
46 
47 struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
ShapeReificationPatternmlir::mhlo::__anon5720ca320111::ShapeReificationPattern48   explicit ShapeReificationPattern(MLIRContext *context)
49       : OpRewritePattern<shape::ShapeOfOp>(context) {
50     // Recursively reify until we hit an op that doesn't support it.
51     setHasBoundedRewriteRecursion();
52   }
53 
matchAndRewritemlir::mhlo::__anon5720ca320111::ShapeReificationPattern54   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
55                                 PatternRewriter &rewriter) const override {
56     // Only reify shape computation if operand allows for it.
57     auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>();
58     if (!shape_origin) return failure();
59 
60     llvm::SmallVector<Value, 1> reifications;
61     if (failed(shape_origin.reifyReturnTypeShapes(
62             rewriter, shape_origin->getOperands(), reifications)))
63       return failure();
64     assert(reifications.size() == 1);
65     Value reified_shape = reifications.front();
66 
67     // Insert cast if needed.
68     if (reified_shape.getType() != op.getType()) {
69       reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
70                                                       reified_shape);
71     }
72 
73     rewriter.replaceOp(op, reified_shape);
74     return success();
75   }
76 };
77 
78 template <typename OpTy>
79 struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
80   using OpRewritePattern<OpTy>::OpRewritePattern;
81 
matchAndRewritemlir::mhlo::__anon5720ca320111::InlineBroadcastedShapeOperandsPattern82   LogicalResult matchAndRewrite(OpTy op,
83                                 PatternRewriter &rewriter) const override {
84     // Find all the shape operands, direct and indirect.
85     SmallVector<Value, 8> inlined_operands;
86     for (Value direct : op->getOperands()) {
87       if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
88         for (Value indirect : bcast_op->getOperands())
89           inlined_operands.push_back(indirect);
90       } else {
91         inlined_operands.push_back(direct);
92       }
93     }
94 
95     // Only rewrite if it makes a difference.
96     if (inlined_operands.size() == op.getNumOperands()) return failure();
97 
98     // Inline shape operands.
99     rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
100                                       inlined_operands, op->getAttrs());
101     return success();
102   }
103 };
104 
MoveIntoAssumingOpMatchAndRewrite(Operation * op,PatternRewriter & rewriter)105 LogicalResult MoveIntoAssumingOpMatchAndRewrite(Operation *op,
106                                                 PatternRewriter &rewriter) {
107   // Only move into immediately preceding `assuming` op.
108   auto assuming_op =
109       llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
110   if (!assuming_op) return failure();
111 
112   Block *body = assuming_op.getBody();
113   auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
114 
115   // Find the operands to use if the op was within the assuming region. We
116   // will later use their copies, as we copy the assuming op and its body.
117   SmallVector<Value, 8> new_operands_unmapped =
118       llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) {
119         for (auto result : llvm::enumerate(assuming_op->getResults())) {
120           if (result.value() == v) return yield_op->getOperand(result.index());
121         }
122         return v;
123       }));
124 
125   // Insert the rewritten assuming op right before the old one.
126   OpBuilder::InsertionGuard guard(rewriter);
127   rewriter.setInsertionPoint(assuming_op);
128   auto new_assuming_op = rewriter.create<shape::AssumingOp>(
129       assuming_op.getLoc(), assuming_op.witness(), [&](OpBuilder &b, Location) {
130         // Copy body.
131         BlockAndValueMapping mapping;
132         for (auto &nested : body->without_terminator())
133           b.clone(nested, mapping);
134 
135         // Copy op into the new body and use the mapped operands.
136         for (auto it : llvm::zip(op->getOperands(), new_operands_unmapped)) {
137           Value old_operand, new_operand_unmapped;
138           std::tie(old_operand, new_operand_unmapped) = it;
139           mapping.map(old_operand,
140                       mapping.lookupOrDefault(new_operand_unmapped));
141         }
142         Operation *new_op = b.clone(*op, mapping);
143 
144         // Yield the previous results and also the new ones.
145         auto mapped_results = llvm::to_vector<8>(llvm::map_range(
146             yield_op.operands(),
147             [&](Value v) { return mapping.lookupOrDefault(v); }));
148         mapped_results.append(new_op->getResults().begin(),
149                               new_op->getResults().end());
150         return mapped_results;
151       });
152 
153   // Replace the assuming op and the root op with the corresponding result
154   // value.
155   ValueRange new_assuming_op_results = new_assuming_op->getResults();
156   rewriter.replaceOp(assuming_op, new_assuming_op_results.drop_back());
157   rewriter.replaceOp(op, new_assuming_op_results.back());
158   return success();
159 }
160 
161 /// Move operation into a preceding assuming op. This allows to process
162 /// operations that depend on the assuming op's results. It will eventually
163 /// allow to make assuming regions' constraints independent from each other.
164 template <typename OpTy>
165 struct MoveIntoAssumingOpPattern : public OpRewritePattern<OpTy> {
166   using OpRewritePattern<OpTy>::OpRewritePattern;
167 
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveIntoAssumingOpPattern168   LogicalResult matchAndRewrite(OpTy op,
169                                 PatternRewriter &rewriter) const override {
170     return MoveIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter);
171   }
172 };
173 
174 // Move elementwise operations into assuming regions. This will eventually allow
175 // for more fusion opportunities.
176 struct MoveElementwiseOpsIntoAssumingOpPattern : public RewritePattern {
MoveElementwiseOpsIntoAssumingOpPatternmlir::mhlo::__anon5720ca320111::MoveElementwiseOpsIntoAssumingOpPattern177   explicit MoveElementwiseOpsIntoAssumingOpPattern(MLIRContext *ctx)
178       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
179 
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveElementwiseOpsIntoAssumingOpPattern180   LogicalResult matchAndRewrite(Operation *op,
181                                 PatternRewriter &rewriter) const override {
182     // Apply to all elementwise and broadcasting elementwise operations.
183     if (!op->hasTrait<mlir::OpTrait::Elementwise>() &&
184         !op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>())
185       return failure();
186 
187     return MoveIntoAssumingOpMatchAndRewrite(op, rewriter);
188   }
189 };
190 
191 /// Move operation out of assuming op. This is only valid for
192 /// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It
193 /// will eventually allow to make assuming regions' constraints independent from
194 /// each other.
195 template <typename OpTy>
196 struct MoveOutOfAssumingOpPattern : public OpRewritePattern<OpTy> {
197   using OpRewritePattern<OpTy>::OpRewritePattern;
198 
matchAndRewritemlir::mhlo::__anon5720ca320111::MoveOutOfAssumingOpPattern199   LogicalResult matchAndRewrite(OpTy op,
200                                 PatternRewriter &rewriter) const override {
201     // Must be inside of an assuming op.
202     auto assuming_op = op->template getParentOfType<shape::AssumingOp>();
203     if (!assuming_op) return failure();
204 
205     // Operands must not be defined within the assuming op.
206     Block *body = assuming_op.getBody();
207     auto is_available = [&](Value v) {
208       Operation *def = v.getDefiningOp();
209       return def == nullptr || def->getBlock() != body;
210     };
211     if (!llvm::all_of(op->getOperands(), is_available)) return failure();
212 
213     // Move op before the assuming region.
214     OpBuilder::InsertionGuard guard(rewriter);
215     rewriter.setInsertionPoint(assuming_op);
216     Operation *new_op = rewriter.clone(*op);
217     rewriter.replaceOp(op, new_op->getResults());
218 
219     // If the assuming region yields none of the new op's results, these values
220     // are exclusively used in the assuming op's body. In these cases there is
221     // no need for further rewrites.
222     auto is_new_op_result = [&](Value v) {
223       return llvm::is_contained(new_op->getResults(), v);
224     };
225     auto yield_op = cast<shape::AssumingYieldOp>(body->getTerminator());
226     if (llvm::none_of(yield_op.operands(), is_new_op_result)) return success();
227 
228     // If the assuming region yields any of the new op's results, these values
229     // can instead bypass the assuming region. There is no need to yield them
230     // explicitly as they are assumed to be independent. The assuming op is
231     // rewritten accordingly.
232     SmallVector<Value, 2> replacement_values;
233     auto new_assuming_op = rewriter.create<shape::AssumingOp>(
234         assuming_op.getLoc(), assuming_op.witness(),
235         [&](OpBuilder &b, Location) {
236           // Copy body.
237           BlockAndValueMapping mapping;
238           for (Operation &nested : body->without_terminator()) {
239             b.clone(nested, mapping);
240           }
241 
242           // Collect new yield operands.
243           SmallVector<Value, 2> new_yield_operands;
244           for (Value result : yield_op.operands()) {
245             if (is_new_op_result(result)) {
246               replacement_values.push_back(result);
247             } else {
248               new_yield_operands.push_back(mapping.lookup(result));
249               replacement_values.push_back(nullptr);
250             }
251           }
252           return new_yield_operands;
253         });
254 
255     // Use the assuming op's results for the missing replacement values.
256     auto src = new_assuming_op.getResults().begin();
257     for (auto &dst : replacement_values) {
258       if (dst) continue;
259       dst = *src++;
260     }
261 
262     rewriter.replaceOp(assuming_op, replacement_values);
263     return success();
264   }
265 };
266 
267 /// Merge assuming regions if their constraints are independent from each other.
268 struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
269   using OpRewritePattern<shape::AssumingOp>::OpRewritePattern;
270 
matchAndRewritemlir::mhlo::__anon5720ca320111::MergeAssumingOpsPattern271   LogicalResult matchAndRewrite(shape::AssumingOp op,
272                                 PatternRewriter &rewriter) const override {
273     // Merge assuming op with directly preceding one if both witnesses are
274     // availiable.
275     auto preceding_op =
276         llvm::dyn_cast_or_null<shape::AssumingOp>(op->getPrevNode());
277     if (!preceding_op) return failure();
278     if (op.witness().getDefiningOp() == preceding_op) return failure();
279 
280     // Merge witnesses.
281     OpBuilder::InsertionGuard guard(rewriter);
282     rewriter.setInsertionPoint(preceding_op);
283     Value new_witness = rewriter.create<shape::AssumingAllOp>(
284         op.witness().getDefiningOp()->getLoc(),
285         ValueRange{preceding_op.witness(), op.witness()});
286 
287     // Merge assuming ops.
288     Block *body_a = preceding_op.getBody();
289     Block *body_b = op.getBody();
290     auto new_assuming_op = rewriter.create<shape::AssumingOp>(
291         preceding_op.getLoc(), new_witness, [&](OpBuilder &b, Location) {
292           // Copy preceding op's body.
293           BlockAndValueMapping mapping;
294           for (auto &nested : body_a->without_terminator()) {
295             b.clone(nested, mapping);
296           }
297 
298           // Map result values of preceding assuming op.
299           auto yield_op_a =
300               llvm::dyn_cast<shape::AssumingYieldOp>(body_a->getTerminator());
301           for (auto pair :
302                llvm::zip(preceding_op->getResults(), yield_op_a.operands())) {
303             mapping.map(std::get<0>(pair),
304                         mapping.lookupOrDefault(std::get<1>(pair)));
305           }
306 
307           // Copy op's body.
308           for (auto &nested : body_b->without_terminator()) {
309             b.clone(nested, mapping);
310           }
311 
312           // Collect merged assuming op's results.
313           SmallVector<Value, 4> mapped_results;
314           auto yield_op_b =
315               llvm::dyn_cast<shape::AssumingYieldOp>(body_b->getTerminator());
316           for (Value v : yield_op_a.operands()) {
317             mapped_results.push_back(mapping.lookupOrDefault(v));
318           }
319           for (Value v : yield_op_b.operands()) {
320             mapped_results.push_back(mapping.lookupOrDefault(v));
321           }
322           return mapped_results;
323         });
324 
325     // Replace the two assuming ops with the new corresponding results.
326     ValueRange new_results = new_assuming_op->getResults();
327     size_t split_at = preceding_op->getNumResults();
328     rewriter.replaceOp(preceding_op, new_results.take_front(split_at));
329     rewriter.replaceOp(op, new_results.drop_front(split_at));
330     return success();
331   }
332 };
333 
334 struct EliminateDuplicateCstrBroadcastableOps
335     : public OpRewritePattern<shape::CstrBroadcastableOp> {
336   using OpRewritePattern<shape::CstrBroadcastableOp>::OpRewritePattern;
337 
matchAndRewritemlir::mhlo::__anon5720ca320111::EliminateDuplicateCstrBroadcastableOps338   LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
339                                 PatternRewriter &rewriter) const override {
340     // Search for previous occurence of the same constraint.
341     Operation *it = op->getPrevNode();
342     while (it != nullptr) {
343       if (auto candidate = llvm::dyn_cast<shape::CstrBroadcastableOp>(it)) {
344         if (candidate.shapes() == op.shapes()) {
345           rewriter.replaceOp(op, candidate.result());
346           return success();
347         }
348       }
349       it = it->getPrevNode();
350     }
351 
352     return failure();
353   }
354 };
355 
356 struct EarlyBroadcastInDimOpPattern
357     : public OpRewritePattern<DynamicBroadcastInDimOp> {
358   using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern;
359 
matchAndRewritemlir::mhlo::__anon5720ca320111::EarlyBroadcastInDimOpPattern360   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op,
361                                 PatternRewriter &rewriter) const override {
362     Operation *producer_op = bcast_op.operand().getDefiningOp();
363     if (!producer_op ||
364         !producer_op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
365         !producer_op->hasTrait<mlir::OpTrait::Elementwise>()) {
366       return failure();
367     }
368 
369     // Materialize broadcast on operands.
370     SmallVector<Value, 2> bcasted_operands;
371     Location loc = bcast_op.getLoc();
372     ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape();
373     for (Value operand : producer_op->getOperands()) {
374       // The broadcast only works on ranked operations.
375       auto operand_ty = operand.getType().dyn_cast<RankedTensorType>();
376       if (!operand_ty) {
377         return bcast_op.emitError()
378                << "Can only move up broadcasts over ranked tensor operands.";
379       }
380 
381       auto bcasted_operand_ty =
382           RankedTensorType::get(ty_shape, operand_ty.getElementType());
383       bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>(
384           loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(),
385           bcast_op.broadcast_dimensions()));
386     }
387 
388     // Create a copy of the producer op with the new broadcasted operands.
389     OperationState new_producer_op_state(
390         loc, producer_op->getName().getStringRef(), bcasted_operands,
391         bcast_op.getType(), producer_op->getAttrs());
392     Operation *new_producer_op =
393         rewriter.createOperation(new_producer_op_state);
394 
395     // The original result of the broadcast now falls directly out of the new
396     // producer op. Use it instead.
397     rewriter.replaceOp(bcast_op, new_producer_op->getResults());
398 
399     return success();
400   }
401 };
402 
403 struct BroadcastPropagationPass
404     : public BroadcastPropagationPassBase<BroadcastPropagationPass> {
getDependentDialectsmlir::mhlo::__anon5720ca320111::BroadcastPropagationPass405   void getDependentDialects(DialectRegistry &registry) const override {
406     registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
407   }
408 
runOnFunctionmlir::mhlo::__anon5720ca320111::BroadcastPropagationPass409   void runOnFunction() override {
410     MLIRContext *ctx = &getContext();
411     RewritePatternSet patterns(ctx);
412     mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
413     if (failed(
414             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
415       return signalPassFailure();
416     }
417   }
418 };
419 
420 }  // namespace
421 
PopulateBroadcastsPropagationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)422 void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
423                                            OwningRewritePatternList *patterns) {
424   // clang-format off
425   patterns->insert<
426       EliminateDuplicateCstrBroadcastableOps,
427       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
428       MergeAssumingOpsPattern,
429       MoveElementwiseOpsIntoAssumingOpPattern,
430       MoveIntoAssumingOpPattern<shape::CstrBroadcastableOp>,
431       MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
432       MoveOutOfAssumingOpPattern<shape::CstrBroadcastableOp>,
433       MoveOutOfAssumingOpPattern<shape::ShapeOfOp>,
434       EarlyBroadcastInDimOpPattern,
435       ShapeReificationPattern>(context);
436   // clang-format on
437   mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns,
438                                                              context);
439   mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context);
440   shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context);
441   shape::AssumingOp::getCanonicalizationPatterns(*patterns, context);
442   shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
443   shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context);
444   tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
445 }
446 
createBroadcastPropagationPass()447 std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
448   return std::make_unique<BroadcastPropagationPass>();
449 }
450 
451 }  // namespace mhlo
452 }  // namespace mlir
453