• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 
31 namespace mlir {
32 namespace TFL {
33 namespace {
34 
35 // Module pass to optimize TensorFlow functional ops.
36 struct OptimizeFunctionalOpsPass
37     : public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
38   void runOnOperation() override;
39 
getArgumentmlir::TFL::__anon4327b8f10111::OptimizeFunctionalOpsPass40   StringRef getArgument() const final {
41     // This is the argument used to refer to the pass in
42     // the textual format (on the commandline for example).
43     return "tfl-optimize-functional-ops";
44   }
getDescriptionmlir::TFL::__anon4327b8f10111::OptimizeFunctionalOpsPass45   StringRef getDescription() const final {
46     // This is a brief description of the pass.
47     return "Optimize TensorFlow functional ops";
48   }
49 };
50 
51 // Updates function return type of the given functions to match the terminator
52 // op operands' types.
53 //
54 // Requires the function has exactly one block.
UpdateFuncType(FuncOp func)55 void UpdateFuncType(FuncOp func) {
56   Operation* terminator = func.front().getTerminator();
57   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
58 
59   FunctionType func_type = func.getType();
60   if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
61 
62   auto updated_type =
63       FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
64   func.setType(updated_type);
65 }
66 
67 // TODO(jpienaar): Remove when recursive side-effect modeling is added.
IsSideEffectFree(FuncOp func)68 bool IsSideEffectFree(FuncOp func) {
69   return !func.getBody()
70               .walk([&](Operation* op) {
71                 if (!MemoryEffectOpInterface::hasNoEffect(op) &&
72                     !op->hasTrait<OpTrait::IsTerminator>())
73                   return WalkResult::interrupt();
74                 return WalkResult::advance();
75               })
76               .wasInterrupted();
77 }
78 
79 // Folds TensorFlow If op with constant conditional operand by inlining the
80 // function body based on the conditional value.
81 class FoldIfOp : public OpRewritePattern<TF::IfOp> {
82  public:
FoldIfOp(MLIRContext * context)83   explicit FoldIfOp(MLIRContext* context)
84       : OpRewritePattern<TF::IfOp>(context) {}
85 
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const86   LogicalResult matchAndRewrite(TF::IfOp op,
87                                 PatternRewriter& rewriter) const override {
88     // This pattern is restricted to if ops in functions with exactly one block
89     // and therefore one terminator op. So, that function return type can be
90     // updated if operands' shapes change after inlining. Without this
91     // restriction, it would require tensor cast ops.
92     FuncOp parent_op = op->getParentOfType<FuncOp>();
93     if (!llvm::hasSingleElement(parent_op)) return failure();
94 
95     // Find the then and else branch functions.
96     FuncOp then_func = op.then_function();
97     FuncOp else_func = op.else_function();
98 
99     // If the If has no uses and its functions are side-effect free, then
100     // remove.
101     // TODO(jpienaar): Remove once recusive side-effects are supported.
102     if (op.use_empty() &&
103         (op.is_stateless() ||
104          (IsSideEffectFree(then_func) && IsSideEffectFree(else_func)))) {
105       rewriter.eraseOp(op.getOperation());
106       return success();
107     }
108 
109     // Extract the constant cond value.
110     DenseElementsAttr cond;
111     if (!matchPattern(op.cond(), m_Constant(&cond))) return failure();
112 
113     // TODO(hinsu): Handle constants that are not scalar booleans.
114     auto cond_type = cond.getType().dyn_cast<RankedTensorType>();
115     if (!cond_type || !cond_type.getShape().equals({}) ||
116         !cond_type.getElementType().isInteger(/*width=*/1))
117       return failure();
118 
119     // Identify the branch to inline.
120     bool cond_value = (*cond.int_value_begin()).getSExtValue();
121     FuncOp func = cond_value ? then_func : else_func;
122 
123     // Make sure that the function has exactly one block to simplify inlining.
124     // TFLite doesn't use control flow with blocks so functions with more than
125     // one blocks are not encountered in practice.
126     if (!llvm::hasSingleElement(func)) return failure();
127 
128     BlockAndValueMapping mapper;
129     for (int i = 0, e = func.getNumArguments(); i != e; ++i)
130       mapper.map(func.getArgument(i), op.getOperand(i + 1));
131 
132     llvm::SmallVector<Value, 4> updated_results;
133     for (auto& op_to_inline : func.front()) {
134       // If this is a terminator, identify the values to use to replace the
135       // original If op.
136       if (op_to_inline.hasTrait<OpTrait::IsTerminator>()) {
137         updated_results.reserve(op_to_inline.getNumOperands());
138         for (Value operand : op_to_inline.getOperands())
139           updated_results.push_back(mapper.lookup(operand));
140         break;
141       }
142 
143       // Otherwise, clone the op here.
144       rewriter.clone(op_to_inline, mapper);
145     }
146     rewriter.replaceOp(op, updated_results);
147 
148     // Here, shapes of the updated_results may not match the original values. If
149     // any of the values are operands of the terminator op, then the function
150     // return type should be updated.
151     UpdateFuncType(parent_op);
152 
153     return success();
154   }
155 };
156 
runOnOperation()157 void OptimizeFunctionalOpsPass::runOnOperation() {
158   OwningRewritePatternList patterns(&getContext());
159 
160   patterns.insert<FoldIfOp>(&getContext());
161 
162   ModuleOp module = getOperation();
163   (void)applyPatternsAndFoldGreedily(module, std::move(patterns));
164 }
165 
166 PassRegistration<OptimizeFunctionalOpsPass> pass;
167 }  // namespace
168 
CreateOptimizeFunctionalOpsPass()169 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
170   return std::make_unique<OptimizeFunctionalOpsPass>();
171 }
172 
173 }  // namespace TFL
174 }  // namespace mlir
175