• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
17 #include "mlir/Transforms/Passes.h"  // from @llvm-project
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
19 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
20 
21 namespace tensorflow {
22 namespace tfrt_compiler {
23 namespace {
24 
25 // A special DenseMapInfo that hashes only operands of a operation, and treats
26 // two operations equivalent if their operands are the same.
27 struct OpWithSameArgsInfo : llvm::DenseMapInfo<mlir::Operation *> {
getHashValuetensorflow::tfrt_compiler::__anoncd6046970111::OpWithSameArgsInfo28   static unsigned getHashValue(const mlir::Operation *const_op) {
29     auto *op = const_cast<mlir::Operation *>(const_op);
30     return llvm::hash_combine(
31         llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
32   }
33 
isEqualtensorflow::tfrt_compiler::__anoncd6046970111::OpWithSameArgsInfo34   static bool isEqual(const mlir::Operation *const_lhs,
35                       const mlir::Operation *const_rhs) {
36     auto *lhs = const_cast<mlir::Operation *>(const_lhs);
37     auto *rhs = const_cast<mlir::Operation *>(const_rhs);
38     if (lhs == rhs) return true;
39     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
40         rhs == getTombstoneKey() || rhs == getEmptyKey())
41       return false;
42 
43     return std::equal(lhs->operand_begin(), lhs->operand_end(),
44                       rhs->operand_begin(), rhs->operand_end());
45   }
46 };
47 
48 // This pass merges non-side-effecting tf.If ops if their operands are the same.
49 // For example,
50 //  %r0 = tf.If(%cond, %x) {else = @else_0, then = @then_0}
51 //  %r1, %r2 = tf.If(%cond, %x) {else = @else_1, then = @then_1}
52 //
53 // will be converted to:
54 //  func private @merge_else(%arg) {
55 //    %r0 = tf.PartitionedCall(%arg) {f = @else_0}
56 //    %r1, %r2 = tf.PartitionedCall(%arg) {f = @else_1}
57 //    return %r0, %r1, %r2
58 //  }
59 //  func private @merge_then(%arg) {
60 //    %r0 = tf.PartitionedCall(%arg) {f = @then_0}
61 //    %r1, %r2 = tf.PartitionedCall(%arg) {f = @then_1}
62 //    return %r0, %r1, %r2
63 //  }
64 //
65 //  %r0, %r1, %r2 = tf.If(%cond, %arg) {else = @merge_else, then = @merge_then}
66 //
67 // Note that the results will be concatenated.
68 class MergeTfIfOpsPass
69     : public mlir::PassWrapper<MergeTfIfOpsPass,
70                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const71   llvm::StringRef getArgument() const final { return "tfrt-merge-tf-if-ops"; }
getDescription() const72   llvm::StringRef getDescription() const final {
73     return "Merge stateless tf.If ops with the same arguments.";
74   }
75 
runOnOperation()76   void runOnOperation() override {
77     auto module = getOperation();
78 
79     for (auto func_op :
80          llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
81       ProcessFunction(func_op);
82     }
83   }
84 
ProcessFunction(mlir::FuncOp op)85   void ProcessFunction(mlir::FuncOp op) {
86     // Use a hash map to group tf.If ops with the same operands.
87     llvm::SmallDenseMap<mlir::Operation *, llvm::SmallVector<mlir::TF::IfOp, 2>,
88                         2, OpWithSameArgsInfo>
89         if_ops_to_merge;
90 
91     for (mlir::Operation &op : op.front()) {
92       auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
93 
94       // Skip non tf.If ops and tf.If ops that are side-effecting.
95       if (!if_op || !if_op.is_stateless()) continue;
96 
97       if_ops_to_merge[if_op].push_back(if_op);
98     }
99 
100     int id = 0;
101 
102     // Set the insertion point to the current function, as we will insert new
103     // functions here.
104     mlir::OpBuilder builder(op);
105 
106     // Track the tf.If ops that should be removed as they are merged.
107     llvm::SmallVector<mlir::TF::IfOp, 4> if_ops_to_remove;
108 
109     for (auto &iter : if_ops_to_merge) {
110       if (iter.second.size() <= 1) continue;
111 
112       // Merge tf.If ops that have the same operands. The merged branches will
113       // be given unique names.
114       MergeIfOpsWithSameArgs(
115           builder, iter.first->getLoc(),
116           /*branch_prefix=*/
117           absl::StrCat(op.sym_name().str(), "_merged_if_", id++), iter.second);
118 
119       if_ops_to_remove.append(iter.second.begin(), iter.second.end());
120     }
121 
122     // Now that we are no longer using `if_ops_to_merge` or any other data
123     // structures that uses the operations that will be removed, we can now
124     // erase these if ops.
125     for (auto op : if_ops_to_remove) op->erase();
126   }
127 
MergeIfOpsWithSameArgs(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,llvm::MutableArrayRef<mlir::TF::IfOp> if_ops)128   void MergeIfOpsWithSameArgs(mlir::OpBuilder &builder, mlir::Location loc,
129                               absl::string_view branch_prefix,
130                               llvm::MutableArrayRef<mlir::TF::IfOp> if_ops) {
131     assert(if_ops.size() > 1);
132 
133     // The results of the merged tf.If op are the concatenation of results of
134     // the original tf.If ops.
135     llvm::SmallVector<mlir::Type, 4> new_result_types;
136     for (auto if_op : if_ops) {
137       new_result_types.append(if_op->result_type_begin(),
138                               if_op->result_type_end());
139     }
140 
141     auto branch_function_type = builder.getFunctionType(
142         if_ops.front().input().getTypes(), new_result_types);
143 
144     // Create new branches for the merged tf.If op.
145     auto then_branch_name = CreateBranchFunction(
146         builder, loc, branch_prefix,
147         /*branch_suffix=*/"_then", branch_function_type, if_ops,
148         [](mlir::TF::IfOp op) { return op.then_branchAttr(); });
149 
150     auto else_branch_name = CreateBranchFunction(
151         builder, loc, branch_prefix,
152         /*branch_suffix=*/"_else", branch_function_type, if_ops,
153         [](mlir::TF::IfOp op) { return op.else_branchAttr(); });
154 
155     mlir::OpBuilder::InsertionGuard guard(builder);
156     builder.setInsertionPoint(if_ops.front());
157 
158     // Create the merged tf.If op using the new branches.
159     auto new_if_op = builder.create<mlir::TF::IfOp>(
160         loc, new_result_types, if_ops.front().cond(), if_ops.front().input(),
161         then_branch_name, else_branch_name, /*is_stateless=*/true);
162 
163     // Replace the uses of results of the original tf.If ops with the results of
164     // the merged tf.If op.
165     auto new_result_iter = new_if_op.output().begin();
166     for (auto if_op : if_ops) {
167       for (auto result : if_op.output()) {
168         assert(new_result_iter != new_if_op.output().end());
169         result.replaceAllUsesWith(*new_result_iter);
170         ++new_result_iter;
171       }
172     }
173   }
174 
CreateBranchFunction(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,absl::string_view branch_suffix,mlir::FunctionType branch_function_type,llvm::ArrayRef<mlir::TF::IfOp> if_ops,llvm::function_ref<mlir::FlatSymbolRefAttr (mlir::TF::IfOp)> get_branch)175   llvm::StringRef CreateBranchFunction(
176       mlir::OpBuilder &builder, mlir::Location loc,
177       absl::string_view branch_prefix, absl::string_view branch_suffix,
178       mlir::FunctionType branch_function_type,
179       llvm::ArrayRef<mlir::TF::IfOp> if_ops,
180       llvm::function_ref<mlir::FlatSymbolRefAttr(mlir::TF::IfOp)> get_branch) {
181     std::string branch_name = absl::StrCat(branch_prefix, branch_suffix);
182     auto branch =
183         builder.create<mlir::FuncOp>(loc, branch_name, branch_function_type);
184     branch.setVisibility(mlir::FuncOp::Visibility::Private);
185 
186     mlir::OpBuilder::InsertionGuard guard(builder);
187 
188     // In the body of newly created branch function, we insert
189     // tf.PartitionedCall ops to call the original branches.
190     auto *block = branch.addEntryBlock();
191     builder.setInsertionPointToStart(block);
192     auto empty_string_attr = builder.getStringAttr("");
193 
194     llvm::SmallVector<mlir::Value, 4> results;
195     results.reserve(branch_function_type.getNumResults());
196 
197     for (auto if_op : if_ops) {
198       // Create the the call op to the original branch. The arguments are simply
199       // the arguments from the wrapper function.
200       auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
201           if_op.getLoc(), if_op.getResultTypes(), block->getArguments(),
202           get_branch(if_op), empty_string_attr, empty_string_attr,
203           empty_string_attr);
204 
205       // The results are the concatenation of the original branches.
206       results.append(call_op.output().begin(), call_op.output().end());
207     }
208 
209     builder.create<mlir::ReturnOp>(loc, results);
210 
211     return branch.sym_name();
212   }
213 };
214 
215 }  // namespace
216 
CreateMergeTfIfOpsPass()217 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass() {
218   return std::make_unique<MergeTfIfOpsPass>();
219 }
220 
221 static mlir::PassRegistration<MergeTfIfOpsPass> register_pass(
222     CreateMergeTfIfOpsPass);
223 
224 }  // namespace tfrt_compiler
225 }  // namespace tensorflow
226