• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/lookup_table_op.h"
17 #define EIGEN_USE_THREADS
18 
19 #include <string>
20 #include <type_traits>
21 #include <utility>
22 
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/variant.h"
26 #include "tensorflow/core/kernels/initializable_lookup_table.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 
30 namespace tensorflow {
31 namespace lookup {
32 
33 // Lookup table that wraps an unordered_map, where the key and value data type
34 // is specified. Each individual value must be a scalar. If vector values are
35 // required, use MutableHashTableOfTensors.
36 //
37 // This table is mutable and thread safe - Insert can be called at any time.
38 //
39 // Sample use case:
40 //
41 // MutableHashTableOfScalars<int64, int64> table;  // int64 -> int64.
42 // // Populate the table, elements could be added in one or multiple calls.
43 // table.Insert(key_tensor, value_tensor); // Populate the table.
44 //
45 // table.Find(in_t, &out_t, default_t)
46 //
47 template <class K, class V>
48 class MutableHashTableOfScalars final : public LookupInterface {
49  public:
MutableHashTableOfScalars(OpKernelContext * ctx,OpKernel * kernel)50   MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
51 
size() const52   size_t size() const override {
53     tf_shared_lock l(mu_);
54     return table_.size();
55   }
56 
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)57   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
58               const Tensor& default_value) override {
59     const V default_val = default_value.flat<V>()(0);
60     const auto key_values = key.flat<K>();
61     auto value_values = value->flat<V>();
62 
63     tf_shared_lock l(mu_);
64     for (int64 i = 0; i < key_values.size(); ++i) {
65       value_values(i) = gtl::FindWithDefault(
66           table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
67     }
68 
69     return Status::OK();
70   }
71 
DoInsert(bool clear,const Tensor & keys,const Tensor & values)72   Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
73     const auto key_values = keys.flat<K>();
74     const auto value_values = values.flat<V>();
75 
76     mutex_lock l(mu_);
77     if (clear) {
78       table_.clear();
79     }
80     for (int64 i = 0; i < key_values.size(); ++i) {
81       gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
82                           SubtleMustCopyIfIntegral(value_values(i)));
83     }
84     return Status::OK();
85   }
86 
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)87   Status Insert(OpKernelContext* ctx, const Tensor& keys,
88                 const Tensor& values) override {
89     return DoInsert(false, keys, values);
90   }
91 
Remove(OpKernelContext * ctx,const Tensor & keys)92   Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
93     const auto key_values = keys.flat<K>();
94 
95     mutex_lock l(mu_);
96     for (int64 i = 0; i < key_values.size(); ++i) {
97       table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
98     }
99     return Status::OK();
100   }
101 
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)102   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
103                       const Tensor& values) override {
104     return DoInsert(true, keys, values);
105   }
106 
ExportValues(OpKernelContext * ctx)107   Status ExportValues(OpKernelContext* ctx) override {
108     tf_shared_lock l(mu_);
109     int64 size = table_.size();
110 
111     Tensor* keys;
112     Tensor* values;
113     TF_RETURN_IF_ERROR(
114         ctx->allocate_output("keys", TensorShape({size}), &keys));
115     TF_RETURN_IF_ERROR(
116         ctx->allocate_output("values", TensorShape({size}), &values));
117 
118     auto keys_data = keys->flat<K>();
119     auto values_data = values->flat<V>();
120     int64 i = 0;
121     for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
122       keys_data(i) = it->first;
123       values_data(i) = it->second;
124     }
125     return Status::OK();
126   }
127 
key_dtype() const128   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
129 
value_dtype() const130   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
131 
key_shape() const132   TensorShape key_shape() const final { return TensorShape(); }
133 
value_shape() const134   TensorShape value_shape() const override { return TensorShape(); }
135 
MemoryUsed() const136   int64 MemoryUsed() const override {
137     int64 ret = 0;
138     tf_shared_lock l(mu_);
139     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
140       size_t bucket_size = table_.bucket_size(i);
141       if (bucket_size == 0) {
142         ret++;
143       } else {
144         ret += bucket_size;
145       }
146     }
147     return sizeof(MutableHashTableOfScalars) + ret;
148   }
149 
150  private:
151   mutable mutex mu_;
152   std::unordered_map<K, V> table_ GUARDED_BY(mu_);
153 };
154 
155 // Lookup table that wraps an unordered_map. Behaves identical to
156 // MutableHashTableOfScalars except that each value must be a vector.
157 template <class K, class V>
158 class MutableHashTableOfTensors final : public LookupInterface {
159  public:
MutableHashTableOfTensors(OpKernelContext * ctx,OpKernel * kernel)160   MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
161     OP_REQUIRES_OK(ctx,
162                    GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
163     OP_REQUIRES(
164         ctx, TensorShapeUtils::IsVector(value_shape_),
165         errors::InvalidArgument("Default value must be a vector, got shape ",
166                                 value_shape_.DebugString()));
167   }
168 
size() const169   size_t size() const override {
170     tf_shared_lock l(mu_);
171     return table_.size();
172   }
173 
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)174   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
175               const Tensor& default_value) override {
176     const auto default_flat = default_value.flat<V>();
177     const auto key_values = key.flat<K>();
178     auto value_values = value->flat_inner_dims<V, 2>();
179     int64 value_dim = value_shape_.dim_size(0);
180 
181     tf_shared_lock l(mu_);
182     for (int64 i = 0; i < key_values.size(); ++i) {
183       ValueArray* value_vec =
184           gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
185       if (value_vec != nullptr) {
186         for (int64 j = 0; j < value_dim; j++) {
187           value_values(i, j) = value_vec->at(j);
188         }
189       } else {
190         for (int64 j = 0; j < value_dim; j++) {
191           value_values(i, j) = default_flat(j);
192         }
193       }
194     }
195 
196     return Status::OK();
197   }
198 
DoInsert(bool clear,const Tensor & keys,const Tensor & values)199   Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
200     const auto key_values = keys.flat<K>();
201     const auto value_values = values.flat_inner_dims<V, 2>();
202     int64 value_dim = value_shape_.dim_size(0);
203 
204     mutex_lock l(mu_);
205     if (clear) {
206       table_.clear();
207     }
208     for (int64 i = 0; i < key_values.size(); ++i) {
209       ValueArray value_vec;
210       for (int64 j = 0; j < value_dim; j++) {
211         V value = value_values(i, j);
212         value_vec.push_back(value);
213       }
214       gtl::InsertOrUpdate(&table_, SubtleMustCopyIfIntegral(key_values(i)),
215                           value_vec);
216     }
217     return Status::OK();
218   }
219 
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)220   Status Insert(OpKernelContext* ctx, const Tensor& keys,
221                 const Tensor& values) override {
222     return DoInsert(false, keys, values);
223   }
224 
Remove(OpKernelContext * ctx,const Tensor & keys)225   Status Remove(OpKernelContext* ctx, const Tensor& keys) override {
226     const auto key_values = keys.flat<K>();
227 
228     mutex_lock l(mu_);
229     for (int64 i = 0; i < key_values.size(); ++i) {
230       table_.erase(SubtleMustCopyIfIntegral(key_values(i)));
231     }
232     return Status::OK();
233   }
234 
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)235   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
236                       const Tensor& values) override {
237     return DoInsert(true, keys, values);
238   }
239 
ExportValues(OpKernelContext * ctx)240   Status ExportValues(OpKernelContext* ctx) override {
241     tf_shared_lock l(mu_);
242     int64 size = table_.size();
243     int64 value_dim = value_shape_.dim_size(0);
244 
245     Tensor* keys;
246     Tensor* values;
247     TF_RETURN_IF_ERROR(
248         ctx->allocate_output("keys", TensorShape({size}), &keys));
249     TF_RETURN_IF_ERROR(ctx->allocate_output(
250         "values", TensorShape({size, value_dim}), &values));
251 
252     auto keys_data = keys->flat<K>();
253     auto values_data = values->matrix<V>();
254     int64 i = 0;
255     for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
256       K key = it->first;
257       ValueArray value = it->second;
258       keys_data(i) = key;
259       for (int64 j = 0; j < value_dim; j++) {
260         values_data(i, j) = value[j];
261       }
262     }
263     return Status::OK();
264   }
265 
key_dtype() const266   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
267 
value_dtype() const268   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
269 
key_shape() const270   TensorShape key_shape() const final { return TensorShape(); }
271 
value_shape() const272   TensorShape value_shape() const override { return value_shape_; }
273 
MemoryUsed() const274   int64 MemoryUsed() const override {
275     int64 ret = 0;
276     tf_shared_lock l(mu_);
277     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
278       size_t bucket_size = table_.bucket_size(i);
279       if (bucket_size == 0) {
280         ret++;
281       } else {
282         ret += bucket_size;
283       }
284     }
285     return sizeof(MutableHashTableOfTensors) + ret;
286   }
287 
288  private:
289   TensorShape value_shape_;
290   mutable mutex mu_;
291   typedef gtl::InlinedVector<V, 4> ValueArray;
292   std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
293 };
294 
295 namespace {
296 
297 template <typename T>
HashScalar(const T & key)298 inline uint64 HashScalar(const T& key) {
299   return static_cast<uint64>(key);
300 }
301 
HashScalar(const string & key)302 inline uint64 HashScalar(const string& key) { return Hash64(key); }
303 
304 // If the given shape is a scalar return {1} instead. Otherwise leave it alone.
MaybeVectorizeShape(const TensorShape & shape)305 TensorShape MaybeVectorizeShape(const TensorShape& shape) {
306   if (shape.dims() == 0) {
307     return TensorShape({1});
308   }
309   return shape;
310 }
311 
312 }  // namespace
313 
314 // Modeled after densehashtable in https://github.com/sparsehash/sparsehash
315 template <class K, class V>
316 class MutableDenseHashTable final : public LookupInterface {
317  public:
MutableDenseHashTable(OpKernelContext * ctx,OpKernel * kernel)318   MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) {
319     OP_REQUIRES_OK(
320         ctx, GetNodeAttr(kernel->def(), "max_load_factor", &max_load_factor_));
321     OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1,
322                 errors::InvalidArgument(
323                     "max_load_factor must be between 0 and 1, got: ",
324                     max_load_factor_));
325 
326     OP_REQUIRES_OK(ctx,
327                    GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
328     OP_REQUIRES(ctx,
329                 TensorShapeUtils::IsScalar(value_shape_) ||
330                     TensorShapeUtils::IsVector(value_shape_),
331                 errors::InvalidArgument(
332                     "Empty value must be a scalar or a vector, got shape ",
333                     value_shape_.DebugString()));
334 
335     const Tensor* empty_key_input;
336     OP_REQUIRES_OK(ctx, ctx->input("empty_key", &empty_key_input));
337     key_shape_ = empty_key_input->shape();
338     OP_REQUIRES(ctx,
339                 TensorShapeUtils::IsScalar(key_shape_) ||
340                     TensorShapeUtils::IsVector(key_shape_),
341                 errors::InvalidArgument(
342                     "Empty key must be a scalar or a vector, got shape ",
343                     key_shape_.DebugString()));
344     empty_key_ = PersistentTensor(*empty_key_input);
345     empty_key_hash_ = HashKey(
346         empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
347         0);
348 
349     const Tensor* deleted_key_input;
350     OP_REQUIRES_OK(ctx, ctx->input("deleted_key", &deleted_key_input));
351     OP_REQUIRES(ctx, key_shape_.IsSameSize(deleted_key_input->shape()),
352                 errors::InvalidArgument(
353                     "Empty and deleted keys must have same shape, got shapes: ",
354                     key_shape_.DebugString(), " and ",
355                     deleted_key_input->shape().DebugString()));
356     deleted_key_ = PersistentTensor(*deleted_key_input);
357     deleted_key_hash_ = HashKey(deleted_key_input->template shaped<K, 2>(
358                                     {1, key_shape_.num_elements()}),
359                                 0);
360 
361     if (empty_key_hash_ == deleted_key_hash_) {
362       const int64 key_size = key_shape_.num_elements();
363       const auto empty_key_matrix =
364           empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
365       const auto deleted_key_matrix =
366           deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
367       OP_REQUIRES(
368           ctx, !IsEqualKey(empty_key_matrix, 0, deleted_key_matrix, 0),
369           errors::InvalidArgument("Empty and deleted keys cannot be equal"));
370     }
371 
372     int64 initial_num_buckets;
373     OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
374                                     &initial_num_buckets));
375     OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets));
376   }
377 
size() const378   size_t size() const override LOCKS_EXCLUDED(mu_) {
379     tf_shared_lock l(mu_);
380     return num_entries_;
381   }
382 
Find(OpKernelContext * ctx,const Tensor & key,Tensor * value,const Tensor & default_value)383   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
384               const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
385     const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
386     const int64 key_size = key_shape_.num_elements();
387     const int64 value_size = value_shape_.num_elements();
388     if (key.NumElements() != num_elements * key_size) {
389       TensorShape expected_shape({num_elements});
390       expected_shape.AppendShape(key_shape_);
391       return errors::InvalidArgument("Expected key shape ",
392                                      expected_shape.DebugString(), " got ",
393                                      key.shape().DebugString());
394     }
395     const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
396     auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
397     const auto default_flat = default_value.flat<V>();
398 
399     tf_shared_lock l(mu_);
400     const auto key_buckets_matrix =
401         key_buckets_.AccessTensor(ctx)->template matrix<K>();
402     const auto value_buckets_matrix =
403         value_buckets_.AccessTensor(ctx)->template matrix<V>();
404     const auto empty_key_matrix =
405         empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
406     const auto deleted_key_matrix =
407         deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
408     const int64 bit_mask = num_buckets_ - 1;
409     // TODO(andreasst): parallelize using work_sharder
410     for (int64 i = 0; i < num_elements; ++i) {
411       const uint64 key_hash = HashKey(key_matrix, i);
412       if (empty_key_hash_ == key_hash &&
413           IsEqualKey(empty_key_matrix, 0, key_matrix, i)) {
414         return errors::InvalidArgument(
415             "Using the empty_key as a table key is not allowed");
416       }
417       if (deleted_key_hash_ == key_hash &&
418           IsEqualKey(deleted_key_matrix, 0, key_matrix, i)) {
419         return errors::InvalidArgument(
420             "Using the deleted_key as a table key is not allowed");
421       }
422       int64 bucket_index = key_hash & bit_mask;
423       int64 num_probes = 0;
424       while (true) {
425         if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
426           for (int64 j = 0; j < value_size; ++j) {
427             // TODO(andreasst): check if we can get rid of SubtleMustCopy
428             // here and elsewhere in this file.
429             value_matrix(i, j) =
430                 SubtleMustCopyIfIntegral(value_buckets_matrix(bucket_index, j));
431           }
432           break;
433         }
434         if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) {
435           for (int64 j = 0; j < value_size; ++j) {
436             value_matrix(i, j) = SubtleMustCopyIfIntegral(default_flat(j));
437           }
438           break;
439         }
440         ++num_probes;
441         bucket_index =
442             (bucket_index + num_probes) & bit_mask;  // quadratic probing
443         if (num_probes >= num_buckets_) {
444           return errors::Internal(
445               "Internal error in MutableDenseHashTable lookup");
446         }
447       }
448     }
449     return Status::OK();
450   }
451 
Insert(OpKernelContext * ctx,const Tensor & key,const Tensor & value)452   Status Insert(OpKernelContext* ctx, const Tensor& key,
453                 const Tensor& value) override LOCKS_EXCLUDED(mu_) {
454     const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
455     if (key.NumElements() != batch_size * key_shape_.num_elements()) {
456       TensorShape expected_shape({batch_size});
457       expected_shape.AppendShape(key_shape_);
458       return errors::InvalidArgument("Expected key shape ",
459                                      expected_shape.DebugString(), " got ",
460                                      key.shape().DebugString());
461     }
462     mutex_lock l(mu_);
463     // For simplicity we assume that all keys in the input result in inserts
464     // rather than updates. That means we may grow the table even though we
465     // don't need to. As long as the number of keys inserted in one call is
466     // small compared to the size of the map, the impact of this is minimal.
467     const int64 pending_num_entries = num_entries_ + batch_size;
468     if (pending_num_entries > num_buckets_ * max_load_factor_) {
469       int64 new_num_buckets = num_buckets_;
470       do {
471         new_num_buckets <<= 1;
472       } while (pending_num_entries > new_num_buckets * max_load_factor_);
473       TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets));
474     }
475     return DoInsert(ctx, key, value, false);
476   }
477 
Remove(OpKernelContext * ctx,const Tensor & key)478   Status Remove(OpKernelContext* ctx, const Tensor& key) override
479       LOCKS_EXCLUDED(mu_) {
480     if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
481       TensorShape expected_shape({key.dim_size(0)});
482       expected_shape.AppendShape(key_shape_);
483       return errors::InvalidArgument("Expected key shape ",
484                                      expected_shape.DebugString(), " got ",
485                                      key.shape().DebugString());
486     }
487     mutex_lock l(mu_);
488     return DoRemove(ctx, key);
489   }
490 
ImportValues(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)491   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
492                       const Tensor& values) override LOCKS_EXCLUDED(mu_) {
493     mutex_lock l(mu_);
494     num_buckets_ = keys.dim_size(0);
495     key_buckets_ = PersistentTensor(keys);
496     value_buckets_ = PersistentTensor(values);
497     // Count the number of keys that are not the empty_key or deleted_key.
498     // This requires iterating through the whole table but that is OK as we
499     // only execute it during checkpoint restore.
500     num_entries_ = 0;
501     const auto empty_key_tensor =
502         empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
503             {1, key_shape_.num_elements()});
504     const auto deleted_key_tensor =
505         deleted_key_.AccessTensor(ctx)->template shaped<K, 2>(
506             {1, key_shape_.num_elements()});
507     const auto key_buckets_tensor =
508         key_buckets_.AccessTensor(ctx)->template matrix<K>();
509     for (int64 i = 0; i < num_buckets_; ++i) {
510       if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0) &&
511           !IsEqualKey(key_buckets_tensor, i, deleted_key_tensor, 0)) {
512         ++num_entries_;
513       }
514     }
515     return Status::OK();
516   }
517 
ExportValues(OpKernelContext * ctx)518   Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
519     tf_shared_lock l(mu_);
520     Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
521     Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
522     TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
523     TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_tensor));
524     return Status::OK();
525   }
526 
CheckKeyAndValueTensorsForImport(const Tensor & keys,const Tensor & values)527   Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
528                                           const Tensor& values) override {
529     TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
530     TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
531 
532     // The storage format in key_buckets_ and value_buckets_ is always vectors,
533     // even if the inputs are scalars. This is what eventually gets exported
534     // and is expected by the import method as well.
535     TensorShape key_shape = MaybeVectorizeShape(key_shape_);
536     TensorShape value_shape = MaybeVectorizeShape(value_shape_);
537 
538     // Compute the final expected shape of the value by starting with the shape
539     // of all keys, removing the dimensions particular to each key and then
540     // appending the shape of a single value.
541     TensorShape expected_value_shape = keys.shape();
542     expected_value_shape.RemoveLastDims(key_shape.dims());
543     expected_value_shape.AppendShape(value_shape);
544     if (values.shape() != expected_value_shape) {
545       return errors::InvalidArgument(
546           "Expected shape ", expected_value_shape.DebugString(),
547           " for value, got ", values.shape().DebugString());
548     }
549     return Status::OK();
550   }
551 
key_dtype() const552   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
553 
value_dtype() const554   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
555 
key_shape() const556   TensorShape key_shape() const override { return key_shape_; }
557 
value_shape() const558   TensorShape value_shape() const override { return value_shape_; }
559 
MemoryUsed() const560   int64 MemoryUsed() const override {
561     tf_shared_lock l(mu_);
562     return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
563            value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
564   }
565 
566  private:
DoInsert(OpKernelContext * ctx,const Tensor & key,const Tensor & value,bool ignore_empty_and_deleted_key)567   Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
568                   bool ignore_empty_and_deleted_key)
569       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570     const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
571     const int64 value_size = value_shape_.num_elements();
572     const int64 key_size = key_shape_.num_elements();
573     const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
574     auto value_matrix = value.shaped<V, 2>({num_elements, value_size});
575 
576     auto key_buckets_matrix =
577         key_buckets_.AccessTensor(ctx)->template matrix<K>();
578     auto value_buckets_matrix =
579         value_buckets_.AccessTensor(ctx)->template matrix<V>();
580     const auto empty_key_tensor =
581         empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
582     const auto deleted_key_tensor =
583         deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
584     const int64 bit_mask = num_buckets_ - 1;
585     for (int64 i = 0; i < num_elements; ++i) {
586       const uint64 key_hash = HashKey(key_matrix, i);
587       if (empty_key_hash_ == key_hash &&
588           IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
589         if (ignore_empty_and_deleted_key) {
590           continue;
591         }
592         return errors::InvalidArgument(
593             "Using the empty_key as a table key is not allowed");
594       }
595       if (deleted_key_hash_ == key_hash &&
596           IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
597         if (ignore_empty_and_deleted_key) {
598           continue;
599         }
600         return errors::InvalidArgument(
601             "Using the deleted_key as a table key is not allowed");
602       }
603       int64 bucket_index = key_hash & bit_mask;
604       int64 num_probes = 0;
605       while (true) {
606         if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
607           for (int64 j = 0; j < value_size; ++j) {
608             value_buckets_matrix(bucket_index, j) =
609                 SubtleMustCopyIfIntegral(value_matrix(i, j));
610           }
611           break;
612         }
613         if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0) ||
614             IsEqualKey(key_buckets_matrix, bucket_index, deleted_key_tensor,
615                        0)) {
616           ++num_entries_;
617           for (int64 j = 0; j < key_size; ++j) {
618             key_buckets_matrix(bucket_index, j) =
619                 SubtleMustCopyIfIntegral(key_matrix(i, j));
620           }
621           for (int64 j = 0; j < value_size; ++j) {
622             value_buckets_matrix(bucket_index, j) =
623                 SubtleMustCopyIfIntegral(value_matrix(i, j));
624           }
625           break;
626         }
627         ++num_probes;
628         bucket_index =
629             (bucket_index + num_probes) & bit_mask;  // quadratic probing
630         if (num_probes >= num_buckets_) {
631           return errors::Internal(
632               "Internal error in MutableDenseHashTable insert");
633         }
634       }
635     }
636     return Status::OK();
637   }
638 
DoRemove(OpKernelContext * ctx,const Tensor & key)639   Status DoRemove(OpKernelContext* ctx, const Tensor& key)
640       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
641     const int64 num_elements = key.dim_size(0);
642     const int64 key_size = key_shape_.num_elements();
643     const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
644 
645     auto key_buckets_matrix =
646         key_buckets_.AccessTensor(ctx)->template matrix<K>();
647     const auto empty_key_tensor =
648         empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
649     const auto deleted_key_tensor =
650         deleted_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
651     const auto deleted_key_flat =
652         deleted_key_.AccessTensor(ctx)->template flat<K>();
653     const int64 bit_mask = num_buckets_ - 1;
654     for (int64 i = 0; i < num_elements; ++i) {
655       const uint64 key_hash = HashKey(key_matrix, i);
656       if (empty_key_hash_ == key_hash &&
657           IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
658         return errors::InvalidArgument(
659             "Using the empty_key as a table key is not allowed");
660       }
661       if (deleted_key_hash_ == key_hash &&
662           IsEqualKey(deleted_key_tensor, 0, key_matrix, i)) {
663         return errors::InvalidArgument(
664             "Using the deleted_key as a table key is not allowed");
665       }
666       int64 bucket_index = key_hash & bit_mask;
667       int64 num_probes = 0;
668       while (true) {
669         if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
670           --num_entries_;
671           for (int64 j = 0; j < key_size; ++j) {
672             key_buckets_matrix(bucket_index, j) =
673                 SubtleMustCopyIfIntegral(deleted_key_flat(j));
674           }
675           break;
676         }
677         if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
678           break;
679         }
680         ++num_probes;
681         bucket_index =
682             (bucket_index + num_probes) & bit_mask;  // quadratic probing
683         if (num_probes >= num_buckets_) {
684           return errors::Internal(
685               "Internal error in MutableDenseHashTable remove");
686         }
687       }
688     }
689     return Status::OK();
690   }
691 
AllocateBuckets(OpKernelContext * ctx,int64 new_num_buckets)692   Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
693       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
694     if (new_num_buckets < 4 ||
695         ((new_num_buckets & (new_num_buckets - 1)) != 0)) {
696       return errors::InvalidArgument(
697           "Number of buckets must be at least 4 and a power of 2, got: ",
698           new_num_buckets);
699     }
700     num_buckets_ = new_num_buckets;
701     num_entries_ = 0;
702 
703     const int64 key_size = key_shape_.num_elements();
704     Tensor* key_buckets_tensor;
705     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
706         key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_,
707         &key_buckets_tensor));
708     auto key_buckets_matrix = key_buckets_tensor->matrix<K>();
709     const auto empty_key_flat =
710         empty_key_.AccessTensor(ctx)->template flat<K>();
711     for (int64 i = 0; i < num_buckets_; ++i) {
712       for (int64 j = 0; j < key_size; ++j) {
713         key_buckets_matrix(i, j) = empty_key_flat(j);
714       }
715     }
716 
717     const int64 value_size = value_shape_.num_elements();
718     Tensor* value_buckets_tensor;
719     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
720         value_dtype(), TensorShape({num_buckets_, value_size}), &value_buckets_,
721         &value_buckets_tensor));
722     auto value_buckets_matrix = value_buckets_tensor->matrix<V>();
723     for (int64 i = 0; i < num_buckets_; ++i) {
724       for (int64 j = 0; j < value_size; ++j) {
725         // Initialize values to the default value for the type to avoid
726         // exposing uninitialized memory in ExportValues().
727         value_buckets_matrix(i, j) = V();
728       }
729     }
730     return Status::OK();
731   }
732 
Rebucket(OpKernelContext * ctx,int64 num_new_buckets)733   Status Rebucket(OpKernelContext* ctx, int64 num_new_buckets)
734       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
735     Tensor old_key_buckets = *key_buckets_.AccessTensor(ctx);
736     Tensor old_value_buckets = *value_buckets_.AccessTensor(ctx);
737     TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets));
738     return DoInsert(ctx, old_key_buckets, old_value_buckets, true);
739   }
740 
HashKey(typename TTypes<K>::ConstMatrix key,int64 index) const741   uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64 index) const {
742     if (key_shape_.num_elements() == 1) {
743       return HashScalar(key(index, 0));
744     }
745     uint64 result = 0;
746     for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
747       result = Hash64Combine(result, HashScalar(key(index, i)));
748     }
749     return result;
750   }
751 
752   // Use a template to allow this function to be used both with Matrix and
753   // ConstMatrix types.
754   template <typename MT2>
IsEqualKey(typename TTypes<K>::Matrix tensor1,int64 index1,MT2 tensor2,int64 index2) const755   bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64 index1, MT2 tensor2,
756                   int64 index2) const {
757     for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
758       if (tensor1(index1, i) != tensor2(index2, i)) {
759         return false;
760       }
761     }
762     return true;
763   }
764 
765   TensorShape key_shape_;
766   TensorShape value_shape_;
767   float max_load_factor_;
768   mutable mutex mu_;
769   int64 num_entries_ GUARDED_BY(mu_);
770   int64 num_buckets_ GUARDED_BY(mu_);
771   PersistentTensor key_buckets_ GUARDED_BY(mu_);
772   PersistentTensor value_buckets_ GUARDED_BY(mu_);
773   PersistentTensor empty_key_;
774   uint64 empty_key_hash_;
775   PersistentTensor deleted_key_;
776   uint64 deleted_key_hash_;
777 };
778 
779 }  // namespace lookup
780 
781 // Table lookup op. Perform the lookup operation on the given table.
782 class LookupTableFindOp : public OpKernel {
783  public:
LookupTableFindOp(OpKernelConstruction * ctx)784   explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
785 
Compute(OpKernelContext * ctx)786   void Compute(OpKernelContext* ctx) override {
787     lookup::LookupInterface* table;
788     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
789     core::ScopedUnref unref_me(table);
790 
791     // Input 0 could be a STRING_REF or a RESOURCE
792     DataType expected_input_0 =
793         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
794     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
795                                       table->value_dtype()};
796     DataTypeVector expected_outputs = {table->value_dtype()};
797     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
798 
799     const Tensor& key = ctx->input(1);
800     const Tensor& default_value = ctx->input(2);
801     OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
802 
803     TensorShape output_shape = key.shape();
804     output_shape.RemoveLastDims(table->key_shape().dims());
805     output_shape.AppendShape(table->value_shape());
806     Tensor* out;
807     OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
808 
809     OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
810   }
811 };
812 
813 REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
814                         LookupTableFindOp);
815 REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
816                         LookupTableFindOp);
817 
818 // Table insert op.
819 class LookupTableInsertOp : public OpKernel {
820  public:
LookupTableInsertOp(OpKernelConstruction * ctx)821   explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
822 
Compute(OpKernelContext * ctx)823   void Compute(OpKernelContext* ctx) override {
824     lookup::LookupInterface* table;
825     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
826     core::ScopedUnref unref_me(table);
827 
828     DataType expected_input_0 =
829         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
830     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
831                                       table->value_dtype()};
832     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
833 
834     const Tensor& keys = ctx->input(1);
835     const Tensor& values = ctx->input(2);
836     OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values));
837 
838     int64 memory_used_before = 0;
839     if (ctx->track_allocations()) {
840       memory_used_before = table->MemoryUsed();
841     }
842     OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values));
843     if (ctx->track_allocations()) {
844       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
845                                                memory_used_before);
846     }
847   }
848 };
849 
850 REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
851                         LookupTableInsertOp);
852 REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
853                         LookupTableInsertOp);
854 
855 // Table remove op.
856 class LookupTableRemoveOp : public OpKernel {
857  public:
LookupTableRemoveOp(OpKernelConstruction * ctx)858   explicit LookupTableRemoveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
859 
Compute(OpKernelContext * ctx)860   void Compute(OpKernelContext* ctx) override {
861     lookup::LookupInterface* table;
862     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
863     core::ScopedUnref unref_me(table);
864 
865     DataType expected_input_0 =
866         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
867     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype()};
868     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
869 
870     const Tensor& key = ctx->input(1);
871     OP_REQUIRES_OK(ctx, table->CheckKeyTensorForRemove(key));
872 
873     int64 memory_used_before = 0;
874     if (ctx->track_allocations()) {
875       memory_used_before = table->MemoryUsed();
876     }
877     OP_REQUIRES_OK(ctx, table->Remove(ctx, key));
878     if (ctx->track_allocations()) {
879       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
880                                                memory_used_before);
881     }
882   }
883 };
884 
885 REGISTER_KERNEL_BUILDER(Name("LookupTableRemoveV2").Device(DEVICE_CPU),
886                         LookupTableRemoveOp);
887 
888 // Op that returns the size of the given table.
889 class LookupTableSizeOp : public OpKernel {
890  public:
LookupTableSizeOp(OpKernelConstruction * ctx)891   explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
892 
Compute(OpKernelContext * ctx)893   void Compute(OpKernelContext* ctx) override {
894     lookup::LookupInterface* table;
895     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
896     core::ScopedUnref unref_me(table);
897 
898     Tensor* out;
899     OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out));
900     out->flat<int64>().setConstant(table->size());
901   }
902 };
903 
904 REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
905                         LookupTableSizeOp);
906 REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
907                         LookupTableSizeOp);
908 
909 // Op that outputs tensors of all keys and all values.
910 class LookupTableExportOp : public OpKernel {
911  public:
LookupTableExportOp(OpKernelConstruction * ctx)912   explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
913 
Compute(OpKernelContext * ctx)914   void Compute(OpKernelContext* ctx) override {
915     lookup::LookupInterface* table;
916     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
917     core::ScopedUnref unref_me(table);
918 
919     OP_REQUIRES_OK(ctx, table->ExportValues(ctx));
920   }
921 };
922 
923 REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
924                         LookupTableExportOp);
925 REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
926                         LookupTableExportOp);
927 
928 // Clear the table and insert data.
929 class LookupTableImportOp : public OpKernel {
930  public:
LookupTableImportOp(OpKernelConstruction * ctx)931   explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
932 
Compute(OpKernelContext * ctx)933   void Compute(OpKernelContext* ctx) override {
934     lookup::LookupInterface* table;
935     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
936     core::ScopedUnref unref_me(table);
937 
938     DataType expected_input_0 =
939         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
940     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
941                                       table->value_dtype()};
942     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
943 
944     const Tensor& keys = ctx->input(1);
945     const Tensor& values = ctx->input(2);
946     OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
947 
948     int memory_used_before = 0;
949     if (ctx->track_allocations()) {
950       memory_used_before = table->MemoryUsed();
951     }
952     OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
953     if (ctx->track_allocations()) {
954       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
955                                                memory_used_before);
956     }
957   }
958 };
959 
960 REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
961                         LookupTableImportOp);
962 REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
963                         LookupTableImportOp);
964 
965 // Register the HashTable op with the currently supported key and value types.
966 #define REGISTER_KERNEL(key_dtype, value_dtype)                           \
967   REGISTER_KERNEL_BUILDER(                                                \
968       Name("HashTable")                                                   \
969           .Device(DEVICE_CPU)                                             \
970           .TypeConstraint<key_dtype>("key_dtype")                         \
971           .TypeConstraint<value_dtype>("value_dtype"),                    \
972       LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
973                     value_dtype>)                                         \
974   REGISTER_KERNEL_BUILDER(                                                \
975       Name("HashTableV2")                                                 \
976           .Device(DEVICE_CPU)                                             \
977           .TypeConstraint<key_dtype>("key_dtype")                         \
978           .TypeConstraint<value_dtype>("value_dtype"),                    \
979       LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
980                     value_dtype>)
981 
982 REGISTER_KERNEL(int32, double);
983 REGISTER_KERNEL(int32, float);
984 REGISTER_KERNEL(int32, int32);
985 REGISTER_KERNEL(int32, string);
986 REGISTER_KERNEL(int64, double);
987 REGISTER_KERNEL(int64, float);
988 REGISTER_KERNEL(int64, int32);
989 REGISTER_KERNEL(int64, int64);
990 REGISTER_KERNEL(int64, string);
991 REGISTER_KERNEL(string, bool);
992 REGISTER_KERNEL(string, double);
993 REGISTER_KERNEL(string, float);
994 REGISTER_KERNEL(string, int32);
995 REGISTER_KERNEL(string, int64);
996 REGISTER_KERNEL(string, string);
997 
998 #undef REGISTER_KERNEL
999 
1000 // Register the MutableHashTable op.
1001 #define REGISTER_KERNEL(key_dtype, value_dtype)                                \
1002   REGISTER_KERNEL_BUILDER(                                                     \
1003       Name("MutableHashTable")                                                 \
1004           .Device(DEVICE_CPU)                                                  \
1005           .TypeConstraint<key_dtype>("key_dtype")                              \
1006           .TypeConstraint<value_dtype>("value_dtype"),                         \
1007       LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1008                     key_dtype, value_dtype>)                                   \
1009   REGISTER_KERNEL_BUILDER(                                                     \
1010       Name("MutableHashTableV2")                                               \
1011           .Device(DEVICE_CPU)                                                  \
1012           .TypeConstraint<key_dtype>("key_dtype")                              \
1013           .TypeConstraint<value_dtype>("value_dtype"),                         \
1014       LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
1015                     key_dtype, value_dtype>)
1016 
1017 REGISTER_KERNEL(int32, double);
1018 REGISTER_KERNEL(int32, float);
1019 REGISTER_KERNEL(int32, int32);
1020 REGISTER_KERNEL(int64, double);
1021 REGISTER_KERNEL(int64, float);
1022 REGISTER_KERNEL(int64, int32);
1023 REGISTER_KERNEL(int64, int64);
1024 REGISTER_KERNEL(int64, string);
1025 REGISTER_KERNEL(int64, Variant);
1026 REGISTER_KERNEL(string, bool);
1027 REGISTER_KERNEL(string, double);
1028 REGISTER_KERNEL(string, float);
1029 REGISTER_KERNEL(string, int32);
1030 REGISTER_KERNEL(string, int64);
1031 
1032 #undef REGISTER_KERNEL
1033 
1034 // Register the MutableHashTableOfTensors op.
1035 #define REGISTER_KERNEL(key_dtype, value_dtype)                                \
1036   REGISTER_KERNEL_BUILDER(                                                     \
1037       Name("MutableHashTableOfTensors")                                        \
1038           .Device(DEVICE_CPU)                                                  \
1039           .TypeConstraint<key_dtype>("key_dtype")                              \
1040           .TypeConstraint<value_dtype>("value_dtype"),                         \
1041       LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1042                     key_dtype, value_dtype>)                                   \
1043   REGISTER_KERNEL_BUILDER(                                                     \
1044       Name("MutableHashTableOfTensorsV2")                                      \
1045           .Device(DEVICE_CPU)                                                  \
1046           .TypeConstraint<key_dtype>("key_dtype")                              \
1047           .TypeConstraint<value_dtype>("value_dtype"),                         \
1048       LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
1049                     key_dtype, value_dtype>)
1050 
1051 REGISTER_KERNEL(int32, double);
1052 REGISTER_KERNEL(int32, float);
1053 REGISTER_KERNEL(int32, int32);
1054 REGISTER_KERNEL(int64, double);
1055 REGISTER_KERNEL(int64, float);
1056 REGISTER_KERNEL(int64, int32);
1057 REGISTER_KERNEL(int64, int64);
1058 REGISTER_KERNEL(int64, string);
1059 REGISTER_KERNEL(string, bool);
1060 REGISTER_KERNEL(string, double);
1061 REGISTER_KERNEL(string, float);
1062 REGISTER_KERNEL(string, int32);
1063 REGISTER_KERNEL(string, int64);
1064 
1065 #undef REGISTER_KERNEL
1066 
1067 // Register the MutableDenseHashTable op.
1068 #define REGISTER_KERNEL(key_dtype, value_dtype)                            \
1069   REGISTER_KERNEL_BUILDER(                                                 \
1070       Name("MutableDenseHashTable")                                        \
1071           .Device(DEVICE_CPU)                                              \
1072           .TypeConstraint<key_dtype>("key_dtype")                          \
1073           .TypeConstraint<value_dtype>("value_dtype"),                     \
1074       LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1075                     key_dtype, value_dtype>)                               \
1076   REGISTER_KERNEL_BUILDER(                                                 \
1077       Name("MutableDenseHashTableV2")                                      \
1078           .Device(DEVICE_CPU)                                              \
1079           .TypeConstraint<key_dtype>("key_dtype")                          \
1080           .TypeConstraint<value_dtype>("value_dtype"),                     \
1081       LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
1082                     key_dtype, value_dtype>)
1083 
1084 REGISTER_KERNEL(int32, double);
1085 REGISTER_KERNEL(int32, float);
1086 REGISTER_KERNEL(int32, int32);
1087 REGISTER_KERNEL(int64, bool);
1088 REGISTER_KERNEL(int64, double);
1089 REGISTER_KERNEL(int64, float);
1090 REGISTER_KERNEL(int64, int32);
1091 REGISTER_KERNEL(int64, int64);
1092 REGISTER_KERNEL(int64, Variant);
1093 REGISTER_KERNEL(string, bool);
1094 REGISTER_KERNEL(string, double);
1095 REGISTER_KERNEL(string, float);
1096 REGISTER_KERNEL(string, int32);
1097 REGISTER_KERNEL(string, int64);
1098 
1099 #undef REGISTER_KERNEL
1100 
1101 }  // namespace tensorflow
1102