1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
23 #include "mlir/IR/MLIRContext.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33
34 namespace mlir {
35 namespace TFL {
36 namespace {
37 // This file has Legalize hash tables pass which is responsible for:
38 // - Converting static hash table ops to the TFLite equivalent ops.
39 //
40 // There are needs to fall back to Flex for the following cases:
41 // - Mutable hash table cases
42 // - Other resource operators consuming a hash table resource tensor
43
44 class LegalizeHashTableOpPattern : public OpRewritePattern<TF::HashTableV2Op> {
45 public:
46 using OpRewritePattern<TF::HashTableV2Op>::OpRewritePattern;
47
matchAndRewrite(TF::HashTableV2Op hashtable_op,PatternRewriter & rewriter) const48 LogicalResult matchAndRewrite(TF::HashTableV2Op hashtable_op,
49 PatternRewriter& rewriter) const override {
50 auto output_type = RankedTensorType::get(
51 {1}, TF::ResourceType::get(rewriter.getContext()));
52
53 // Hash the shared name to generate integer hash table id. The TFLite
54 // native resource design is based on integer keys to identify the
55 // corresponding resource objects.
56 auto table_id =
57 static_cast<int32_t>(::llvm::hash_value(hashtable_op.shared_name()));
58 auto key_dtype = hashtable_op.key_dtype();
59 auto value_dtype = hashtable_op.value_dtype();
60
61 rewriter.replaceOpWithNewOp<TFL::HashtableOp>(
62 hashtable_op, output_type, table_id, key_dtype, value_dtype);
63 return success();
64 }
65 };
66
67 class LegalizeHashTableFindOpPattern
68 : public OpRewritePattern<TF::LookupTableFindV2Op> {
69 public:
70 using OpRewritePattern<TF::LookupTableFindV2Op>::OpRewritePattern;
71
matchAndRewrite(TF::LookupTableFindV2Op find_op,PatternRewriter & rewriter) const72 LogicalResult matchAndRewrite(TF::LookupTableFindV2Op find_op,
73 PatternRewriter& rewriter) const override {
74 auto handle_op = find_op.table_handle().getDefiningOp();
75 if (handle_op == nullptr) return failure();
76 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
77 if (hashtable_op == nullptr) return failure();
78 rewriter.replaceOpWithNewOp<TFL::HashtableFindOp>(
79 find_op, find_op->getResultTypes(), find_op.table_handle(),
80 find_op.keys(), find_op.default_value());
81 return success();
82 }
83 };
84
85 class LegalizeHashTableImportOpPattern
86 : public OpRewritePattern<TF::LookupTableImportV2Op> {
87 public:
88 using OpRewritePattern<TF::LookupTableImportV2Op>::OpRewritePattern;
89
matchAndRewrite(TF::LookupTableImportV2Op import_op,PatternRewriter & rewriter) const90 LogicalResult matchAndRewrite(TF::LookupTableImportV2Op import_op,
91 PatternRewriter& rewriter) const override {
92 auto handle_op = import_op.table_handle().getDefiningOp();
93 if (handle_op == nullptr) return failure();
94 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
95 if (hashtable_op == nullptr) return failure();
96 rewriter.replaceOpWithNewOp<TFL::HashtableImportOp>(
97 import_op, import_op->getResultTypes(), import_op.table_handle(),
98 import_op.keys(), import_op.values());
99 return success();
100 }
101 };
102
103 class LegalizeHashTableSizeOpPattern
104 : public OpRewritePattern<TF::LookupTableSizeV2Op> {
105 public:
106 using OpRewritePattern<TF::LookupTableSizeV2Op>::OpRewritePattern;
107
matchAndRewrite(TF::LookupTableSizeV2Op size_op,PatternRewriter & rewriter) const108 LogicalResult matchAndRewrite(TF::LookupTableSizeV2Op size_op,
109 PatternRewriter& rewriter) const override {
110 auto handle_op = size_op.table_handle().getDefiningOp();
111 if (handle_op == nullptr) return failure();
112 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
113 if (hashtable_op == nullptr) return failure();
114 rewriter.replaceOpWithNewOp<TFL::HashtableSizeOp>(
115 size_op, size_op->getResultTypes(), size_op.table_handle());
116 return success();
117 }
118 };
119
120 template <typename T>
GetAllOps(mlir::ModuleOp * module)121 std::vector<T> GetAllOps(mlir::ModuleOp* module) {
122 std::vector<T> ops;
123 module->walk([&](T op) { ops.emplace_back(op); });
124 return ops;
125 }
126
checkWhetherGraphHasValidStaticLookupTables(ModuleOp module)127 bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) {
128 auto hashtables = GetAllOps<TF::HashTableV2Op>(&module);
129 // No needs to run the legalization patterns.
130 if (hashtables.empty()) {
131 return false;
132 }
133
134 for (auto hashtable : hashtables) {
135 auto key_dtype = hashtable.key_dtype();
136 auto value_dtype = hashtable.value_dtype();
137
138 // Only allow string -> int64 and int64 -> string mappings due to kernel
139 // capability.
140 if (!((key_dtype.isa<TF::StringType>() && value_dtype.isa<IntegerType>() &&
141 value_dtype.cast<IntegerType>().getWidth() == 64) ||
142 (value_dtype.isa<TF::StringType>() && key_dtype.isa<IntegerType>() &&
143 key_dtype.cast<IntegerType>().getWidth() == 64))) {
144 return false;
145 }
146
147 for (auto& use : hashtable->getUses()) {
148 Operation* user = use.getOwner();
149
150 // Allow consuming hash table ops that can be covered by TensorFlow Lite
151 // hash table kernels.
152 if (auto find_op = llvm::dyn_cast<TF::LookupTableFindV2Op>(user))
153 continue;
154 if (auto import_op = llvm::dyn_cast<TF::LookupTableImportV2Op>(user))
155 continue;
156 if (auto size_op = llvm::dyn_cast<TF::LookupTableSizeV2Op>(user))
157 continue;
158
159 return false;
160 }
161 }
162 return true;
163 }
164
165 // Pass which legalizes TF hash tables only when they are covered by the
166 // TensorFlow Lite hash table kernels.
167 class LegalizeHashTables
168 : public PassWrapper<LegalizeHashTables, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const169 void getDependentDialects(DialectRegistry& registry) const override {
170 registry.insert<TensorFlowLiteDialect>();
171 }
172
173 public:
174 LegalizeHashTables() = default;
LegalizeHashTables(const LegalizeHashTables &)175 LegalizeHashTables(const LegalizeHashTables&) {}
176
getArgument() const177 StringRef getArgument() const final {
178 // This is the argument used to refer to the pass in
179 // the textual format (on the commandline for example).
180 return "tfl-legalize-hashtables-tf";
181 }
getDescription() const182 StringRef getDescription() const final {
183 // This is a brief description of the pass.
184 return "Legalize TensorFlow hash tables to TensorFlow Lite dialect";
185 }
186
runOnOperation()187 void runOnOperation() override {
188 auto module = getOperation();
189
190 if (!checkWhetherGraphHasValidStaticLookupTables(module)) {
191 return;
192 }
193
194 OwningRewritePatternList patterns(&getContext());
195 patterns.insert<LegalizeHashTableOpPattern, LegalizeHashTableFindOpPattern,
196 LegalizeHashTableImportOpPattern,
197 LegalizeHashTableSizeOpPattern>(&getContext());
198 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
199 signalPassFailure();
200 return;
201 }
202 }
203 };
204
205 } // namespace
206
CreateLegalizeHashTablesPass()207 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHashTablesPass() {
208 return std::make_unique<LegalizeHashTables>();
209 }
210
211 static PassRegistration<LegalizeHashTables> pass;
212
213 } // namespace TFL
214 } // namespace mlir
215