1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Converts TF While to TFL While with single call in body and cond.
17
18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/Operation.h" // from @llvm-project
22 #include "mlir/IR/PatternMatch.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27
28 namespace mlir {
29 namespace TFL {
30 namespace {
31
32 // Legalize TF While to TFL While with calls to the original functions from the
33 // cond and body regions.
34 struct LegalizeWhile
35 : public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
getDependentDialectsmlir::TFL::__anon9936039b0111::LegalizeWhile36 void getDependentDialects(DialectRegistry& registry) const override {
37 registry.insert<TFL::TensorFlowLiteDialect>();
38 }
39
getArgumentmlir::TFL::__anon9936039b0111::LegalizeWhile40 StringRef getArgument() const final {
41 // This is the argument used to refer to the pass in
42 // the textual format (on the commandline for example).
43 return "tfl-legalize-tf-while";
44 }
getDescriptionmlir::TFL::__anon9936039b0111::LegalizeWhile45 StringRef getDescription() const final {
46 // This is a brief description of the pass.
47 return "Legalize from TensorFlow While to TensorFlow Lite While";
48 }
49
50 void RunOnFunction(FuncOp func);
51
runOnOperationmlir::TFL::__anon9936039b0111::LegalizeWhile52 void runOnOperation() override {
53 for (auto op : getOperation().getOps<FuncOp>()) RunOnFunction(op);
54 }
55 };
56
57 } // namespace
58
59 // Inserts call to the given function into the 'region'.
CreateRegionWithCall(FuncOp func,Region & region,Location loc)60 void CreateRegionWithCall(FuncOp func, Region& region, Location loc) {
61 OpBuilder builder(region);
62 auto block = builder.createBlock(®ion);
63 SmallVector<Value, 4> new_operands;
64 for (Type t : func.getType().getInputs())
65 new_operands.push_back(block->addArgument(t));
66 auto call = builder.create<CallOp>(loc, func, new_operands);
67 builder.create<YieldOp>(loc, call.getResults());
68 // Mark old function as private so that it can be DCE'd if not called.
69 func.setPrivate();
70 }
71
RunOnWhile(TF::WhileOp while_op)72 void RunOnWhile(TF::WhileOp while_op) {
73 Operation* op = while_op.getOperation();
74 // Create new TFL While op that will be used to replace TF While op.
75 auto new_op = OpBuilder(op).create<TFL::WhileOp>(
76 op->getLoc(), op->getResultTypes(), op->getOperands(),
77 while_op.is_stateless());
78 Location loc = while_op->getLoc();
79 CreateRegionWithCall(while_op.cond_function(), new_op.cond(), loc);
80 CreateRegionWithCall(while_op.body_function(), new_op.body(), loc);
81
82 op->replaceAllUsesWith(new_op.getResults());
83 op->erase();
84 }
85
RunOnFunction(FuncOp func)86 void LegalizeWhile::RunOnFunction(FuncOp func) {
87 // Convert all TF WhileOps inside the function body to TFL While ops.
88 func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); });
89 }
90
91 // Creates an instance of the TensorFlow While to TFLite While pass.
CreateLegalizeTFWhilePass()92 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
93 return std::make_unique<LegalizeWhile>();
94 }
95
96 static PassRegistration<LegalizeWhile> pass;
97
98 } // namespace TFL
99 } // namespace mlir
100