1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
17 #include "mlir/Transforms/Passes.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24
25 // This pass removes tf.If ops' operands that are produced by tf.Const ops.
26 // These constants can be moved into branches' function body for further
27 // optimziation.
28 class RemoveTfIfConstArgs
29 : public mlir::PassWrapper<RemoveTfIfConstArgs,
30 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const31 llvm::StringRef getArgument() const final {
32 return "tfrt-remove-tf-if-const-args";
33 }
getDescription() const34 llvm::StringRef getDescription() const final {
35 return "Remove const args from tf.If ops";
36 }
37
runOnOperation()38 void runOnOperation() override {
39 auto module = getOperation();
40 for (auto func_op :
41 llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
42 ProcessFunction(func_op);
43 }
44 }
45
ProcessFunction(mlir::FuncOp op)46 void ProcessFunction(mlir::FuncOp op) {
47 // Set the insertion point to the current function, as we will insert new
48 // functions here.
49 mlir::OpBuilder builder(op);
50 for (mlir::Operation &op : op.front()) {
51 auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
52 if (!if_op) continue;
53
54 // Record the operands that are produced by tf.Const ops.
55 llvm::SmallVector<mlir::TF::ConstOp, 2> const_args;
56 // Record these operands's corresponding operand indices.
57 llvm::SmallVector<unsigned, 2> const_arg_indices;
58 // Record the remaining operands that won't be removed.
59 llvm::SmallVector<mlir::Value, 2> remaining_args;
60 for (auto iter : llvm::enumerate(if_op.input())) {
61 mlir::Value operand = iter.value();
62 if (auto const_op = operand.getDefiningOp<mlir::TF::ConstOp>()) {
63 const_args.push_back(const_op);
64 const_arg_indices.push_back(iter.index());
65 } else {
66 remaining_args.push_back(operand);
67 }
68 }
69
70 if (const_args.empty()) continue;
71
72 RemoveConstArgsFromTfIfOp(builder, if_op, const_args, const_arg_indices,
73 remaining_args);
74 }
75 }
76
RemoveConstArgsFromTfIfOp(mlir::OpBuilder & builder,mlir::TF::IfOp if_op,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices,llvm::ArrayRef<mlir::Value> remaining_args)77 void RemoveConstArgsFromTfIfOp(mlir::OpBuilder &builder, mlir::TF::IfOp if_op,
78 llvm::ArrayRef<mlir::TF::ConstOp> const_args,
79 llvm::ArrayRef<unsigned> const_arg_indices,
80 llvm::ArrayRef<mlir::Value> remaining_args) {
81 auto branch_suffix = absl::StrCat("_removed_const_args_", id_++);
82
83 // Create wrapper functions with the new arguments (as const args are
84 // removed) for both then function and else function.
85 auto new_then_function_name =
86 CreateBranchFunction(builder, if_op.then_function(), branch_suffix,
87 const_args, const_arg_indices);
88 auto new_else_function_name =
89 CreateBranchFunction(builder, if_op.else_function(), branch_suffix,
90 const_args, const_arg_indices);
91
92 // Change the if_op's argumetns to the new arguments, branches to new
93 // branches. Note that the outputs are not changed.
94 if_op.inputMutable().assign(remaining_args);
95 if_op.then_branchAttr(builder.getSymbolRefAttr(new_then_function_name));
96 if_op.else_branchAttr(builder.getSymbolRefAttr(new_else_function_name));
97 }
98
CreateBranchFunction(mlir::OpBuilder & builder,mlir::FuncOp branch,absl::string_view branch_suffix,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices)99 llvm::StringRef CreateBranchFunction(
100 mlir::OpBuilder &builder, mlir::FuncOp branch,
101 absl::string_view branch_suffix,
102 llvm::ArrayRef<mlir::TF::ConstOp> const_args,
103 llvm::ArrayRef<unsigned> const_arg_indices) {
104 // Get the new function type as const args are removed.
105 auto new_branch_type =
106 branch.getType().getWithoutArgsAndResults(const_arg_indices, {});
107 std::string new_branch_name =
108 absl::StrCat(branch.sym_name().str(), branch_suffix);
109 // Create the wrapper function with the new arguments that calls the
110 // original branch.
111 auto new_branch = builder.create<mlir::FuncOp>(
112 branch.getLoc(), new_branch_name, new_branch_type);
113 new_branch.setVisibility(mlir::FuncOp::Visibility::Private);
114
115 // In its function body, we will add the corresponding const ops and call
116 // the original branch.
117
118 mlir::OpBuilder::InsertionGuard guard(builder);
119 auto *block = new_branch.addEntryBlock();
120 builder.setInsertionPointToStart(block);
121
122 // Prepare the function arguments of the original branch.
123 llvm::SmallVector<mlir::Value, 4> call_args(branch.getNumArguments());
124
125 // For those removed const args, we copy the tf.Const op, and use that as
126 // the corresponding argument when calling the original branch.
127 for (const auto &iter : llvm::zip(const_args, const_arg_indices)) {
128 auto const_op =
129 llvm::cast<mlir::TF::ConstOp>(builder.clone(*std::get<0>(iter)));
130 unsigned index = std::get<1>(iter);
131 call_args[index] = const_op;
132 }
133
134 // For the rest, they are now coming from the wrapper function's arguments
135 // in the original order.
136 for (int i = 0, j = 0; i < call_args.size(); ++i) {
137 if (!call_args[i]) {
138 assert(j < block->getNumArguments());
139 call_args[i] = block->getArgument(j++);
140 }
141 }
142
143 // Now create the call op to the original branch.
144 auto call_op = builder.create<mlir::TF::StatefulPartitionedCallOp>(
145 new_branch.getLoc(), new_branch_type.getResults(), call_args,
146 branch.sym_name(), "", "", "");
147 // Note that the outputs are not changed.
148 builder.create<mlir::ReturnOp>(new_branch.getLoc(), call_op.output());
149
150 return new_branch.sym_name();
151 }
152
153 int id_ = 0;
154 };
155
156 } // namespace
157
158 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateRemoveTfIfConstArgsPass()159 CreateRemoveTfIfConstArgsPass() {
160 return std::make_unique<RemoveTfIfConstArgs>();
161 }
162
163 static mlir::PassRegistration<RemoveTfIfConstArgs> register_pass(
164 CreateRemoveTfIfConstArgsPass);
165
166 } // namespace tfrt_compiler
167 } // namespace tensorflow
168