• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
17 #include "mlir/Transforms/Passes.h"  // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20 
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24 
25 // This pass removes tf.If ops' operands that are produced by tf.Const ops.
26 // These constants can be moved into branches' function body for further
27 // optimziation.
28 class RemoveTfIfConstArgs
29     : public mlir::PassWrapper<RemoveTfIfConstArgs,
30                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const31   llvm::StringRef getArgument() const final {
32     return "tfrt-remove-tf-if-const-args";
33   }
getDescription() const34   llvm::StringRef getDescription() const final {
35     return "Remove const args from tf.If ops";
36   }
37 
runOnOperation()38   void runOnOperation() override {
39     auto module = getOperation();
40     for (auto func_op :
41          llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
42       ProcessFunction(func_op);
43     }
44   }
45 
ProcessFunction(mlir::FuncOp op)46   void ProcessFunction(mlir::FuncOp op) {
47     // Set the insertion point to the current function, as we will insert new
48     // functions here.
49     mlir::OpBuilder builder(op);
50     for (mlir::Operation &op : op.front()) {
51       auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
52       if (!if_op) continue;
53 
54       // Record the operands that are produced by tf.Const ops.
55       llvm::SmallVector<mlir::TF::ConstOp, 2> const_args;
56       // Record these operands's corresponding operand indices.
57       llvm::SmallVector<unsigned, 2> const_arg_indices;
58       // Record the remaining operands that won't be removed.
59       llvm::SmallVector<mlir::Value, 2> remaining_args;
60       for (auto iter : llvm::enumerate(if_op.input())) {
61         mlir::Value operand = iter.value();
62         if (auto const_op = operand.getDefiningOp<mlir::TF::ConstOp>()) {
63           const_args.push_back(const_op);
64           const_arg_indices.push_back(iter.index());
65         } else {
66           remaining_args.push_back(operand);
67         }
68       }
69 
70       if (const_args.empty()) continue;
71 
72       RemoveConstArgsFromTfIfOp(builder, if_op, const_args, const_arg_indices,
73                                 remaining_args);
74     }
75   }
76 
RemoveConstArgsFromTfIfOp(mlir::OpBuilder & builder,mlir::TF::IfOp if_op,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices,llvm::ArrayRef<mlir::Value> remaining_args)77   void RemoveConstArgsFromTfIfOp(mlir::OpBuilder &builder, mlir::TF::IfOp if_op,
78                                  llvm::ArrayRef<mlir::TF::ConstOp> const_args,
79                                  llvm::ArrayRef<unsigned> const_arg_indices,
80                                  llvm::ArrayRef<mlir::Value> remaining_args) {
81     auto branch_suffix = absl::StrCat("_removed_const_args_", id_++);
82 
83     // Create wrapper functions with the new arguments (as const args are
84     // removed) for both then function and else function.
85     auto new_then_function_name =
86         CreateBranchFunction(builder, if_op.then_function(), branch_suffix,
87                              const_args, const_arg_indices);
88     auto new_else_function_name =
89         CreateBranchFunction(builder, if_op.else_function(), branch_suffix,
90                              const_args, const_arg_indices);
91 
92     // Change the if_op's argumetns to the new arguments, branches to new
93     // branches. Note that the outputs are not changed.
94     if_op.inputMutable().assign(remaining_args);
95     if_op.then_branchAttr(builder.getSymbolRefAttr(new_then_function_name));
96     if_op.else_branchAttr(builder.getSymbolRefAttr(new_else_function_name));
97   }
98 
CreateBranchFunction(mlir::OpBuilder & builder,mlir::FuncOp branch,absl::string_view branch_suffix,llvm::ArrayRef<mlir::TF::ConstOp> const_args,llvm::ArrayRef<unsigned> const_arg_indices)99   llvm::StringRef CreateBranchFunction(
100       mlir::OpBuilder &builder, mlir::FuncOp branch,
101       absl::string_view branch_suffix,
102       llvm::ArrayRef<mlir::TF::ConstOp> const_args,
103       llvm::ArrayRef<unsigned> const_arg_indices) {
104     // Get the new function type as const args are removed.
105     auto new_branch_type =
106         branch.getType().getWithoutArgsAndResults(const_arg_indices, {});
107     std::string new_branch_name =
108         absl::StrCat(branch.sym_name().str(), branch_suffix);
109     // Create the wrapper function with the new arguments that calls the
110     // original branch.
111     auto new_branch = builder.create<mlir::FuncOp>(
112         branch.getLoc(), new_branch_name, new_branch_type);
113     new_branch.setVisibility(mlir::FuncOp::Visibility::Private);
114 
115     // In its function body, we will add the corresponding const ops and call
116     // the original branch.
117 
118     mlir::OpBuilder::InsertionGuard guard(builder);
119     auto *block = new_branch.addEntryBlock();
120     builder.setInsertionPointToStart(block);
121 
122     // Prepare the function arguments of the original branch.
123     llvm::SmallVector<mlir::Value, 4> call_args(branch.getNumArguments());
124 
125     // For those removed const args, we copy the tf.Const op, and use that as
126     // the corresponding argument when calling the original branch.
127     for (const auto &iter : llvm::zip(const_args, const_arg_indices)) {
128       auto const_op =
129           llvm::cast<mlir::TF::ConstOp>(builder.clone(*std::get<0>(iter)));
130       unsigned index = std::get<1>(iter);
131       call_args[index] = const_op;
132     }
133 
134     // For the rest, they are now coming from the wrapper function's arguments
135     // in the original order.
136     for (int i = 0, j = 0; i < call_args.size(); ++i) {
137       if (!call_args[i]) {
138         assert(j < block->getNumArguments());
139         call_args[i] = block->getArgument(j++);
140       }
141     }
142 
143     // Now create the call op to the original branch.
144     auto call_op = builder.create<mlir::TF::StatefulPartitionedCallOp>(
145         new_branch.getLoc(), new_branch_type.getResults(), call_args,
146         branch.sym_name(), "", "", "");
147     // Note that the outputs are not changed.
148     builder.create<mlir::ReturnOp>(new_branch.getLoc(), call_op.output());
149 
150     return new_branch.sym_name();
151   }
152 
153   int id_ = 0;
154 };
155 
156 }  // namespace
157 
158 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateRemoveTfIfConstArgsPass()159 CreateRemoveTfIfConstArgsPass() {
160   return std::make_unique<RemoveTfIfConstArgs>();
161 }
162 
163 static mlir::PassRegistration<RemoveTfIfConstArgs> register_pass(
164     CreateRemoveTfIfConstArgsPass);
165 
166 }  // namespace tfrt_compiler
167 }  // namespace tensorflow
168