• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering memref.tensor_load ops that are
17 // inserted during `mhlo-legalize-to-lmhlo`.
18 
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/Shape/IR/Shape.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Pass/PassRegistry.h"
30 #include "mlir/Support/LogicalResult.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 
33 namespace mlir {
34 namespace lmhlo {
35 namespace {
36 using shape::ShapeOfOp;
37 using tensor::ExtractOp;
38 
39 // Converting:
40 //   memref (operand) -> memref.tensor_load -> tensor.extract
41 //     to
42 //   memref (operand) -> memref.load
43 struct ForwardExtractOp : public OpRewritePattern<ExtractOp> {
44   using OpRewritePattern<ExtractOp>::OpRewritePattern;
45 
matchAndRewritemlir::lmhlo::__anon3f4708000111::ForwardExtractOp46   LogicalResult matchAndRewrite(ExtractOp extract,
47                                 PatternRewriter& rewriter) const override {
48     auto tensor_load = extract.tensor().getDefiningOp<memref::TensorLoadOp>();
49     if (!tensor_load) return failure();
50 
51     rewriter.replaceOpWithNewOp<memref::LoadOp>(
52         extract, extract.getType(), tensor_load.memref(), extract.indices());
53     return success();
54   }
55 };
56 
57 // Converting:
58 //   memref (operand) -> memref.tensor_load -> shape.shape_of
59 //     to
60 //   memref (operand) -> shape.shape_of
61 struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
62   using OpRewritePattern<ShapeOfOp>::OpRewritePattern;
63 
matchAndRewritemlir::lmhlo::__anon3f4708000111::ForwardShapeOfOp64   LogicalResult matchAndRewrite(ShapeOfOp shape_of,
65                                 PatternRewriter& rewriter) const override {
66     auto tensor_load = shape_of.arg().getDefiningOp<memref::TensorLoadOp>();
67     if (!tensor_load) return failure();
68 
69     rewriter.replaceOpWithNewOp<ShapeOfOp>(shape_of, shape_of.getType(),
70                                            tensor_load.memref());
71     return success();
72   }
73 };
74 
75 struct LegalizeTensorLoadOpPass
76     : public LegalizeTensorLoadOpPassBase<LegalizeTensorLoadOpPass> {
77   // Perform the lowering to remove memref.tensor_load ops inserted during
78   // `mhlo-legalize-to-lmhlo`.
runOnFunctionmlir::lmhlo::__anon3f4708000111::LegalizeTensorLoadOpPass79   void runOnFunction() override {
80     auto func = getFunction();
81     auto context = &getContext();
82     OwningRewritePatternList patterns(context);
83     patterns.insert<ForwardShapeOfOp, ForwardExtractOp>(context);
84     if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
85       func.emitError("applyPatternsAndFoldGreedily does not converge");
86       signalPassFailure();
87     }
88   }
89 };
90 
91 }  // namespace
92 
93 }  // namespace lmhlo
94 }  // namespace mlir
95 
96 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTensorLoadOpPass()97 mlir::lmhlo::createLegalizeTensorLoadOpPass() {
98   return std::make_unique<LegalizeTensorLoadOpPass>();
99 }
100