• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33 
34 namespace mlir {
35 namespace TFL {
36 namespace {
37 // This file has Legalize hash tables pass which is responsible for:
38 // - Converting static hash table ops to the TFLite equivalent ops.
39 //
40 // There are needs to fall back to Flex for the following cases:
41 // - Mutable hash table cases
42 // - Other resource operators consuming a hash table resource tensor
43 
44 class LegalizeHashTableOpPattern : public OpRewritePattern<TF::HashTableV2Op> {
45  public:
46   using OpRewritePattern<TF::HashTableV2Op>::OpRewritePattern;
47 
matchAndRewrite(TF::HashTableV2Op hashtable_op,PatternRewriter & rewriter) const48   LogicalResult matchAndRewrite(TF::HashTableV2Op hashtable_op,
49                                 PatternRewriter& rewriter) const override {
50     auto output_type = RankedTensorType::get(
51         {1}, TF::ResourceType::get(rewriter.getContext()));
52 
53     // Hash the shared name to generate integer hash table id. The TFLite
54     // native resource design is based on integer keys to identify the
55     // corresponding resource objects.
56     auto table_id =
57         static_cast<int32_t>(::llvm::hash_value(hashtable_op.shared_name()));
58     auto key_dtype = hashtable_op.key_dtype();
59     auto value_dtype = hashtable_op.value_dtype();
60 
61     rewriter.replaceOpWithNewOp<TFL::HashtableOp>(
62         hashtable_op, output_type, table_id, key_dtype, value_dtype);
63     return success();
64   }
65 };
66 
67 class LegalizeHashTableFindOpPattern
68     : public OpRewritePattern<TF::LookupTableFindV2Op> {
69  public:
70   using OpRewritePattern<TF::LookupTableFindV2Op>::OpRewritePattern;
71 
matchAndRewrite(TF::LookupTableFindV2Op find_op,PatternRewriter & rewriter) const72   LogicalResult matchAndRewrite(TF::LookupTableFindV2Op find_op,
73                                 PatternRewriter& rewriter) const override {
74     auto handle_op = find_op.table_handle().getDefiningOp();
75     if (handle_op == nullptr) return failure();
76     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
77     if (hashtable_op == nullptr) return failure();
78     rewriter.replaceOpWithNewOp<TFL::HashtableFindOp>(
79         find_op, find_op->getResultTypes(), find_op.table_handle(),
80         find_op.keys(), find_op.default_value());
81     return success();
82   }
83 };
84 
85 class LegalizeHashTableImportOpPattern
86     : public OpRewritePattern<TF::LookupTableImportV2Op> {
87  public:
88   using OpRewritePattern<TF::LookupTableImportV2Op>::OpRewritePattern;
89 
matchAndRewrite(TF::LookupTableImportV2Op import_op,PatternRewriter & rewriter) const90   LogicalResult matchAndRewrite(TF::LookupTableImportV2Op import_op,
91                                 PatternRewriter& rewriter) const override {
92     auto handle_op = import_op.table_handle().getDefiningOp();
93     if (handle_op == nullptr) return failure();
94     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
95     if (hashtable_op == nullptr) return failure();
96     rewriter.replaceOpWithNewOp<TFL::HashtableImportOp>(
97         import_op, import_op->getResultTypes(), import_op.table_handle(),
98         import_op.keys(), import_op.values());
99     return success();
100   }
101 };
102 
103 class LegalizeHashTableSizeOpPattern
104     : public OpRewritePattern<TF::LookupTableSizeV2Op> {
105  public:
106   using OpRewritePattern<TF::LookupTableSizeV2Op>::OpRewritePattern;
107 
matchAndRewrite(TF::LookupTableSizeV2Op size_op,PatternRewriter & rewriter) const108   LogicalResult matchAndRewrite(TF::LookupTableSizeV2Op size_op,
109                                 PatternRewriter& rewriter) const override {
110     auto handle_op = size_op.table_handle().getDefiningOp();
111     if (handle_op == nullptr) return failure();
112     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
113     if (hashtable_op == nullptr) return failure();
114     rewriter.replaceOpWithNewOp<TFL::HashtableSizeOp>(
115         size_op, size_op->getResultTypes(), size_op.table_handle());
116     return success();
117   }
118 };
119 
120 template <typename T>
GetAllOps(mlir::ModuleOp * module)121 std::vector<T> GetAllOps(mlir::ModuleOp* module) {
122   std::vector<T> ops;
123   module->walk([&](T op) { ops.emplace_back(op); });
124   return ops;
125 }
126 
checkWhetherGraphHasValidStaticLookupTables(ModuleOp module)127 bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) {
128   auto hashtables = GetAllOps<TF::HashTableV2Op>(&module);
129   // No needs to run the legalization patterns.
130   if (hashtables.empty()) {
131     return false;
132   }
133 
134   for (auto hashtable : hashtables) {
135     auto key_dtype = hashtable.key_dtype();
136     auto value_dtype = hashtable.value_dtype();
137 
138     // Only allow string -> int64 and int64 -> string mappings due to kernel
139     // capability.
140     if (!((key_dtype.isa<TF::StringType>() && value_dtype.isa<IntegerType>() &&
141            value_dtype.cast<IntegerType>().getWidth() == 64) ||
142           (value_dtype.isa<TF::StringType>() && key_dtype.isa<IntegerType>() &&
143            key_dtype.cast<IntegerType>().getWidth() == 64))) {
144       return false;
145     }
146 
147     for (auto& use : hashtable->getUses()) {
148       Operation* user = use.getOwner();
149 
150       // Allow consuming hash table ops that can be covered by TensorFlow Lite
151       // hash table kernels.
152       if (auto find_op = llvm::dyn_cast<TF::LookupTableFindV2Op>(user))
153         continue;
154       if (auto import_op = llvm::dyn_cast<TF::LookupTableImportV2Op>(user))
155         continue;
156       if (auto size_op = llvm::dyn_cast<TF::LookupTableSizeV2Op>(user))
157         continue;
158 
159       return false;
160     }
161   }
162   return true;
163 }
164 
165 // Pass which legalizes TF hash tables only when they are covered by the
166 // TensorFlow Lite hash table kernels.
167 class LegalizeHashTables
168     : public PassWrapper<LegalizeHashTables, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const169   void getDependentDialects(DialectRegistry& registry) const override {
170     registry.insert<TensorFlowLiteDialect>();
171   }
172 
173  public:
174   LegalizeHashTables() = default;
LegalizeHashTables(const LegalizeHashTables &)175   LegalizeHashTables(const LegalizeHashTables&) {}
176 
getArgument() const177   StringRef getArgument() const final {
178     // This is the argument used to refer to the pass in
179     // the textual format (on the commandline for example).
180     return "tfl-legalize-hashtables-tf";
181   }
getDescription() const182   StringRef getDescription() const final {
183     // This is a brief description of the pass.
184     return "Legalize TensorFlow hash tables to TensorFlow Lite dialect";
185   }
186 
runOnOperation()187   void runOnOperation() override {
188     auto module = getOperation();
189 
190     if (!checkWhetherGraphHasValidStaticLookupTables(module)) {
191       return;
192     }
193 
194     OwningRewritePatternList patterns(&getContext());
195     patterns.insert<LegalizeHashTableOpPattern, LegalizeHashTableFindOpPattern,
196                     LegalizeHashTableImportOpPattern,
197                     LegalizeHashTableSizeOpPattern>(&getContext());
198     if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
199       signalPassFailure();
200       return;
201     }
202   }
203 };
204 
205 }  // namespace
206 
CreateLegalizeHashTablesPass()207 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHashTablesPass() {
208   return std::make_unique<LegalizeHashTables>();
209 }
210 
211 static PassRegistration<LegalizeHashTables> pass;
212 
213 }  // namespace TFL
214 }  // namespace mlir
215