1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Diagnostics.h" // from @llvm-project
28 #include "mlir/IR/Operation.h" // from @llvm-project
29 #include "mlir/IR/OperationSupport.h" // from @llvm-project
30 #include "mlir/IR/SymbolTable.h" // from @llvm-project
31 #include "mlir/Support/LLVM.h" // from @llvm-project
32 #include "mlir/Support/LogicalResult.h" // from @llvm-project
33 #include "mlir/Transforms/Passes.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36
37 namespace tensorflow {
38 namespace tfrt_compiler {
39 namespace {
40
41 using ::mlir::ArrayRef;
42 using ::mlir::FuncOp;
43 using ::mlir::ModuleOp;
44 using ::mlir::Operation;
45 using ::mlir::SymbolTable;
46 using ::mlir::SymbolTableCollection;
47 using ::mlir::SymbolUserMap;
48
49 // This only includes some preliminary checks as this is a short term solution.
AreEquivalent(FuncOp & lhs,FuncOp & rhs)50 bool AreEquivalent(FuncOp& lhs, FuncOp& rhs) {
51 if (lhs.getType() != rhs.getType()) return false;
52
53 for (auto arg_pair : llvm::zip(lhs.getArguments(), rhs.getArguments())) {
54 auto& lhs_arg = std::get<0>(arg_pair);
55 auto& rhs_arg = std::get<1>(arg_pair);
56 if (lhs_arg.getType() != rhs_arg.getType()) return false;
57 }
58
59 auto lhs_ops = lhs.body().getOps();
60 auto rhs_ops = rhs.body().getOps();
61 if (std::distance(lhs_ops.begin(), lhs_ops.end()) !=
62 std::distance(rhs_ops.begin(), rhs_ops.end()))
63 return false;
64
65 for (auto op_pair : llvm::zip(lhs_ops, rhs_ops)) {
66 auto& lhs_op = std::get<0>(op_pair);
67 auto& rhs_op = std::get<1>(op_pair);
68 if (lhs_op.getName() != rhs_op.getName()) return false;
69 if (lhs_op.getNumRegions() != rhs_op.getNumRegions()) return false;
70 if (lhs_op.getNumSuccessors() != rhs_op.getNumSuccessors()) return false;
71 if (!std::equal(lhs_op.getOperandTypes().begin(),
72 lhs_op.getOperandTypes().end(),
73 rhs_op.getOperandTypes().begin()))
74 return false;
75 if (!std::equal(lhs_op.getResultTypes().begin(),
76 lhs_op.getResultTypes().end(),
77 rhs_op.getResultTypes().begin()))
78 return false;
79 }
80
81 return true;
82 }
83
84 // Deduplicate the functions if all users are BatchFunctionOp and have the same
85 // shared_name.
86 //
87 // TODO(b/192463730): this is the short term solution and not needed anymore
88 // after the shape inference pass is revamped with ideal solution
89 // (b/192463730#comment11).
90 class DeduplicateFunctionsInovkedByBatchFunction
91 : public mlir::PassWrapper<DeduplicateFunctionsInovkedByBatchFunction,
92 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const93 llvm::StringRef getArgument() const final {
94 return "tfrt-deduplicate-functions-invoked-by-batch-function";
95 }
getDescription() const96 llvm::StringRef getDescription() const final {
97 return "Deduplicate the functions invoked by tf.BatchFunction with the "
98 "same shared_name";
99 }
runOnOperation()100 void runOnOperation() override {
101 if (failed(Run())) {
102 signalPassFailure();
103 }
104 }
105
106 mlir::LogicalResult Run();
107 };
108
Run()109 mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() {
110 ModuleOp module = getOperation();
111 SymbolTableCollection symbol_table_collection;
112 SymbolTable& symbol_table = symbol_table_collection.getSymbolTable(module);
113 SymbolUserMap symbol_users(symbol_table_collection, module);
114
115 // Categorize the functions invoked by BatchFunctionOp by its shared_name.
116 llvm::StringMap<llvm::SmallVector<FuncOp, 2>> shared_name_to_func_ops;
117
118 for (auto func : llvm::make_early_inc_range(module.getOps<FuncOp>())) {
119 ArrayRef<Operation*> users = symbol_users.getUsers(func);
120 llvm::StringRef shared_name;
121 // Deduplicate the function only if all users are BatchFunctionOp and have
122 // the same shared_name
123 if (!users.empty() && llvm::all_of(users, [&shared_name](Operation* user) {
124 auto op = llvm::dyn_cast_or_null<mlir::TF::BatchFunctionOp>(user);
125 // User is not a BatchFunctionOp
126 if (!op) return false;
127 if (shared_name.empty()) {
128 shared_name = op.shared_name();
129 return true;
130 }
131 return shared_name == op.shared_name();
132 })) {
133 shared_name_to_func_ops[shared_name].push_back(func);
134 }
135 }
136
137 for (auto& it : shared_name_to_func_ops) {
138 auto& func_ops = it.second;
139 FuncOp& func_op_to_keep = func_ops.front();
140 for (FuncOp& func_op_to_remove : llvm::drop_begin(func_ops)) {
141 if (!AreEquivalent(func_op_to_keep, func_op_to_remove)) {
142 return func_op_to_remove.emitError(
143 "func_ops for BatchFunctionOp with the same shared name are "
144 "different");
145 }
146 if (failed(SymbolTable::replaceAllSymbolUses(
147 func_op_to_remove, func_op_to_keep.getName(), module))) {
148 return func_op_to_remove.emitError("unable to replace the symbol use");
149 }
150 symbol_table.erase(func_op_to_remove);
151 }
152 }
153
154 return mlir::success();
155 }
156 } // namespace
157
158 std::unique_ptr<mlir::OperationPass<ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass()159 CreateDeduplicateFunctionsInovkedByBatchFunctionPass() {
160 return std::make_unique<DeduplicateFunctionsInovkedByBatchFunction>();
161 }
162
163 static mlir::PassRegistration<DeduplicateFunctionsInovkedByBatchFunction>
164 register_pass;
165
166 } // namespace tfrt_compiler
167 } // namespace tensorflow
168