1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
17 #include "mlir/Transforms/Passes.h" // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24
25 // A special DenseMapInfo that hashes only operands of a operation, and treats
26 // two operations equivalent if their operands are the same.
27 struct OpWithSameArgsInfo : llvm::DenseMapInfo<mlir::Operation *> {
getHashValuetensorflow::tfrt_compiler::__anoncd6046970111::OpWithSameArgsInfo28 static unsigned getHashValue(const mlir::Operation *const_op) {
29 auto *op = const_cast<mlir::Operation *>(const_op);
30 return llvm::hash_combine(
31 llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
32 }
33
isEqualtensorflow::tfrt_compiler::__anoncd6046970111::OpWithSameArgsInfo34 static bool isEqual(const mlir::Operation *const_lhs,
35 const mlir::Operation *const_rhs) {
36 auto *lhs = const_cast<mlir::Operation *>(const_lhs);
37 auto *rhs = const_cast<mlir::Operation *>(const_rhs);
38 if (lhs == rhs) return true;
39 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
40 rhs == getTombstoneKey() || rhs == getEmptyKey())
41 return false;
42
43 return std::equal(lhs->operand_begin(), lhs->operand_end(),
44 rhs->operand_begin(), rhs->operand_end());
45 }
46 };
47
48 // This pass merges non-side-effecting tf.If ops if their operands are the same.
49 // For example,
50 // %r0 = tf.If(%cond, %x) {else = @else_0, then = @then_0}
51 // %r1, %r2 = tf.If(%cond, %x) {else = @else_1, then = @then_1}
52 //
53 // will be converted to:
54 // func private @merge_else(%arg) {
55 // %r0 = tf.PartitionedCall(%arg) {f = @else_0}
56 // %r1, %r2 = tf.PartitionedCall(%arg) {f = @else_1}
57 // return %r0, %r1, %r2
58 // }
59 // func private @merge_then(%arg) {
60 // %r0 = tf.PartitionedCall(%arg) {f = @then_0}
61 // %r1, %r2 = tf.PartitionedCall(%arg) {f = @then_1}
62 // return %r0, %r1, %r2
63 // }
64 //
65 // %r0, %r1, %r2 = tf.If(%cond, %arg) {else = @merge_else, then = @merge_then}
66 //
67 // Note that the results will be concatenated.
68 class MergeTfIfOpsPass
69 : public mlir::PassWrapper<MergeTfIfOpsPass,
70 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const71 llvm::StringRef getArgument() const final { return "tfrt-merge-tf-if-ops"; }
getDescription() const72 llvm::StringRef getDescription() const final {
73 return "Merge stateless tf.If ops with the same arguments.";
74 }
75
runOnOperation()76 void runOnOperation() override {
77 auto module = getOperation();
78
79 for (auto func_op :
80 llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
81 ProcessFunction(func_op);
82 }
83 }
84
ProcessFunction(mlir::FuncOp op)85 void ProcessFunction(mlir::FuncOp op) {
86 // Use a hash map to group tf.If ops with the same operands.
87 llvm::SmallDenseMap<mlir::Operation *, llvm::SmallVector<mlir::TF::IfOp, 2>,
88 2, OpWithSameArgsInfo>
89 if_ops_to_merge;
90
91 for (mlir::Operation &op : op.front()) {
92 auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
93
94 // Skip non tf.If ops and tf.If ops that are side-effecting.
95 if (!if_op || !if_op.is_stateless()) continue;
96
97 if_ops_to_merge[if_op].push_back(if_op);
98 }
99
100 int id = 0;
101
102 // Set the insertion point to the current function, as we will insert new
103 // functions here.
104 mlir::OpBuilder builder(op);
105
106 // Track the tf.If ops that should be removed as they are merged.
107 llvm::SmallVector<mlir::TF::IfOp, 4> if_ops_to_remove;
108
109 for (auto &iter : if_ops_to_merge) {
110 if (iter.second.size() <= 1) continue;
111
112 // Merge tf.If ops that have the same operands. The merged branches will
113 // be given unique names.
114 MergeIfOpsWithSameArgs(
115 builder, iter.first->getLoc(),
116 /*branch_prefix=*/
117 absl::StrCat(op.sym_name().str(), "_merged_if_", id++), iter.second);
118
119 if_ops_to_remove.append(iter.second.begin(), iter.second.end());
120 }
121
122 // Now that we are no longer using `if_ops_to_merge` or any other data
123 // structures that uses the operations that will be removed, we can now
124 // erase these if ops.
125 for (auto op : if_ops_to_remove) op->erase();
126 }
127
MergeIfOpsWithSameArgs(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,llvm::MutableArrayRef<mlir::TF::IfOp> if_ops)128 void MergeIfOpsWithSameArgs(mlir::OpBuilder &builder, mlir::Location loc,
129 absl::string_view branch_prefix,
130 llvm::MutableArrayRef<mlir::TF::IfOp> if_ops) {
131 assert(if_ops.size() > 1);
132
133 // The results of the merged tf.If op are the concatenation of results of
134 // the original tf.If ops.
135 llvm::SmallVector<mlir::Type, 4> new_result_types;
136 for (auto if_op : if_ops) {
137 new_result_types.append(if_op->result_type_begin(),
138 if_op->result_type_end());
139 }
140
141 auto branch_function_type = builder.getFunctionType(
142 if_ops.front().input().getTypes(), new_result_types);
143
144 // Create new branches for the merged tf.If op.
145 auto then_branch_name = CreateBranchFunction(
146 builder, loc, branch_prefix,
147 /*branch_suffix=*/"_then", branch_function_type, if_ops,
148 [](mlir::TF::IfOp op) { return op.then_branchAttr(); });
149
150 auto else_branch_name = CreateBranchFunction(
151 builder, loc, branch_prefix,
152 /*branch_suffix=*/"_else", branch_function_type, if_ops,
153 [](mlir::TF::IfOp op) { return op.else_branchAttr(); });
154
155 mlir::OpBuilder::InsertionGuard guard(builder);
156 builder.setInsertionPoint(if_ops.front());
157
158 // Create the merged tf.If op using the new branches.
159 auto new_if_op = builder.create<mlir::TF::IfOp>(
160 loc, new_result_types, if_ops.front().cond(), if_ops.front().input(),
161 then_branch_name, else_branch_name, /*is_stateless=*/true);
162
163 // Replace the uses of results of the original tf.If ops with the results of
164 // the merged tf.If op.
165 auto new_result_iter = new_if_op.output().begin();
166 for (auto if_op : if_ops) {
167 for (auto result : if_op.output()) {
168 assert(new_result_iter != new_if_op.output().end());
169 result.replaceAllUsesWith(*new_result_iter);
170 ++new_result_iter;
171 }
172 }
173 }
174
CreateBranchFunction(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,absl::string_view branch_suffix,mlir::FunctionType branch_function_type,llvm::ArrayRef<mlir::TF::IfOp> if_ops,llvm::function_ref<mlir::FlatSymbolRefAttr (mlir::TF::IfOp)> get_branch)175 llvm::StringRef CreateBranchFunction(
176 mlir::OpBuilder &builder, mlir::Location loc,
177 absl::string_view branch_prefix, absl::string_view branch_suffix,
178 mlir::FunctionType branch_function_type,
179 llvm::ArrayRef<mlir::TF::IfOp> if_ops,
180 llvm::function_ref<mlir::FlatSymbolRefAttr(mlir::TF::IfOp)> get_branch) {
181 std::string branch_name = absl::StrCat(branch_prefix, branch_suffix);
182 auto branch =
183 builder.create<mlir::FuncOp>(loc, branch_name, branch_function_type);
184 branch.setVisibility(mlir::FuncOp::Visibility::Private);
185
186 mlir::OpBuilder::InsertionGuard guard(builder);
187
188 // In the body of newly created branch function, we insert
189 // tf.PartitionedCall ops to call the original branches.
190 auto *block = branch.addEntryBlock();
191 builder.setInsertionPointToStart(block);
192 auto empty_string_attr = builder.getStringAttr("");
193
194 llvm::SmallVector<mlir::Value, 4> results;
195 results.reserve(branch_function_type.getNumResults());
196
197 for (auto if_op : if_ops) {
198 // Create the the call op to the original branch. The arguments are simply
199 // the arguments from the wrapper function.
200 auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
201 if_op.getLoc(), if_op.getResultTypes(), block->getArguments(),
202 get_branch(if_op), empty_string_attr, empty_string_attr,
203 empty_string_attr);
204
205 // The results are the concatenation of the original branches.
206 results.append(call_op.output().begin(), call_op.output().end());
207 }
208
209 builder.create<mlir::ReturnOp>(loc, results);
210
211 return branch.sym_name();
212 }
213 };
214
215 } // namespace
216
CreateMergeTfIfOpsPass()217 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass() {
218 return std::make_unique<MergeTfIfOpsPass>();
219 }
220
221 static mlir::PassRegistration<MergeTfIfOpsPass> register_pass(
222 CreateMergeTfIfOpsPass);
223
224 } // namespace tfrt_compiler
225 } // namespace tensorflow
226