• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
30 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
31 #include "mlir/Support/LLVM.h"  // from @llvm-project
32 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
33 #include "mlir/Transforms/Passes.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36 
37 namespace tensorflow {
38 namespace tfrt_compiler {
39 namespace {
40 
41 using ::mlir::ArrayRef;
42 using ::mlir::FuncOp;
43 using ::mlir::ModuleOp;
44 using ::mlir::Operation;
45 using ::mlir::SymbolTable;
46 using ::mlir::SymbolTableCollection;
47 using ::mlir::SymbolUserMap;
48 
49 // This only includes some preliminary checks as this is a short term solution.
AreEquivalent(FuncOp & lhs,FuncOp & rhs)50 bool AreEquivalent(FuncOp& lhs, FuncOp& rhs) {
51   if (lhs.getType() != rhs.getType()) return false;
52 
53   for (auto arg_pair : llvm::zip(lhs.getArguments(), rhs.getArguments())) {
54     auto& lhs_arg = std::get<0>(arg_pair);
55     auto& rhs_arg = std::get<1>(arg_pair);
56     if (lhs_arg.getType() != rhs_arg.getType()) return false;
57   }
58 
59   auto lhs_ops = lhs.body().getOps();
60   auto rhs_ops = rhs.body().getOps();
61   if (std::distance(lhs_ops.begin(), lhs_ops.end()) !=
62       std::distance(rhs_ops.begin(), rhs_ops.end()))
63     return false;
64 
65   for (auto op_pair : llvm::zip(lhs_ops, rhs_ops)) {
66     auto& lhs_op = std::get<0>(op_pair);
67     auto& rhs_op = std::get<1>(op_pair);
68     if (lhs_op.getName() != rhs_op.getName()) return false;
69     if (lhs_op.getNumRegions() != rhs_op.getNumRegions()) return false;
70     if (lhs_op.getNumSuccessors() != rhs_op.getNumSuccessors()) return false;
71     if (!std::equal(lhs_op.getOperandTypes().begin(),
72                     lhs_op.getOperandTypes().end(),
73                     rhs_op.getOperandTypes().begin()))
74       return false;
75     if (!std::equal(lhs_op.getResultTypes().begin(),
76                     lhs_op.getResultTypes().end(),
77                     rhs_op.getResultTypes().begin()))
78       return false;
79   }
80 
81   return true;
82 }
83 
84 // Deduplicate the functions if all users are BatchFunctionOp and have the same
85 // shared_name.
86 //
87 // TODO(b/192463730): this is the short term solution and not needed anymore
88 // after the shape inference pass is revamped with ideal solution
89 // (b/192463730#comment11).
90 class DeduplicateFunctionsInovkedByBatchFunction
91     : public mlir::PassWrapper<DeduplicateFunctionsInovkedByBatchFunction,
92                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const93   llvm::StringRef getArgument() const final {
94     return "tfrt-deduplicate-functions-invoked-by-batch-function";
95   }
getDescription() const96   llvm::StringRef getDescription() const final {
97     return "Deduplicate the functions invoked by tf.BatchFunction with the "
98            "same shared_name";
99   }
runOnOperation()100   void runOnOperation() override {
101     if (failed(Run())) {
102       signalPassFailure();
103     }
104   }
105 
106   mlir::LogicalResult Run();
107 };
108 
Run()109 mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() {
110   ModuleOp module = getOperation();
111   SymbolTableCollection symbol_table_collection;
112   SymbolTable& symbol_table = symbol_table_collection.getSymbolTable(module);
113   SymbolUserMap symbol_users(symbol_table_collection, module);
114 
115   // Categorize the functions invoked by BatchFunctionOp by its shared_name.
116   llvm::StringMap<llvm::SmallVector<FuncOp, 2>> shared_name_to_func_ops;
117 
118   for (auto func : llvm::make_early_inc_range(module.getOps<FuncOp>())) {
119     ArrayRef<Operation*> users = symbol_users.getUsers(func);
120     llvm::StringRef shared_name;
121     // Deduplicate the function only if all users are BatchFunctionOp and have
122     // the same shared_name
123     if (!users.empty() && llvm::all_of(users, [&shared_name](Operation* user) {
124           auto op = llvm::dyn_cast_or_null<mlir::TF::BatchFunctionOp>(user);
125           // User is not a BatchFunctionOp
126           if (!op) return false;
127           if (shared_name.empty()) {
128             shared_name = op.shared_name();
129             return true;
130           }
131           return shared_name == op.shared_name();
132         })) {
133       shared_name_to_func_ops[shared_name].push_back(func);
134     }
135   }
136 
137   for (auto& it : shared_name_to_func_ops) {
138     auto& func_ops = it.second;
139     FuncOp& func_op_to_keep = func_ops.front();
140     for (FuncOp& func_op_to_remove : llvm::drop_begin(func_ops)) {
141       if (!AreEquivalent(func_op_to_keep, func_op_to_remove)) {
142         return func_op_to_remove.emitError(
143             "func_ops for BatchFunctionOp with the same shared name are "
144             "different");
145       }
146       if (failed(SymbolTable::replaceAllSymbolUses(
147               func_op_to_remove, func_op_to_keep.getName(), module))) {
148         return func_op_to_remove.emitError("unable to replace the symbol use");
149       }
150       symbol_table.erase(func_op_to_remove);
151     }
152   }
153 
154   return mlir::success();
155 }
156 }  // namespace
157 
158 std::unique_ptr<mlir::OperationPass<ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass()159 CreateDeduplicateFunctionsInovkedByBatchFunctionPass() {
160   return std::make_unique<DeduplicateFunctionsInovkedByBatchFunction>();
161 }
162 
163 static mlir::PassRegistration<DeduplicateFunctionsInovkedByBatchFunction>
164     register_pass;
165 
166 }  // namespace tfrt_compiler
167 }  // namespace tensorflow
168