• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Casting.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 
31 namespace mlir {
32 namespace TFL {
33 namespace {
34 
35 // Module pass to optimize TensorFlow functional ops.
36 struct OptimizeFunctionalOpsPass
37     : public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
38   void runOnOperation() override;
39 };
40 
41 // Updates function return type of the given functions to match the terminator
42 // op operands' types.
43 //
44 // Requires the function has exactly one block.
UpdateFuncType(FuncOp func)45 void UpdateFuncType(FuncOp func) {
46   Operation* terminator = func.front().getTerminator();
47   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
48 
49   FunctionType func_type = func.getType();
50   if (llvm::makeArrayRef(return_types) == func_type.getResults()) return;
51 
52   auto updated_type =
53       FunctionType::get(func.getContext(), func_type.getInputs(), return_types);
54   func.setType(updated_type);
55 }
56 
57 // TODO(jpienaar): Remove when recursive side-effect modeling is added.
IsSideEffectFree(FuncOp func)58 bool IsSideEffectFree(FuncOp func) {
59   return !func.getBody()
60               .walk([&](Operation* op) {
61                 if (!MemoryEffectOpInterface::hasNoEffect(op) &&
62                     !op->hasTrait<OpTrait::IsTerminator>())
63                   return WalkResult::interrupt();
64                 return WalkResult::advance();
65               })
66               .wasInterrupted();
67 }
68 
69 // Folds TensorFlow If op with constant conditional operand by inlining the
70 // function body based on the conditional value.
71 class FoldIfOp : public OpRewritePattern<TF::IfOp> {
72  public:
FoldIfOp(MLIRContext * context)73   explicit FoldIfOp(MLIRContext* context)
74       : OpRewritePattern<TF::IfOp>(context) {}
75 
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const76   LogicalResult matchAndRewrite(TF::IfOp op,
77                                 PatternRewriter& rewriter) const override {
78     // This pattern is restricted to if ops in functions with exactly one block
79     // and therefore one terminator op. So, that function return type can be
80     // updated if operands' shapes change after inlining. Without this
81     // restriction, it would require tensor cast ops.
82     FuncOp parent_op = op->getParentOfType<FuncOp>();
83     if (!llvm::hasSingleElement(parent_op)) return failure();
84 
85     // Find the then and else branch functions.
86     FuncOp then_func = op.then_function();
87     FuncOp else_func = op.else_function();
88 
89     // If the If has no uses and its functions are side-effect free, then
90     // remove.
91     // TODO(jpienaar): Remove once recusive side-effects are supported.
92     if (op.use_empty() &&
93         (op.is_stateless() ||
94          (IsSideEffectFree(then_func) && IsSideEffectFree(else_func)))) {
95       rewriter.eraseOp(op.getOperation());
96       return success();
97     }
98 
99     // Extract the constant cond value.
100     DenseElementsAttr cond;
101     if (!matchPattern(op.cond(), m_Constant(&cond))) return failure();
102 
103     // TODO(hinsu): Handle constants that are not scalar booleans.
104     auto cond_type = cond.getType().dyn_cast<RankedTensorType>();
105     if (!cond_type || !cond_type.getShape().equals({}) ||
106         !cond_type.getElementType().isInteger(/*width=*/1))
107       return failure();
108 
109     // Identify the branch to inline.
110     bool cond_value = (*cond.int_value_begin()).getSExtValue();
111     FuncOp func = cond_value ? then_func : else_func;
112 
113     // Make sure that the function has exactly one block to simplify inlining.
114     // TFLite doesn't use control flow with blocks so functions with more than
115     // one blocks are not encountered in practice.
116     if (!llvm::hasSingleElement(func)) return failure();
117 
118     BlockAndValueMapping mapper;
119     for (int i = 0, e = func.getNumArguments(); i != e; ++i)
120       mapper.map(func.getArgument(i), op.getOperand(i + 1));
121 
122     llvm::SmallVector<Value, 4> updated_results;
123     for (auto& op_to_inline : func.front()) {
124       // If this is a terminator, identify the values to use to replace the
125       // original If op.
126       if (op_to_inline.hasTrait<OpTrait::IsTerminator>()) {
127         updated_results.reserve(op_to_inline.getNumOperands());
128         for (Value operand : op_to_inline.getOperands())
129           updated_results.push_back(mapper.lookup(operand));
130         break;
131       }
132 
133       // Otherwise, clone the op here.
134       rewriter.clone(op_to_inline, mapper);
135     }
136     rewriter.replaceOp(op, updated_results);
137 
138     // Here, shapes of the updated_results may not match the original values. If
139     // any of the values are operands of the terminator op, then the function
140     // return type should be updated.
141     UpdateFuncType(parent_op);
142 
143     return success();
144   }
145 };
146 
runOnOperation()147 void OptimizeFunctionalOpsPass::runOnOperation() {
148   OwningRewritePatternList patterns;
149 
150   patterns.insert<FoldIfOp>(&getContext());
151 
152   ModuleOp module = getOperation();
153   (void)applyPatternsAndFoldGreedily(module, std::move(patterns));
154 }
155 
156 PassRegistration<OptimizeFunctionalOpsPass> pass(
157     "tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
158 }  // namespace
159 
CreateOptimizeFunctionalOpsPass()160 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
161   return std::make_unique<OptimizeFunctionalOpsPass>();
162 }
163 
164 }  // namespace TFL
165 }  // namespace mlir
166