1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
18
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/core/framework/bounds_check.h"
21 #include "tensorflow/core/framework/lookup_interface.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/kernels/lookup_util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34
35 namespace tensorflow {
36
37 // Lookup table op that supports different table implementations specified by
38 // the 'Container' template. Container must be derived from LookupInterface. The
39 // key and value are of the templated type "key_dtype" and "value_dtype"
40 // respectively.
41 template <class Container, class key_dtype, class value_dtype>
42 class LookupTableOp : public OpKernel {
43 public:
44 // ctx is not owned by this class.
LookupTableOp(OpKernelConstruction * ctx)45 explicit LookupTableOp(OpKernelConstruction* ctx)
46 : OpKernel(ctx), table_set_(false) {
47 if (ctx->output_type(0) == DT_RESOURCE) {
48 OP_REQUIRES_OK(ctx,
49 ctx->allocate_temp(tensorflow::DT_RESOURCE,
50 tensorflow::TensorShape({}), &table_));
51 } else {
52 OP_REQUIRES_OK(ctx,
53 ctx->allocate_temp(tensorflow::DT_STRING,
54 tensorflow::TensorShape({2}), &table_));
55 }
56 OP_REQUIRES_OK(
57 ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
58 }
59
60 // ctx is not owned by this function.
Compute(OpKernelContext * ctx)61 void Compute(OpKernelContext* ctx) override {
62 mutex_lock l(mu_);
63
64 if (!table_set_) {
65 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
66 use_node_name_sharing_));
67 }
68
69 auto creator =
70 [ctx, this](lookup::LookupInterface** ret)
71 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
72 lookup::LookupInterface* container = new Container(ctx, this);
73 if (!ctx->status().ok()) {
74 container->Unref();
75 return ctx->status();
76 }
77 if (ctx->track_allocations()) {
78 ctx->record_persistent_memory_allocation(
79 container->MemoryUsed() + table_.AllocatedBytes());
80 }
81 *ret = container;
82 return Status::OK();
83 };
84
85 lookup::LookupInterface* table = nullptr;
86 OP_REQUIRES_OK(ctx,
87 cinfo_.resource_manager()
88 ->template LookupOrCreate<lookup::LookupInterface>(
89 cinfo_.container(), cinfo_.name(), &table, creator));
90 core::ScopedUnref unref_me(table);
91
92 OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
93 *table, DataTypeToEnum<key_dtype>::v(),
94 DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
95
96 if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
97 if (!table_set_) {
98 auto h = table_.template scalar<ResourceHandle>();
99 h() = MakeResourceHandle<lookup::LookupInterface>(
100 ctx, cinfo_.container(), cinfo_.name());
101 }
102 ctx->set_output(0, table_);
103 } else {
104 if (!table_set_) {
105 auto h = table_.template flat<tstring>();
106 h(0) = cinfo_.container();
107 h(1) = cinfo_.name();
108 }
109 ctx->set_output_ref(0, &mu_, &table_);
110 }
111 table_set_ = true;
112 }
113
~LookupTableOp()114 ~LookupTableOp() override {
115 // If the table object was not shared, delete it.
116 if (table_set_ && cinfo_.resource_is_private_to_kernel()) {
117 if (!cinfo_.resource_manager()
118 ->template Delete<lookup::LookupInterface>(cinfo_.container(),
119 cinfo_.name())
120 .ok()) {
121 // Do nothing; the resource can have been deleted by session resets.
122 }
123 }
124 }
125
126 private:
127 mutex mu_;
128 Tensor table_ TF_GUARDED_BY(mu_);
129 bool table_set_ TF_GUARDED_BY(mu_);
130 ContainerInfo cinfo_;
131 bool use_node_name_sharing_;
132
133 TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
134 };
135
136 namespace lookup {
137
138 // Ensure that the compiler cannot elide a copy into a local, for
139 // bounds checking on source tensors that might be updated asynchronously for
140 // integral types. However non-integer variables are not allowed and therefore
141 // the local copy is unnecessary.
142 template <typename T>
SubtleMustCopyIfIntegral(const T & value)143 T SubtleMustCopyIfIntegral(const T& value) {
144 return internal::SubtleMustCopy(value);
145 }
146
SubtleMustCopyIfIntegral(const tstring & value)147 inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) {
148 return value;
149 }
150
SubtleMustCopyIfIntegral(const float value)151 inline const float SubtleMustCopyIfIntegral(const float value) { return value; }
152
SubtleMustCopyIfIntegral(const double value)153 inline const double SubtleMustCopyIfIntegral(const double value) {
154 return value;
155 }
156
SubtleMustCopyIfIntegral(const Variant & value)157 inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) {
158 return value;
159 }
160
SubtleMustCopyIfIntegral(const ResourceHandle & value)161 inline const ResourceHandle& SubtleMustCopyIfIntegral(
162 const ResourceHandle& value) {
163 return value;
164 }
165
166 // Returns a unique node name starting with "base".
167 std::string UniqueNodeName(const std::string& base);
168
169 // Lookup table that wraps an flat_hash_map, where the key and value data type
170 // is specified.
171 //
172 // This table is recommended for any variations to key values.
173 //
174 // For look up, the table is required to be initialized (allocated
175 // and populated). Once the table is marked as initialized it becomes read-only.
176 //
177 // Sample use case:
178 //
179 // HashTable<int64, int64> table; // int64 -> int64.
180 // table.Initialize(...);
181 // table.Find(in_t, &out_t, default_t)
182 //
183 template <class K, class V>
184 class HashTable : public InitializableLookupTable {
185 public:
HashTable(OpKernelContext * ctx,OpKernel * kernel)186 HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
187
AsGraphDef(GraphDefBuilder * builder,Node ** out)188 Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override {
189 // We set use_node_name_sharing with a unique node name so that the resource
190 // can outlive the HashTableV2 kernel. This means that the lifetime of the
191 // HashTable resource will be tied to the lifetime of the resource manager
192 // it is created in.
193 // TODO(b/181695913): Provide a mechanism for deleting this resource
194 // earlier when appropriate.
195 Node* hash_table_node = ops::SourceOp(
196 "HashTableV2", builder->opts()
197 .WithName(UniqueNodeName("HashTableFromGraphDef"))
198 .WithAttr("key_dtype", key_dtype())
199 .WithAttr("value_dtype", value_dtype())
200 .WithAttr("use_node_name_sharing", true));
201 if (table_.empty()) {
202 *out = hash_table_node;
203 return Status::OK();
204 }
205
206 if (initializer_serializer_ == nullptr) {
207 std::string message =
208 "Failed to serialize lookup table: no initialization function was "
209 "specified. Falling back to serializing a handle to the table.";
210 LOG(WARNING) << message;
211 return errors::Unimplemented(message);
212 }
213 Node* initializer;
214 TF_RETURN_IF_ERROR(initializer_serializer_->AsGraphDef(
215 builder, hash_table_node, &initializer));
216 *out = ops::UnaryOp("Identity", hash_table_node,
217 builder->opts().WithControlInput(initializer));
218 return Status::OK();
219 }
220
size()221 size_t size() const override {
222 if (!is_initialized())
223 return 0;
224 else
225 return table_.size();
226 }
227
ExportValues(OpKernelContext * context)228 Status ExportValues(OpKernelContext* context) override {
229 if (!is_initialized()) {
230 return errors::Aborted("HashTable is not initialized.");
231 }
232
233 const int64_t size = table_.size();
234
235 Tensor* keys;
236 Tensor* values;
237 TF_RETURN_IF_ERROR(
238 context->allocate_output("keys", TensorShape({size}), &keys));
239 TF_RETURN_IF_ERROR(
240 context->allocate_output("values", TensorShape({size}), &values));
241
242 auto keys_data = keys->flat<K>();
243 auto values_data = values->flat<V>();
244 int64_t i = 0;
245 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
246 keys_data(i) = it->first;
247 values_data(i) = it->second;
248 }
249 return Status::OK();
250 }
251
key_dtype()252 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
253
value_dtype()254 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
255
256 protected:
DoPrepare(size_t size)257 Status DoPrepare(size_t size) override {
258 if (is_initialized()) {
259 return errors::Aborted("HashTable already initialized.");
260 }
261 if (size > 0) {
262 table_.reserve(size);
263 }
264 return Status::OK();
265 };
266
DoLazyPrepare(std::function<int64 (void)> size_fn)267 Status DoLazyPrepare(std::function<int64(void)> size_fn) override {
268 return DoPrepare(size_fn());
269 }
270
DoInsert(const Tensor & keys,const Tensor & values)271 Status DoInsert(const Tensor& keys, const Tensor& values) override {
272 const auto key_values = keys.flat<K>();
273 const auto value_values = values.flat<V>();
274 for (int64_t i = 0; i < key_values.size(); ++i) {
275 auto&& key = SubtleMustCopyIfIntegral(key_values(i));
276 auto&& value = SubtleMustCopyIfIntegral(value_values(i));
277 auto result = table_.try_emplace(key, value);
278 if (!result.second && result.first->second != value) {
279 return errors::FailedPrecondition(
280 "HashTable has different value for same key. Key ", key, " has ",
281 result.first->second, " and trying to add value ", value);
282 }
283 }
284 return Status::OK();
285 }
286
DoFind(const Tensor & key,Tensor * value,const Tensor & default_value)287 Status DoFind(const Tensor& key, Tensor* value,
288 const Tensor& default_value) override {
289 const V default_val = default_value.flat<V>()(0);
290 const auto key_values = key.flat<K>();
291 auto value_values = value->flat<V>();
292
293 for (int64_t i = 0; i < key_values.size(); ++i) {
294 value_values(i) = gtl::FindWithDefault(
295 table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
296 }
297 return Status::OK();
298 }
299
MemoryUsed()300 int64 MemoryUsed() const override {
301 if (!is_initialized()) {
302 return 0;
303 }
304 const int64_t num_elements = table_.size();
305 return num_elements * (sizeof(K) + sizeof(V));
306 }
307
308 private:
309 absl::flat_hash_map<K, V> table_;
310 };
311
312 } // namespace lookup
313
314 } // namespace tensorflow
315
316 #endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
317