1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/lookup_table_op.h"
17 #define EIGEN_USE_THREADS
18
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/variant.h"
26 #include "tensorflow/core/kernels/initializable_lookup_table.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29
30 namespace tensorflow {
31 namespace lookup {
32
33 // Lookup table that wraps an unordered_map, where the key and value data type
34 // is specified. Each individual value must be a scalar. If vector values are
35 // required, use MutableHashTableOfTensors.
36 //
37 // This table is mutable and thread safe - Insert can be called at any time.
38 //
39 // Sample use case:
40 //
41 // MutableHashTableOfScalars<int64, int64> table; // int64 -> int64.
42 // // Populate the table, elements could be added in one or multiple calls.
43 // table.Insert(key_tensor, value_tensor); // Populate the table.
44 //
45 // table.Find(in_t, &out_t, default_t)
46 //
47 template <class K, class V>
48 class MutableHashTableOfScalars final : public LookupInterface {
49 public:
MutableHashTableOfScalars(OpKernelContext * ctx,OpKernel * kernel)50 MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
51
size() const52 size_t size() const override {
53 tf_shared_lock l(mu_);
54 return table_.size();
55 }
56
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)57 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
58 const Tensor& default_value) override {
59 const V default_val = default_value.flat<V>()(0);
60 const auto key_values = key.flat<K>();
61 auto value_values = value->flat<V>();
62
63 tf_shared_lock l(mu_);
64 for (int64 i = 0; i < key_values.size(); ++i) {
65 value_values(i) = gtl::FindWithDefault(
66 table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
67 }
68
69 return Status::OK();
70 }
71
DoInsert(bool clear,const Tensor & keys,const Tensor & values)72 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
73 const auto key_values = keys.flat<K>();
74 const auto value_values = values.flat<V>();
75
76 mutex_lock l(mu_);
77 if (clear) {
78 table_.clear();
79 }
80 for (int64 i = 0; i < key_values.size(); ++i) {
81 gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
82 SubtleMustCopyIfIntegral(value_values(i)));
83 }
84 return Status::OK();
85 }
86
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)87 Status Insert(OpKernelContext* ctx, const Tensor& keys,
88 const Tensor& values) override {
89 return DoInsert(false, keys, values);
90 }
91
Remove(OpKernelContext * ctx,const Tensor & keys)92 Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
93 const auto key_values = keys.flat<K>();
94
95 mutex_lock l(mu_);
96 for (int64 i = 0; i < key_values.size(); ++i) {
97 table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
98 }
99 return Status::OK();
100 }
101
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)102 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
103 const Tensor& values) override {
104 return DoInsert(true, keys, values);
105 }
106
ExportValues(OpKernelContext * ctx)107 Status ExportValues(OpKernelContext* ctx) override {
108 tf_shared_lock l(mu_);
109 int64 size = table_.size();
110
111 Tensor* keys;
112 Tensor* values;
113 TF_RETURN_IF_ERROR(
114 ctx->allocate_output("keys", TensorShape({size}), &keys));
115 TF_RETURN_IF_ERROR(
116 ctx->allocate_output("values", TensorShape({size}), &values));
117
118 auto keys_data = keys->flat<K>();
119 auto values_data = values->flat<V>();
120 int64 i = 0;
121 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
122 keys_data(i) = it->first;
123 values_data(i) = it->second;
124 }
125 return Status::OK();
126 }
127
key_dtype() const128 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
129
value_dtype() const130 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
131
key_shape() const132 TensorShape key_shape() const final { return TensorShape(); }
133
value_shape() const134 TensorShape value_shape() const override { return TensorShape(); }
135
MemoryUsed() const136 int64 MemoryUsed() const override {
137 int64 ret = 0;
138 tf_shared_lock l(mu_);
139 for (unsigned i = 0; i < table_.bucket_count(); ++i) {
140 size_t bucket_size = table_.bucket_size(i);
141 if (bucket_size == 0) {
142 ret++;
143 } else {
144 ret += bucket_size;
145 }
146 }
147 return sizeof(MutableHashTableOfScalars) + ret;
148 }
149
150 private:
151 mutable mutex mu_;
152 std::unordered_map<K, V> table_ GUARDED_BY(mu_);
153 };
154
155 // Lookup table that wraps an unordered_map. Behaves identical to
156 // MutableHashTableOfScalars except that each value must be a vector.
157 template <class K, class V>
158 class MutableHashTableOfTensors final : public LookupInterface {
159 public:
MutableHashTableOfTensors(OpKernelContext * ctx,OpKernel * kernel)160 MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
161 OP_REQUIRES_OK(ctx,
162 GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
163 OP_REQUIRES(
164 ctx, TensorShapeUtils::IsVector(value_shape_),
165 errors::InvalidArgument("Default value must be a vector, got shape ",
166 value_shape_.DebugString()));
167 }
168
size() const169 size_t size() const override {
170 tf_shared_lock l(mu_);
171 return table_.size();
172 }
173
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)174 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
175 const Tensor& default_value) override {
176 const auto default_flat = default_value.flat<V>();
177 const auto key_values = key.flat<K>();
178 auto value_values = value->flat_inner_dims<V, 2>();
179 int64 value_dim = value_shape_.dim_size(0);
180
181 tf_shared_lock l(mu_);
182 for (int64 i = 0; i < key_values.size(); ++i) {
183 ValueArray* value_vec =
184 gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
185 if (value_vec != nullptr) {
186 for (int64 j = 0; j < value_dim; j++) {
187 value_values(i, j) = value_vec->at(j);
188 }
189 } else {
190 for (int64 j = 0; j < value_dim; j++) {
191 value_values(i, j) = default_flat(j);
192 }
193 }
194 }
195
196 return Status::OK();
197 }
198
DoInsert(bool clear,const Tensor & keys,const Tensor & values)199 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
200 const auto key_values = keys.flat<K>();
201 const auto value_values = values.flat_inner_dims<V, 2>();
202 int64 value_dim = value_shape_.dim_size(0);
203
204 mutex_lock l(mu_);
205 if (clear) {
206 table_.clear();
207 }
208 for (int64 i = 0; i < key_values.size(); ++i) {
209 ValueArray value_vec;
210 for (int64 j = 0; j < value_dim; j++) {
211 V value = value_values(i, j);
212 value_vec.push_back(value);
213 }
214 gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
215 value_vec);
216 }
217 return Status::OK();
218 }
219
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)220 Status Insert(OpKernelContext* ctx, const Tensor& keys,
221 const Tensor& values) override {
222 return DoInsert(false, keys, values);
223 }
224
Remove(OpKernelContext * ctx,const Tensor & keys)225 Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
226 const auto key_values = keys.flat<K>();
227
228 mutex_lock l(mu_);
229 for (int64 i = 0; i < key_values.size(); ++i) {
230 table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
231 }
232 return Status::OK();
233 }
234
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)235 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
236 const Tensor& values) override {
237 return DoInsert(true, keys, values);
238 }
239
ExportValues(OpKernelContext * ctx)240 Status ExportValues(OpKernelContext* ctx) override {
241 tf_shared_lock l(mu_);
242 int64 size = table_.size();
243 int64 value_dim = value_shape_.dim_size(0);
244
245 Tensor* keys;
246 Tensor* values;
247 TF_RETURN_IF_ERROR(
248 ctx->allocate_output("keys", TensorShape({size}), &keys));
249 TF_RETURN_IF_ERROR(ctx->allocate_output(
250 "values", TensorShape({size, value_dim}), &values));
251
252 auto keys_data = keys->flat<K>();
253 auto values_data = values->matrix<V>();
254 int64 i = 0;
255 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
256 K key = it->first;
257 ValueArray value = it->second;
258 keys_data(i) = key;
259 for (int64 j = 0; j < value_dim; j++) {
260 values_data(i, j) = value[j];
261 }
262 }
263 return Status::OK();
264 }
265
key_dtype() const266 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
267
value_dtype() const268 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
269
key_shape() const270 TensorShape key_shape() const final { return TensorShape(); }
271
value_shape() const272 TensorShape value_shape() const override { return value_shape_; }
273
MemoryUsed() const274 int64 MemoryUsed() const override {
275 int64 ret = 0;
276 tf_shared_lock l(mu_);
277 for (unsigned i = 0; i < table_.bucket_count(); ++i) {
278 size_t bucket_size = table_.bucket_size(i);
279 if (bucket_size == 0) {
280 ret++;
281 } else {
282 ret += bucket_size;
283 }
284 }
285 return sizeof(MutableHashTableOfTensors) + ret;
286 }
287
288 private:
289 TensorShape value_shape_;
290 mutable mutex mu_;
291 typedef gtl::InlinedVector<V, 4> ValueArray;
292 std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
293 };
294
295 namespace {
296
297 template <typename T>
HashScalar(const T & key)298 inline uint64 HashScalar(const T& key) {
299 return static_cast<uint64>(key);
300 }
301
HashScalar(const string & key)302 inline uint64 HashScalar(const string& key) { return Hash64(key); }
303
304 // If the given shape is a scalar return {1} instead. Otherwise leave it alone.
MaybeVectorizeShape(const TensorShape & shape)305 TensorShape MaybeVectorizeShape(const TensorShape& shape) {
306 if (shape.dims() == 0) {
307 return TensorShape({1});
308 }
309 return shape;
310 }
311
312 } // namespace
313
314 // Modeled after densehashtable in https://github.com/sparsehash/sparsehash
315 template <class K, class V>
316 class MutableDenseHashTable final : public LookupInterface {
317 public:
MutableDenseHashTable(OpKernelContext * ctx,OpKernel * kernel)318 MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) {
319 OP_REQUIRES_OK(
320 ctx, GetNodeAttr(kernel->def(), "max_load_factor", &max_load_factor_));
321 OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1,
322 errors::InvalidArgument(
323 "max_load_factor must be between 0 and 1, got: ",
324 max_load_factor_));
325
326 OP_REQUIRES_OK(ctx,
327 GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
328 OP_REQUIRES(ctx,
329 TensorShapeUtils::IsScalar(value_shape_) ||
330 TensorShapeUtils::IsVector(value_shape_),
331 errors::InvalidArgument(
332 "Empty value must be a scalar or a vector, got shape ",
333 value_shape_.DebugString()));
334
335 const Tensor* empty_key_input;
336 OP_REQUIRES_OK(ctx, ctx->input("empty_key", &empty_key_input));
337 key_shape_ = empty_key_input->shape();
338 OP_REQUIRES(ctx,
339 TensorShapeUtils::IsScalar(key_shape_) ||
340 TensorShapeUtils::IsVector(key_shape_),
341 errors::InvalidArgument(
342 "Empty key must be a scalar or a vector, got shape ",
343 key_shape_.DebugString()));
344 empty_key_ = PersistentTensor(*empty_key_input);
345 empty_key_hash_ = HashKey(
346 empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
347 0);
348
349 const Tensor* deleted_key_input;
350 OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
351 OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
352 errors::InvalidArgument(
353 "Empty and deleted keys must have same shape, got shapes: ",
354 key_shape_.DebugString(), " and ",
355 deleted_key_input->shape().DebugString()));
356 deleted_key_ = PersistentTensor(*deleted_key_input);
357 deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
358 {1, key_shape_.num_elements()}),
359 0);
360
361 if (empty_key_hash_ == deleted_key_hash_) {
362 const int64 key_size = key_shape_.num_elements();
363 const auto empty_key_matrix =
364 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
365 const auto deleted_key_matrix =
366 deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
367 OP_REQUIRES(
368 ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
369 errors::InvalidArgument("Empty and deleted keys cannot be equal"));
370 }
371
372 int64 initial_num_buckets;
373 OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
374 &initial_num_buckets));
375 OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets));
376 }
377
size() const378 size_t size() const override LOCKS_EXCLUDED(mu_) {
379 tf_shared_lock l(mu_);
380 return num_entries_;
381 }
382
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)383 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
384 const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
385 const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
386 const int64 key_size = key_shape_.num_elements();
387 const int64 value_size = value_shape_.num_elements();
388 if (key.NumElements() != num_elements * key_size) {
389 TensorShape expected_shape({num_elements});
390 expected_shape.AppendShape(key_shape_);
391 return errors::InvalidArgument("Expected key shape ",
392 expected_shape.DebugString(), " got ",
393 key.shape().DebugString());
394 }
395 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
396 auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
397 const auto default_flat = default_value.flat<V>();
398
399 tf_shared_lock l(mu_);
400 const auto key_buckets_matrix =
401 key_buckets_.AccessTensor(ctx)->template matrix<K>();
402 const auto value_buckets_matrix =
403 value_buckets_.AccessTensor(ctx)->template matrix<V>();
404 const auto empty_key_matrix =
405 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
406 const auto deleted_key_matrix =
407 deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
408 const int64 bit_mask = num_buckets_ - 1;
409 // TODO(andreasst): parallelize using work_sharder
410 for (int64 i = 0; i < num_elements; ++i) {
411 const uint64 key_hash = HashKey(key_matrix, i);
412 if (empty_key_hash_ == key_hash &&
413 IsEqualKey(empty_key_matrix, 0, key_matrix, i)) {
414 return errors::InvalidArgument(
415 "Using the empty_key as a table key is not allowed");
416 }
417 if (deleted_key_hash_ == key_hash &&
418 IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
419 return errors::InvalidArgument(
420 "Using the deleted_key as a table key is not allowed");
421 }
422 int64 bucket_index = key_hash & bit_mask;
423 int64 num_probes = 0;
424 while (true) {
425 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
426 for (int64 j = 0; j < value_size; ++j) {
427 // TODO(andreasst): check if we can get rid of SubtleMustCopy
428 // here and elsewhere in this file.
429 value_matrix(i, j) =
430 SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j));
431 }
432 break;
433 }
434 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) {
435 for (int64 j = 0; j < value_size; ++j) {
436 value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j));
437 }
438 break;
439 }
440 ++num_probes;
441 bucket_index =
442 (bucket_index + num_probes) & bit_mask; // quadratic probing
443 if (num_probes >= num_buckets_) {
444 return errors::Internal(
445 "Internal error in MutableDenseHashTable lookup");
446 }
447 }
448 }
449 return Status::OK();
450 }
451
Insert(OpKernelContext * ctx,const Tensor & key,const Tensor & value)452 Status Insert(OpKernelContext* ctx, const Tensor& key,
453 const Tensor& value) override LOCKS_EXCLUDED(mu_) {
454 const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
455 if (key.NumElements() != batch_size * key_shape_.num_elements()) {
456 TensorShape expected_shape({batch_size});
457 expected_shape.AppendShape(key_shape_);
458 return errors::InvalidArgument("Expected key shape ",
459 expected_shape.DebugString(), " got ",
460 key.shape().DebugString());
461 }
462 mutex_lock l(mu_);
463 // For simplicity we assume that all keys in the input result in inserts
464 // rather than updates. That means we may grow the table even though we
465 // don't need to. As long as the number of keys inserted in one call is
466 // small compared to the size of the map, the impact of this is minimal.
467 const int64 pending_num_entries = num_entries_ + batch_size;
468 if (pending_num_entries > num_buckets_ * max_load_factor_) {
469 int64 new_num_buckets = num_buckets_;
470 do {
471 new_num_buckets <<= 1;
472 } while (pending_num_entries > new_num_buckets * max_load_factor_);
473 TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets));
474 }
475 return DoInsert(ctx, key, value, false);
476 }
477
Remove(OpKernelContext * ctx,const Tensor & key)478 Status Remove(OpKernelContext* ctx, const Tensor& key) override
479 LOCKS_EXCLUDED(mu_) {
480 if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
481 TensorShape expected_shape({key.dim_size(0)});
482 expected_shape.AppendShape(key_shape_);
483 return errors::InvalidArgument("Expected key shape ",
484 expected_shape.DebugString(), " got ",
485 key.shape().DebugString());
486 }
487 mutex_lock l(mu_);
488 return DoRemove(ctx, key);
489 }
490
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)491 Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
492 const Tensor& values) override LOCKS_EXCLUDED(mu_) {
493 mutex_lock l(mu_);
494 num_buckets_ = keys.dim_size(0);
495 key_buckets_ = PersistentTensor(keys);
496 value_buckets_ = PersistentTensor(values);
497 // Count the number of keys that are not the empty_key or deleted_key.
498 // This requires iterating through the whole table but that is OK as we
499 // only execute it during checkpoint restore.
500 num_entries_ = 0;
501 const auto empty_key_tensor =
502 empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
503 {1, key_shape_.num_elements()});
504 const auto deleted_key_tensor =
505 deleted_key_.AccessTensor(ctx)->template shaped<K, 2>(
506 {1, key_shape_.num_elements()});
507 const auto key_buckets_tensor =
508 key_buckets_.AccessTensor(ctx)->template matrix<K>();
509 for (int64 i = 0; i < num_buckets_; ++i) {
510 if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
511 !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
512 ++num_entries_;
513 }
514 }
515 return Status::OK();
516 }
517
ExportValues(OpKernelContext * ctx)518 Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
519 tf_shared_lock l(mu_);
520 Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
521 Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
522 TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
523 TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_tensor));
524 return Status::OK();
525 }
526
CheckKeyAndValueTensorsForImport(const Tensor & keys,const Tensor & values)527 Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
528 const Tensor& values) override {
529 TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
530 TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
531
532 // The storage format in key_buckets_ and value_buckets_ is always vectors,
533 // even if the inputs are scalars. This is what eventually gets exported
534 // and is expected by the import method as well.
535 TensorShape key_shape = MaybeVectorizeShape(key_shape_);
536 TensorShape value_shape = MaybeVectorizeShape(value_shape_);
537
538 // Compute the final expected shape of the value by starting with the shape
539 // of all keys, removing the dimensions particular to each key and then
540 // appending the shape of a single value.
541 TensorShape expected_value_shape = keys.shape();
542 expected_value_shape.RemoveLastDims(key_shape.dims());
543 expected_value_shape.AppendShape(value_shape);
544 if (values.shape() != expected_value_shape) {
545 return errors::InvalidArgument(
546 "Expected shape ", expected_value_shape.DebugString(),
547 " for value, got ", values.shape().DebugString());
548 }
549 return Status::OK();
550 }
551
key_dtype() const552 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
553
value_dtype() const554 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
555
key_shape() const556 TensorShape key_shape() const override { return key_shape_; }
557
value_shape() const558 TensorShape value_shape() const override { return value_shape_; }
559
MemoryUsed() const560 int64 MemoryUsed() const override {
561 tf_shared_lock l(mu_);
562 return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
563 value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
564 }
565
566 private:
DoInsert(OpKernelContext * ctx,const Tensor & key,const Tensor & value,bool ignore_empty_and_deleted_key)567 Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
568 bool ignore_empty_and_deleted_key)
569 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570 const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
571 const int64 value_size = value_shape_.num_elements();
572 const int64 key_size = key_shape_.num_elements();
573 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
574 auto value_matrix = value.shaped<V, 2>({num_elements, value_size});
575
576 auto key_buckets_matrix =
577 key_buckets_.AccessTensor(ctx)->template matrix<K>();
578 auto value_buckets_matrix =
579 value_buckets_.AccessTensor(ctx)->template matrix<V>();
580 const auto empty_key_tensor =
581 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
582 const auto deleted_key_tensor =
583 deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
584 const int64 bit_mask = num_buckets_ - 1;
585 for (int64 i = 0; i < num_elements; ++i) {
586 const uint64 key_hash = HashKey(key_matrix, i);
587 if (empty_key_hash_ == key_hash &&
588 IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
589 if (ignore_empty_and_deleted_key) {
590 continue;
591 }
592 return errors::InvalidArgument(
593 "Using the empty_key as a table key is not allowed");
594 }
595 if (deleted_key_hash_ == key_hash &&
596 IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
597 if (ignore_empty_and_deleted_key) {
598 continue;
599 }
600 return errors::InvalidArgument(
601 "Using the deleted_key as a table key is not allowed");
602 }
603 int64 bucket_index = key_hash & bit_mask;
604 int64 num_probes = 0;
605 while (true) {
606 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
607 for (int64 j = 0; j < value_size; ++j) {
608 value_buckets_matrix(bucket_index, j) =
609 SubtleMustCopyIfIntegral(value_matrix(i, j));
610 }
611 break;
612 }
613 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
614 IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
615 0)) {
616 ++num_entries_;
617 for (int64 j = 0; j < key_size; ++j) {
618 key_buckets_matrix(bucket_index, j) =
619 SubtleMustCopyIfIntegral(key_matrix(i, j));
620 }
621 for (int64 j = 0; j < value_size; ++j) {
622 value_buckets_matrix(bucket_index, j) =
623 SubtleMustCopyIfIntegral(value_matrix(i, j));
624 }
625 break;
626 }
627 ++num_probes;
628 bucket_index =
629 (bucket_index + num_probes) & bit_mask; // quadratic probing
630 if (num_probes >= num_buckets_) {
631 return errors::Internal(
632 "Internal error in MutableDenseHashTable insert");
633 }
634 }
635 }
636 return Status::OK();
637 }
638
DoRemove(OpKernelContext * ctx,const Tensor & key)639 Status DoRemove(OpKernelContext* ctx, const Tensor& key)
640 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
641 const int64 num_elements = key.dim_size(0);
642 const int64 key_size = key_shape_.num_elements();
643 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
644
645 auto key_buckets_matrix =
646 key_buckets_.AccessTensor(ctx)->template matrix<K>();
647 const auto empty_key_tensor =
648 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
649 const auto deleted_key_tensor =
650 deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
651 const auto deleted_key_flat =
652 deleted_key_.AccessTensor(ctx)->template flat<K>();
653 const int64 bit_mask = num_buckets_ - 1;
654 for (int64 i = 0; i < num_elements; ++i) {
655 const uint64 key_hash = HashKey(key_matrix, i);
656 if (empty_key_hash_ == key_hash &&
657 IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
658 return errors::InvalidArgument(
659 "Using the empty_key as a table key is not allowed");
660 }
661 if (deleted_key_hash_ == key_hash &&
662 IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
663 return errors::InvalidArgument(
664 "Using the deleted_key as a table key is not allowed");
665 }
666 int64 bucket_index = key_hash & bit_mask;
667 int64 num_probes = 0;
668 while (true) {
669 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
670 --num_entries_;
671 for (int64 j = 0; j < key_size; ++j) {
672 key_buckets_matrix(bucket_index, j) =
673 SubtleMustCopyIfIntegral(deleted_key_flat(j));
674 }
675 break;
676 }
677 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
678 break;
679 }
680 ++num_probes;
681 bucket_index =
682 (bucket_index + num_probes) & bit_mask; // quadratic probing
683 if (num_probes >= num_buckets_) {
684 return errors::Internal(
685 "Internal error in MutableDenseHashTable remove");
686 }
687 }
688 }
689 return Status::OK();
690 }
691
AllocateBuckets(OpKernelContext * ctx,int64 new_num_buckets)692 Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
693 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
694 if (new_num_buckets < 4 ||
695 ((new_num_buckets & (new_num_buckets - 1)) != 0)) {
696 return errors::InvalidArgument(
697 "Number of buckets must be at least 4 and a power of 2, got: ",
698 new_num_buckets);
699 }
700 num_buckets_ = new_num_buckets;
701 num_entries_ = 0;
702
703 const int64 key_size = key_shape_.num_elements();
704 Tensor* key_buckets_tensor;
705 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
706 key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_,
707 &key_buckets_tensor));
708 auto key_buckets_matrix = key_buckets_tensor->matrix<K>();
709 const auto empty_key_flat =
710 empty_key_.AccessTensor(ctx)->template flat<K>();
711 for (int64 i = 0; i < num_buckets_; ++i) {
712 for (int64 j = 0; j < key_size; ++j) {
713 key_buckets_matrix(i, j) = empty_key_flat(j);
714 }
715 }
716
717 const int64 value_size = value_shape_.num_elements();
718 Tensor* value_buckets_tensor;
719 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
720 value_dtype(), TensorShape({num_buckets_, value_size}), &value_buckets_,
721 &value_buckets_tensor));
722 auto value_buckets_matrix = value_buckets_tensor->matrix<V>();
723 for (int64 i = 0; i < num_buckets_; ++i) {
724 for (int64 j = 0; j < value_size; ++j) {
725 // Initialize values to the default value for the type to avoid
726 // exposing uninitialized memory in ExportValues().
727 value_buckets_matrix(i, j) = V();
728 }
729 }
730 return Status::OK();
731 }
732
Rebucket(OpKernelContext * ctx,int64 num_new_buckets)733 Status Rebucket(OpKernelContext* ctx, int64 num_new_buckets)
734 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
735 Tensor old_key_buckets = *key_buckets_.AccessTensor(ctx);
736 Tensor old_value_buckets = *value_buckets_.AccessTensor(ctx);
737 TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets));
738 return DoInsert(ctx, old_key_buckets, old_value_buckets, true);
739 }
740
HashKey(typename TTypes<K>::ConstMatrix key,int64 index) const741 uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64 index) const {
742 if (key_shape_.num_elements() == 1) {
743 return HashScalar(key(index, 0));
744 }
745 uint64 result = 0;
746 for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
747 result = Hash64Combine(result, HashScalar(key(index, i)));
748 }
749 return result;
750 }
751
752 // Use a template to allow this function to be used both with Matrix and
753 // ConstMatrix types.
754 template <typename MT2>
IsEqualKey(typename TTypes<K>::Matrix tensor1,int64 index1,MT2 tensor2,int64 index2) const755 bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64 index1, MT2 tensor2,
756 int64 index2) const {
757 for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
758 if (tensor1(index1, i) != tensor2(index2, i)) {
759 return false;
760 }
761 }
762 return true;
763 }
764
765 TensorShape key_shape_;
766 TensorShape value_shape_;
767 float max_load_factor_;
768 mutable mutex mu_;
769 int64 num_entries_ GUARDED_BY(mu_);
770 int64 num_buckets_ GUARDED_BY(mu_);
771 PersistentTensor key_buckets_ GUARDED_BY(mu_);
772 PersistentTensor value_buckets_ GUARDED_BY(mu_);
773 PersistentTensor empty_key_;
774 uint64 empty_key_hash_;
775 PersistentTensor deleted_key_;
776 uint64 deleted_key_hash_;
777 };
778
779 } // namespace lookup
780
781 // Table lookup op. Perform the lookup operation on the given table.
782 class LookupTableFindOp : public OpKernel {
783 public:
LookupTableFindOp(OpKernelConstruction * ctx)784 explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
785
Compute(OpKernelContext * ctx)786 void Compute(OpKernelContext* ctx) override {
787 lookup::LookupInterface* table;
788 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
789 core::ScopedUnref unref_me(table);
790
791 // Input 0 could be a STRING_REF or a RESOURCE
792 DataType expected_input_0 =
793 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
794 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
795 table->value_dtype()};
796 DataTypeVector expected_outputs = {table->value_dtype()};
797 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
798
799 const Tensor& key = ctx->input(1);
800 const Tensor& default_value = ctx->input(2);
801 OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
802
803 TensorShape output_shape = key.shape();
804 output_shape.RemoveLastDims(table->key_shape().dims());
805 output_shape.AppendShape(table->value_shape());
806 Tensor* out;
807 OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
808
809 OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
810 }
811 };
812
813 REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
814 LookupTableFindOp);
815 REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
816 LookupTableFindOp);
817
818 // Table insert op.
819 class LookupTableInsertOp : public OpKernel {
820 public:
LookupTableInsertOp(OpKernelConstruction * ctx)821 explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
822
Compute(OpKernelContext * ctx)823 void Compute(OpKernelContext* ctx) override {
824 lookup::LookupInterface* table;
825 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
826 core::ScopedUnref unref_me(table);
827
828 DataType expected_input_0 =
829 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
830 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
831 table->value_dtype()};
832 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
833
834 const Tensor& keys = ctx->input(1);
835 const Tensor& values = ctx->input(2);
836 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values));
837
838 int64 memory_used_before = 0;
839 if (ctx->track_allocations()) {
840 memory_used_before = table->MemoryUsed();
841 }
842 OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values));
843 if (ctx->track_allocations()) {
844 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
845 memory_used_before);
846 }
847 }
848 };
849
850 REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
851 LookupTableInsertOp);
852 REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
853 LookupTableInsertOp);
854
855 // Table remove op.
856 class LookupTableRemoveOp : public OpKernel {
857 public:
LookupTableRemoveOp(OpKernelConstruction * ctx)858 explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
859
Compute(OpKernelContext * ctx)860 void Compute(OpKernelContext* ctx) override {
861 lookup::LookupInterface* table;
862 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
863 core::ScopedUnref unref_me(table);
864
865 DataType expected_input_0 =
866 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
867 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()};
868 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
869
870 const Tensor& key = ctx->input(1);
871 OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
872
873 int64 memory_used_before = 0;
874 if (ctx->track_allocations()) {
875 memory_used_before = table->MemoryUsed();
876 }
877 OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
878 if (ctx->track_allocations()) {
879 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
880 memory_used_before);
881 }
882 }
883 };
884
885 REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
886 LookupTableRemoveOp);
887
888 // Op that returns the size of the given table.
889 class LookupTableSizeOp : public OpKernel {
890 public:
LookupTableSizeOp(OpKernelConstruction * ctx)891 explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
892
Compute(OpKernelContext * ctx)893 void Compute(OpKernelContext* ctx) override {
894 lookup::LookupInterface* table;
895 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
896 core::ScopedUnref unref_me(table);
897
898 Tensor* out;
899 OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out));
900 out->flat<int64>().setConstant(table->size());
901 }
902 };
903
904 REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
905 LookupTableSizeOp);
906 REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
907 LookupTableSizeOp);
908
909 // Op that outputs tensors of all keys and all values.
910 class LookupTableExportOp : public OpKernel {
911 public:
LookupTableExportOp(OpKernelConstruction * ctx)912 explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
913
Compute(OpKernelContext * ctx)914 void Compute(OpKernelContext* ctx) override {
915 lookup::LookupInterface* table;
916 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
917 core::ScopedUnref unref_me(table);
918
919 OP_REQUIRES_OK(ctx, table->ExportValues(ctx));
920 }
921 };
922
923 REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
924 LookupTableExportOp);
925 REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
926 LookupTableExportOp);
927
928 // Clear the table and insert data.
929 class LookupTableImportOp : public OpKernel {
930 public:
LookupTableImportOp(OpKernelConstruction * ctx)931 explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
932
Compute(OpKernelContext * ctx)933 void Compute(OpKernelContext* ctx) override {
934 lookup::LookupInterface* table;
935 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
936 core::ScopedUnref unref_me(table);
937
938 DataType expected_input_0 =
939 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
940 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
941 table->value_dtype()};
942 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
943
944 const Tensor& keys = ctx->input(1);
945 const Tensor& values = ctx->input(2);
946 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
947
948 int memory_used_before = 0;
949 if (ctx->track_allocations()) {
950 memory_used_before = table->MemoryUsed();
951 }
952 OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
953 if (ctx->track_allocations()) {
954 ctx->record_persistent_memory_allocation(table->MemoryUsed() -
955 memory_used_before);
956 }
957 }
958 };
959
960 REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
961 LookupTableImportOp);
962 REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
963 LookupTableImportOp);
964
965 // Register the HashTable op with the currently supported key and value types.
966 #define REGISTER_KERNEL(key_dtype, value_dtype) \
967 REGISTER_KERNEL_BUILDER( \
968 Name("HashTable") \
969 .Device(DEVICE_CPU) \
970 .TypeConstraint<key_dtype>("key_dtype") \
971 .TypeConstraint<value_dtype>("value_dtype"), \
972 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
973 value_dtype>) \
974 REGISTER_KERNEL_BUILDER( \
975 Name("HashTableV2") \
976 .Device(DEVICE_CPU) \
977 .TypeConstraint<key_dtype>("key_dtype") \
978 .TypeConstraint<value_dtype>("value_dtype"), \
979 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
980 value_dtype>)
981
982 REGISTER_KERNEL(int32, double);
983 REGISTER_KERNEL(int32, float);
984 REGISTER_KERNEL(int32, int32);
985 REGISTER_KERNEL(int32, string);
986 REGISTER_KERNEL(int64, double);
987 REGISTER_KERNEL(int64, float);
988 REGISTER_KERNEL(int64, int32);
989 REGISTER_KERNEL(int64, int64);
990 REGISTER_KERNEL(int64, string);
991 REGISTER_KERNEL(string, bool);
992 REGISTER_KERNEL(string, double);
993 REGISTER_KERNEL(string, float);
994 REGISTER_KERNEL(string, int32);
995 REGISTER_KERNEL(string, int64);
996 REGISTER_KERNEL(string, string);
997
998 #undef REGISTER_KERNEL
999
1000 // Register the MutableHashTable op.
1001 #define REGISTER_KERNEL(key_dtype, value_dtype) \
1002 REGISTER_KERNEL_BUILDER( \
1003 Name("MutableHashTable") \
1004 .Device(DEVICE_CPU) \
1005 .TypeConstraint<key_dtype>("key_dtype") \
1006 .TypeConstraint<value_dtype>("value_dtype"), \
1007 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1008 key_dtype, value_dtype>) \
1009 REGISTER_KERNEL_BUILDER( \
1010 Name("MutableHashTableV2") \
1011 .Device(DEVICE_CPU) \
1012 .TypeConstraint<key_dtype>("key_dtype") \
1013 .TypeConstraint<value_dtype>("value_dtype"), \
1014 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1015 key_dtype, value_dtype>)
1016
1017 REGISTER_KERNEL(int32, double);
1018 REGISTER_KERNEL(int32, float);
1019 REGISTER_KERNEL(int32, int32);
1020 REGISTER_KERNEL(int64, double);
1021 REGISTER_KERNEL(int64, float);
1022 REGISTER_KERNEL(int64, int32);
1023 REGISTER_KERNEL(int64, int64);
1024 REGISTER_KERNEL(int64, string);
1025 REGISTER_KERNEL(int64, Variant);
1026 REGISTER_KERNEL(string, bool);
1027 REGISTER_KERNEL(string, double);
1028 REGISTER_KERNEL(string, float);
1029 REGISTER_KERNEL(string, int32);
1030 REGISTER_KERNEL(string, int64);
1031
1032 #undef REGISTER_KERNEL
1033
1034 // Register the MutableHashTableOfTensors op.
1035 #define REGISTER_KERNEL(key_dtype, value_dtype) \
1036 REGISTER_KERNEL_BUILDER( \
1037 Name("MutableHashTableOfTensors") \
1038 .Device(DEVICE_CPU) \
1039 .TypeConstraint<key_dtype>("key_dtype") \
1040 .TypeConstraint<value_dtype>("value_dtype"), \
1041 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1042 key_dtype, value_dtype>) \
1043 REGISTER_KERNEL_BUILDER( \
1044 Name("MutableHashTableOfTensorsV2") \
1045 .Device(DEVICE_CPU) \
1046 .TypeConstraint<key_dtype>("key_dtype") \
1047 .TypeConstraint<value_dtype>("value_dtype"), \
1048 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1049 key_dtype, value_dtype>)
1050
1051 REGISTER_KERNEL(int32, double);
1052 REGISTER_KERNEL(int32, float);
1053 REGISTER_KERNEL(int32, int32);
1054 REGISTER_KERNEL(int64, double);
1055 REGISTER_KERNEL(int64, float);
1056 REGISTER_KERNEL(int64, int32);
1057 REGISTER_KERNEL(int64, int64);
1058 REGISTER_KERNEL(int64, string);
1059 REGISTER_KERNEL(string, bool);
1060 REGISTER_KERNEL(string, double);
1061 REGISTER_KERNEL(string, float);
1062 REGISTER_KERNEL(string, int32);
1063 REGISTER_KERNEL(string, int64);
1064
1065 #undef REGISTER_KERNEL
1066
1067 // Register the MutableDenseHashTable op.
1068 #define REGISTER_KERNEL(key_dtype, value_dtype) \
1069 REGISTER_KERNEL_BUILDER( \
1070 Name("MutableDenseHashTable") \
1071 .Device(DEVICE_CPU) \
1072 .TypeConstraint<key_dtype>("key_dtype") \
1073 .TypeConstraint<value_dtype>("value_dtype"), \
1074 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1075 key_dtype, value_dtype>) \
1076 REGISTER_KERNEL_BUILDER( \
1077 Name("MutableDenseHashTableV2") \
1078 .Device(DEVICE_CPU) \
1079 .TypeConstraint<key_dtype>("key_dtype") \
1080 .TypeConstraint<value_dtype>("value_dtype"), \
1081 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1082 key_dtype, value_dtype>)
1083
1084 REGISTER_KERNEL(int32, double);
1085 REGISTER_KERNEL(int32, float);
1086 REGISTER_KERNEL(int32, int32);
1087 REGISTER_KERNEL(int64, bool);
1088 REGISTER_KERNEL(int64, double);
1089 REGISTER_KERNEL(int64, float);
1090 REGISTER_KERNEL(int64, int32);
1091 REGISTER_KERNEL(int64, int64);
1092 REGISTER_KERNEL(int64, Variant);
1093 REGISTER_KERNEL(string, bool);
1094 REGISTER_KERNEL(string, double);
1095 REGISTER_KERNEL(string, float);
1096 REGISTER_KERNEL(string, int32);
1097 REGISTER_KERNEL(string, int64);
1098
1099 #undef REGISTER_KERNEL
1100
1101 } // namespace tensorflow
1102