• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/experimental/resource/static_hashtable.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
22 
23 namespace tflite {
24 namespace resource {
25 namespace internal {
26 
27 template <typename KeyType, typename ValueType>
Lookup(TfLiteContext * context,const TfLiteTensor * keys,TfLiteTensor * values,const TfLiteTensor * default_value)28 TfLiteStatus StaticHashtable<KeyType, ValueType>::Lookup(
29     TfLiteContext* context, const TfLiteTensor* keys, TfLiteTensor* values,
30     const TfLiteTensor* default_value) {
31   if (!is_initialized_) {
32     TF_LITE_KERNEL_LOG(context,
33                        "hashtable need to be initialized before using");
34     return kTfLiteError;
35   }
36   const int size =
37       MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
38 
39   auto key_tensor_reader = TensorReader<KeyType>(keys);
40   auto value_tensor_writer = TensorWriter<ValueType>(values);
41   auto default_value_tensor_reader = TensorReader<ValueType>(default_value);
42   ValueType first_default_value = default_value_tensor_reader.GetData(0);
43 
44   for (int i = 0; i < size; ++i) {
45     auto result = map_.find(key_tensor_reader.GetData(i));
46     if (result != map_.end()) {
47       value_tensor_writer.SetData(i, result->second);
48     } else {
49       value_tensor_writer.SetData(i, first_default_value);
50     }
51   }
52 
53   // This is for a string tensor case in order to write buffer back to the
54   // actual tensor destination. Otherwise, it does nothing since the scalar data
55   // will be written into the tensor storage directly.
56   value_tensor_writer.Commit();
57 
58   return kTfLiteOk;
59 }
60 
61 template <typename KeyType, typename ValueType>
Import(TfLiteContext * context,const TfLiteTensor * keys,const TfLiteTensor * values)62 TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
63     TfLiteContext* context, const TfLiteTensor* keys,
64     const TfLiteTensor* values) {
65   // Import nodes can be invoked twice because the converter will not extract
66   // the initializer graph separately from the original graph. The invocations
67   // after the first call will be ignored.
68   if (is_initialized_) {
69     return kTfLiteOk;
70   }
71 
72   const int size =
73       MatchingFlatSize(GetTensorShape(keys), GetTensorShape(values));
74 
75   auto key_tensor_reader = TensorReader<KeyType>(keys);
76   auto value_tensor_writer = TensorReader<ValueType>(values);
77   for (int i = 0; i < size; ++i) {
78     map_.insert({key_tensor_reader.GetData(i), value_tensor_writer.GetData(i)});
79   }
80 
81   is_initialized_ = true;
82   return kTfLiteOk;
83 }
84 
CreateStaticHashtable(TfLiteType key_type,TfLiteType value_type)85 LookupInterface* CreateStaticHashtable(TfLiteType key_type,
86                                        TfLiteType value_type) {
87   if (key_type == kTfLiteInt64 && value_type == kTfLiteString) {
88     return new StaticHashtable<std::int64_t, std::string>(key_type, value_type);
89   } else if (key_type == kTfLiteString && value_type == kTfLiteInt64) {
90     return new StaticHashtable<std::string, std::int64_t>(key_type, value_type);
91   }
92   return nullptr;
93 }
94 
95 }  // namespace internal
96 
CreateHashtableResourceIfNotAvailable(ResourceMap * resources,int resource_id,TfLiteType key_dtype,TfLiteType value_dtype)97 void CreateHashtableResourceIfNotAvailable(ResourceMap* resources,
98                                            int resource_id,
99                                            TfLiteType key_dtype,
100                                            TfLiteType value_dtype) {
101   if (resources->count(resource_id) != 0) {
102     return;
103   }
104   auto* hashtable = internal::CreateStaticHashtable(key_dtype, value_dtype);
105   resources->emplace(resource_id, std::unique_ptr<LookupInterface>(hashtable));
106 }
107 
GetHashtableResource(ResourceMap * resources,int resource_id)108 LookupInterface* GetHashtableResource(ResourceMap* resources, int resource_id) {
109   auto it = resources->find(resource_id);
110   if (it != resources->end()) {
111     return static_cast<LookupInterface*>(it->second.get());
112   }
113   return nullptr;
114 }
115 
116 }  // namespace resource
117 }  // namespace tflite
118