1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Converts TF While to TFL While with single call in body and cond.
17
18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/Operation.h" // from @llvm-project
22 #include "mlir/IR/PatternMatch.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
25 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27
28 namespace mlir {
29 namespace TFL {
30 namespace {
31
32 // Legalize TF While to TFL While with calls to the original functions from the
33 // cond and body regions.
34 struct LegalizeWhile
35 : public PassWrapper<LegalizeWhile, OperationPass<ModuleOp>> {
getDependentDialectsmlir::TFL::__anon4efec2b60111::LegalizeWhile36 void getDependentDialects(DialectRegistry& registry) const override {
37 registry.insert<TFL::TensorFlowLiteDialect>();
38 }
39
40 void RunOnFunction(FuncOp func);
41
runOnOperationmlir::TFL::__anon4efec2b60111::LegalizeWhile42 void runOnOperation() override {
43 for (auto op : getOperation().getOps<FuncOp>()) RunOnFunction(op);
44 }
45 };
46
47 } // namespace
48
RunOnWhile(TF::WhileOp while_op)49 void RunOnWhile(TF::WhileOp while_op) {
50 Operation* op = while_op.getOperation();
51 // Create new TFL While op that will be used to replace TF While op.
52 auto new_op = OpBuilder(op).create<TFL::WhileOp>(
53 op->getLoc(), op->getResultTypes(), op->getOperands(),
54 while_op.is_stateless());
55 // Insert call to the given function into the 'region'.
56 auto create_region_with_call = [&while_op](FuncOp func, Region& region) {
57 OpBuilder builder(region);
58 auto block = builder.createBlock(®ion);
59 SmallVector<Value, 4> new_operands;
60 for (Type t : func.getType().getInputs())
61 new_operands.push_back(block->addArgument(t));
62 auto call = builder.create<CallOp>(while_op.getLoc(), func, new_operands);
63 builder.create<YieldOp>(while_op.getLoc(), call.getResults());
64 // Mark old function as private so that it can be DCE'd if not called.
65 func.setPrivate();
66 };
67 create_region_with_call(while_op.cond_function(), new_op.cond());
68 create_region_with_call(while_op.body_function(), new_op.body());
69
70 op->replaceAllUsesWith(new_op.getResults());
71 op->erase();
72 }
73
RunOnFunction(FuncOp func)74 void LegalizeWhile::RunOnFunction(FuncOp func) {
75 // Convert all TF WhileOps inside the function body to TFL While ops.
76 func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); });
77 }
78
79 // Creates an instance of the TensorFlow While to TFLite While pass.
CreateLegalizeTFWhilePass()80 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass() {
81 return std::make_unique<LegalizeWhile>();
82 }
83
84 static PassRegistration<LegalizeWhile> pass(
85 "tfl-legalize-tf-while",
86 "Legalize from TensorFlow While to TensorFlow Lite While");
87
88 } // namespace TFL
89 } // namespace mlir
90