• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16# pylint: disable=g-bad-name
17import collections
18import functools
19import uuid
20
21from tensorflow.python.checkpoint import saveable_compat
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_lookup_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import string_ops
34# go/tf-wildcard-import
35# pylint: disable=wildcard-import
36from tensorflow.python.ops.gen_lookup_ops import *
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.saved_model import registration
39from tensorflow.python.trackable import asset
40# pylint: enable=wildcard-import
41from tensorflow.python.trackable import base as trackable_base
42from tensorflow.python.trackable import resource
43from tensorflow.python.training.saver import BaseSaverBuilder
44from tensorflow.python.util import compat as compat_util
45from tensorflow.python.util.deprecation import deprecated
46from tensorflow.python.util.tf_export import tf_export
47
48
49@tf_export(v1=["initialize_all_tables"])
50@deprecated(None, "Use `tf.tables_initializer` instead.")
51def initialize_all_tables(name="init_all_tables"):
52  """Returns an Op that initializes all tables of the default graph.
53
54  Args:
55    name: Optional name for the initialization op.
56
57  Returns:
58    An Op that initializes all tables.  Note that if there are
59    not tables the returned Op is a NoOp.
60  """
61  return tables_initializer(name)
62
63
64@tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
65def tables_initializer(name="init_all_tables"):
66  """Returns an Op that initializes all tables of the default graph.
67
68  Args:
69    name: Optional name for the initialization op.
70
71  Returns:
72    An Op that initializes all tables.  Note that if there are
73    not tables the returned Op is a NoOp.
74
75  @compatibility(TF2)
76  `tf.compat.v1.tables_initializer` is no longer needed with eager execution and
77  `tf.function`. In TF2, when creating an initializable table like a
78  `tf.lookup.StaticHashTable`, the table will automatically be initialized on
79  creation.
80
81  #### Before & After Usage Example
82
83  Before:
84
85  >>> with tf.compat.v1.Session():
86  ...   init = tf.compat.v1.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
87  ...   table = tf.compat.v1.lookup.StaticHashTable(init, default_value=-1)
88  ...   tf.compat.v1.tables_initializer().run()
89  ...   result = table.lookup(tf.constant(['a', 'c'])).eval()
90  >>> result
91  array([ 1, -1], dtype=int32)
92
93  After:
94
95  >>> init = tf.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
96  >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
97  >>> table.lookup(tf.constant(['a', 'c'])).numpy()
98  array([ 1, -1], dtype=int32)
99
100  @end_compatibility
101  """
102  initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
103  if initializers:
104    return control_flow_ops.group(*initializers, name=name)
105  return control_flow_ops.no_op(name=name)
106
107
108def check_table_dtypes(table, key_dtype, value_dtype):
109  """Check that the given key_dtype and value_dtype matches the table dtypes.
110
111  Args:
112    table: The table to check types against to.
113    key_dtype: The key data type to check.
114    value_dtype: The value data type to check.
115
116  Raises:
117    TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
118      types.
119  """
120  if key_dtype.base_dtype != table.key_dtype:
121    raise TypeError(f"Invalid key dtype for table, expected {table.key_dtype} "
122                    f"but got {key_dtype}.")
123  if value_dtype.base_dtype != table.value_dtype:
124    raise TypeError("Invalid value dtype for table, expected "
125                    f"{table.value_dtype} but got {value_dtype}.")
126
127
128class LookupInterface(resource.TrackableResource):
129  """Represent a lookup table that persists across different steps."""
130
131  def __init__(self, key_dtype, value_dtype):
132    """Construct a lookup table interface.
133
134    Args:
135      key_dtype: The table key type.
136      value_dtype: The table value type.
137    """
138    self._key_dtype = dtypes.as_dtype(key_dtype)
139    self._value_dtype = dtypes.as_dtype(value_dtype)
140    super(LookupInterface, self).__init__()
141
142  def _create_resource(self):
143    raise NotImplementedError
144
145  @property
146  def key_dtype(self):
147    """The table key dtype."""
148    return self._key_dtype
149
150  @property
151  def value_dtype(self):
152    """The table value dtype."""
153    return self._value_dtype
154
155  @property
156  def name(self):
157    """The name of the table."""
158    return NotImplementedError
159
160  def size(self, name=None):
161    """Compute the number of elements in this table."""
162    raise NotImplementedError
163
164  def lookup(self, keys, name=None):
165    """Looks up `keys` in a table, outputs the corresponding values."""
166    raise NotImplementedError
167
168  def __getitem__(self, keys):
169    """Looks up `keys` in a table, outputs the corresponding values."""
170    return self.lookup(keys)
171
172
173class InitializableLookupTableBase(LookupInterface):
174  """Initializable lookup table interface.
175
176  An initializable lookup tables persist across different steps.
177  """
178
179  def __init__(self, default_value, initializer):
180    """Construct a table object from a table reference.
181
182    If requires a table initializer object (subclass of `TableInitializerBase`).
183    It provides the table key and value types, as well as the op to initialize
184    the table. The caller is responsible to execute the initialization op.
185
186    Args:
187      default_value: The value to use if a key is missing in the table.
188      initializer: The table initializer to use.
189    """
190    super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
191                                                       initializer.value_dtype)
192    self._default_value = ops.convert_to_tensor(
193        default_value, dtype=self._value_dtype)
194    self._default_value.get_shape().merge_with(tensor_shape.TensorShape([]))
195    if isinstance(initializer, trackable_base.Trackable):
196      self._initializer = self._track_trackable(initializer, "_initializer")
197    with ops.init_scope():
198      self._resource_handle = self._create_resource()
199    if (not context.executing_eagerly() and
200        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
201      with ops.init_scope():
202        self._init_op = self._initialize()
203    else:
204      self._init_op = self._initialize()
205
206  def _initialize(self):
207    return self._initializer.initialize(self)
208
209  @property
210  def default_value(self):
211    """The default value of the table."""
212    return self._default_value
213
214  def size(self, name=None):
215    """Compute the number of elements in this table.
216
217    Args:
218      name: A name for the operation (optional).
219
220    Returns:
221      A scalar tensor containing the number of elements in this table.
222    """
223    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
224      return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
225
226  def lookup(self, keys, name=None):
227    """Looks up `keys` in a table, outputs the corresponding values.
228
229    The `default_value` is used for keys not present in the table.
230
231    Args:
232      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
233      name: A name for the operation (optional).
234
235    Returns:
236      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
237      otherwise a dense `Tensor`.
238
239    Raises:
240      TypeError: when `keys` or `default_value` doesn't match the table data
241        types.
242    """
243    key_tensor = keys
244    if isinstance(keys,
245                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
246      key_tensor = keys.values
247
248    if keys.dtype.base_dtype != self._key_dtype:
249      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
250                      f"received: {keys.dtype}")
251
252    with ops.name_scope(
253        name, "%s_Lookup" % self.name,
254        (self.resource_handle, key_tensor, self._default_value)):
255      values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle,
256                                                   key_tensor,
257                                                   self._default_value)
258
259    values.set_shape(key_tensor.get_shape())
260    if isinstance(keys, sparse_tensor.SparseTensor):
261      return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
262    elif isinstance(keys, ragged_tensor.RaggedTensor):
263      return keys.with_values(values)
264    else:
265      return values
266
267
268class InitializableLookupTableBaseV1(InitializableLookupTableBase):
269
270  @property
271  def initializer(self):
272    return self._init_op
273
274
275@registration.register_tf_serializable(
276    predicate=lambda obj: isinstance(obj, StaticHashTable))
277@tf_export("lookup.StaticHashTable", v1=[])
278class StaticHashTable(InitializableLookupTableBase):
279  """A generic hash table that is immutable once initialized.
280
281  Example usage:
282
283  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
284  >>> vals_tensor = tf.constant([7, 8, 9])
285  >>> input_tensor = tf.constant(['a', 'f'])
286  >>> table = tf.lookup.StaticHashTable(
287  ...     tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
288  ...     default_value=-1)
289  >>> table.lookup(input_tensor).numpy()
290  array([ 7, -1], dtype=int32)
291
292  Or for more pythonic code:
293
294  >>> table[input_tensor].numpy()
295  array([ 7, -1], dtype=int32)
296
297  The result of a lookup operation has the same shape as the argument:
298
299  >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']])
300  >>> table[input_tensor].numpy()
301  array([[ 7,  8],
302         [ 9, -1]], dtype=int32)
303
304
305  """
306
307  def __init__(self,
308               initializer,
309               default_value,
310               name=None,
311               experimental_is_anonymous=False):
312    """Creates a non-initialized `HashTable` object.
313
314    Creates a table, the type of its keys and values are specified by the
315    initializer.
316    Before using the table you will have to initialize it. After initialization
317    the table will be immutable.
318
319    Args:
320      initializer: The table initializer to use. See `HashTable` kernel for
321        supported key and value types.
322      default_value: The value to use if a key is missing in the table.
323      name: A name for the operation (optional).
324      experimental_is_anonymous: Whether to use anonymous mode for the
325        table (default is False). In anonymous mode, the table
326        resource can only be accessed via a resource handle. It can't
327        be looked up by a name. When all resource handles pointing to
328        that resource are gone, the resource will be deleted
329        automatically.
330
331    Returns:
332      A `HashTable` object.
333    """
334    self._initializer = initializer
335    self._default_value = default_value
336    self._is_anonymous = experimental_is_anonymous
337    if not self._is_anonymous:
338      self._shared_name = self._initializer._shared_name  # pylint: disable=protected-access
339      if not self._shared_name:
340        # Force using a shared name so that StaticHashTable resources can be
341        # shared across different kernels. If no "shared_name" is set and
342        # "use_node_name_sharing" is False, then each kernel gets its own local
343        # resource.
344        self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),)
345    self._name = name or "hash_table"
346    self._table_name = None
347    super(StaticHashTable, self).__init__(default_value, initializer)
348    self._value_shape = self._default_value.get_shape()
349
350  def _create_resource(self):
351    if self._is_anonymous:
352      table_ref = gen_lookup_ops.anonymous_hash_table(
353          key_dtype=self._initializer.key_dtype,
354          value_dtype=self._initializer.value_dtype,
355          name=self._name)
356    else:
357      table_ref = gen_lookup_ops.hash_table_v2(
358          shared_name=self._shared_name,
359          key_dtype=self._initializer.key_dtype,
360          value_dtype=self._initializer.value_dtype,
361          name=self._name)
362    if context.executing_eagerly():
363      self._table_name = None
364    else:
365      self._table_name = table_ref.op.name.split("/")[-1]
366    return table_ref
367
368  @property
369  def name(self):
370    return self._table_name
371
372  def export(self, name=None):
373    """Returns tensors of all keys and values in the table.
374
375    Args:
376      name: A name for the operation (optional).
377
378    Returns:
379      A pair of tensors with the first tensor containing all keys and the
380        second tensors containing all values in the table.
381    """
382    with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
383      exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
384          self.resource_handle, self._key_dtype, self._value_dtype)
385
386    exported_values.set_shape(exported_keys.get_shape().concatenate(
387        self._value_shape))
388    return exported_keys, exported_values
389
390  def _serialize_to_proto(self, **unused_kwargs):
391    return None
392
393  def _add_trackable_child(self, name, value):
394    setattr(self, name, value)
395    if isinstance(value, trackable_base.Trackable):
396      self._track_trackable(value, name)  # pylint:disable=protected-access
397
398  @classmethod
399  def _deserialize_from_proto(cls, **kwargs):
400
401    class _RestoredStaticHashTable(resource.RestoredResource):  # pylint: disable=protected-access
402
403      @classmethod
404      def _resource_type(cls):
405        return "RestoredStaticHashTable"
406
407    return _RestoredStaticHashTable._deserialize_from_proto(**kwargs)  # pylint: disable=protected-access
408
409
410@tf_export(v1=["lookup.StaticHashTable"])
411class StaticHashTableV1(StaticHashTable):
412  """A generic hash table that is immutable once initialized.
413
414  When running in graph mode, you must evaluate the tensor returned by
415  `tf.tables_initializer()` before evaluating the tensor returned by
416  this class's `lookup()` method. Example usage in graph mode:
417
418  ```python
419  keys_tensor = tf.constant([1, 2])
420  vals_tensor = tf.constant([3, 4])
421  input_tensor = tf.constant([1, 5])
422  table = tf.lookup.StaticHashTable(
423      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
424  out = table.lookup(input_tensor)
425  with tf.Session() as sess:
426      sess.run(tf.tables_initializer())
427      print(sess.run(out))
428  ```
429
430  Note that in graph mode if you set `experimental_is_anonymous` to
431  `True`, you should only call `Session.run` once, otherwise each
432  `Session.run` will create (and destroy) a new table unrelated to
433  each other, leading to errors such as "Table not initialized".
434  You can do so like this:
435
436  ```python
437  keys_tensor = tf.constant([1, 2])
438  vals_tensor = tf.constant([3, 4])
439  input_tensor = tf.constant([1, 5])
440  table = tf.lookup.StaticHashTable(
441      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1,
442      experimental_is_anonymous=True)
443  with tf.control_dependencies([tf.tables_initializer()]):
444    out = table.lookup(input_tensor)
445  with tf.Session() as sess:
446    print(sess.run(out))
447  ```
448
449  In eager mode, no special code is needed to initialize the table.
450  Example usage in eager mode:
451
452  ```python
453  tf.enable_eager_execution()
454  keys_tensor = tf.constant([1, 2])
455  vals_tensor = tf.constant([3, 4])
456  input_tensor = tf.constant([1, 5])
457  table = tf.lookup.StaticHashTable(
458      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
459  print(table.lookup(input_tensor))
460  ```
461  """
462
463  @property
464  def initializer(self):
465    return self._init_op
466
467
468# For backwards compatibility. This will be removed in TF 2.0.
469class HashTable(StaticHashTableV1):
470
471  @property
472  def init(self):
473    return self.initializer
474
475
476class TableInitializerBase(trackable_base.Trackable):
477  """Base class for lookup table initializers."""
478
479  def __init__(self, key_dtype, value_dtype):
480    """Construct a table initializer object.
481
482    Args:
483      key_dtype: Type of the table keys.
484      value_dtype: Type of the table values.
485    """
486    self._key_dtype = dtypes.as_dtype(key_dtype)
487    self._value_dtype = dtypes.as_dtype(value_dtype)
488
489  @property
490  def key_dtype(self):
491    """The expected table key dtype."""
492    return self._key_dtype
493
494  @property
495  def value_dtype(self):
496    """The expected table value dtype."""
497    return self._value_dtype
498
499  def initialize(self, table):
500    """Returns the table initialization op."""
501    raise NotImplementedError
502
503  @property
504  def _shared_name(self):
505    """Returns a shared name to be used by the table."""
506    shared_name = ""
507    if context.executing_eagerly():
508      # Ensure a unique name when eager execution is enabled to avoid spurious
509      # sharing issues.
510      # TODO(rohanj): Use context.anonymous_name() instead.
511      shared_name += str(ops.uid())
512    return shared_name
513
514
515@tf_export("lookup.KeyValueTensorInitializer")
516class KeyValueTensorInitializer(TableInitializerBase):
517  """Table initializers given `keys` and `values` tensors.
518
519  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
520  >>> vals_tensor = tf.constant([7, 8, 9])
521  >>> input_tensor = tf.constant(['a', 'f'])
522  >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor)
523  >>> table = tf.lookup.StaticHashTable(
524  ...     init,
525  ...     default_value=-1)
526  >>> table.lookup(input_tensor).numpy()
527  array([ 7, -1], dtype=int32)
528
529  """
530
531  def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
532    """Constructs a table initializer object based on keys and values tensors.
533
534    Args:
535      keys: The tensor for the keys.
536      values: The tensor for the values.
537      key_dtype: The `keys` data type. Used when `keys` is a python array.
538      value_dtype: The `values` data type. Used when `values` is a python array.
539      name: A name for the operation (optional).
540    """
541    if (not context.executing_eagerly() and
542        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
543      with ops.init_scope():
544        self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
545        self._values = ops.convert_to_tensor(
546            values, dtype=value_dtype, name="values")
547    else:
548      self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
549      self._values = ops.convert_to_tensor(
550          values, dtype=value_dtype, name="values")
551    self._name = name if name is not None else "key_value_init"
552    if context.executing_eagerly():
553      # Ensure a unique name when eager execution is enabled to avoid spurious
554      # sharing issues.
555      # TODO(rohanj): Use context.anonymous_name() instead.
556      self._name += str(ops.uid())
557
558    super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
559                                                    self._values.dtype)
560
561  def initialize(self, table):
562    """Initializes the given `table` with `keys` and `values` tensors.
563
564    Args:
565      table: The table to initialize.
566
567    Returns:
568      The operation that initializes the table.
569
570    Raises:
571      TypeError: when the keys and values data types do not match the table
572      key and value data types.
573    """
574    check_table_dtypes(table, self._keys.dtype, self._values.dtype)
575    with ops.name_scope(
576        self._name, values=(table.resource_handle, self._keys, self._values)):
577      init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle,
578                                                      self._keys, self._values)
579    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
580    return init_op
581
582
583@tf_export("lookup.TextFileIndex")
584class TextFileIndex:
585  """The key and value content to get from each line.
586
587  This class defines the key and value used for `tf.lookup.TextFileInitializer`.
588
589  The key and value content to get from each line is specified either
590  by the following, or a value `>=0`.
591  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
592    expects data type int64.
593  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
594    type string.
595
596  A value `>=0` means use the index (starting at zero) of the split line based
597      on `delimiter`.
598  """
599  WHOLE_LINE = -2
600  LINE_NUMBER = -1
601
602
603@tf_export("lookup.TextFileInitializer")
604class TextFileInitializer(TableInitializerBase):
605  r"""Table initializers from a text file.
606
607  This initializer assigns one entry in the table for each line in the file.
608
609  The key and value type of the table to initialize is given by `key_dtype` and
610  `value_dtype`.
611
612  The key and value content to get from each line is specified by
613  the `key_index` and `value_index`.
614
615  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
616    expects data type int64.
617  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
618    type string.
619  * A value `>=0` means use the index (starting at zero) of the split line based
620      on `delimiter`.
621
622  For example if we have a file with the following content:
623
624  >>> import tempfile
625  >>> f = tempfile.NamedTemporaryFile(delete=False)
626  >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",])
627  >>> f.file.write(content.encode('utf-8'))
628  >>> f.file.close()
629
630  The following snippet initializes a table with the first column as keys and
631  second column as values:
632
633  * `emerson -> 10`
634  * `lake -> 20`
635  * `palmer -> 30`
636
637  >>> init= tf.lookup.TextFileInitializer(
638  ...    filename=f.name,
639  ...    key_dtype=tf.string, key_index=0,
640  ...    value_dtype=tf.int64, value_index=1,
641  ...    delimiter=" ")
642  >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
643  >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy()
644
645  Similarly to initialize the whole line as keys and the line number as values.
646
647  * `emerson 10 -> 0`
648  * `lake 20 -> 1`
649  * `palmer 30 -> 2`
650
651  >>> init = tf.lookup.TextFileInitializer(
652  ...   filename=f.name,
653  ...   key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
654  ...   value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
655  >>> table = tf.lookup.StaticHashTable(init, -1)
656  >>> table.lookup(tf.constant('palmer 30')).numpy()
657  2
658  """
659
660  def __init__(self,
661               filename,
662               key_dtype,
663               key_index,
664               value_dtype,
665               value_index,
666               vocab_size=None,
667               delimiter="\t",
668               name=None,
669               value_index_offset=0):
670    """Constructs a table initializer object to populate from a text file.
671
672    It generates one key-value pair per line. The type of table key and
673    value are specified by `key_dtype` and `value_dtype`, respectively.
674    Similarly the content of the key and value are specified by the key_index
675    and value_index.
676
677    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
678      expects data type int64.
679    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
680      type string or int64.
681    - A value >=0 means use the index (starting at zero) of the split line based
682      on `delimiter`.
683
684    Args:
685      filename: The filename of the text file to be used for initialization. The
686        path must be accessible from wherever the graph is initialized (eg.
687        trainer or eval workers). The filename may be a scalar `Tensor`.
688      key_dtype: The `key` data type.
689      key_index: the index that represents information of a line to get the
690        table 'key' values from.
691      value_dtype: The `value` data type.
692      value_index: the index that represents information of a line to get the
693        table 'value' values from.'
694      vocab_size: The number of elements in the file, if known.
695      delimiter: The delimiter to separate fields in a line.
696      name: A name for the operation (optional).
697      value_index_offset: A number to add to all indices extracted from the file
698        This is useful for cases where a user would like to reserve one or more
699        low index values for control characters. For instance, if you would
700        like to ensure that no vocabulary item is mapped to index 0 (so you can
701        reserve 0 for a masking value), you can set value_index_offset to 1;
702        this will mean that the first vocabulary element is mapped to 1
703        instead of 0.
704
705    Raises:
706      ValueError: when the filename is empty, or when the table key and value
707      data types do not match the expected data types.
708    """
709    if not isinstance(filename, ops.Tensor) and not filename:
710      raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer")
711
712    self._filename_arg = filename
713    key_dtype = dtypes.as_dtype(key_dtype)
714    value_dtype = dtypes.as_dtype(value_dtype)
715
716    if key_index < -2:
717      raise ValueError(f"`key_index` should be >= -2, received: {key_index}.")
718
719    if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
720      raise ValueError("`key_dtype` must be int64 if `key_index` is "
721                       f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}")
722    if ((key_index == TextFileIndex.WHOLE_LINE) and
723        (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
724      raise ValueError(
725          "`key_dtype` should be either integer or string for `key_index` "
726          f"{TextFileIndex.WHOLE_LINE}, received: {key_dtype}")
727    if value_index < -2:
728      raise ValueError("`value_index` should be >= -2, received: "
729                       f"{value_index}")
730
731    if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
732      raise ValueError("`value_dtype` must be int64 for `value_index` "
733                       f"{TextFileIndex.LINE_NUMBER}, received: {value_dtype}")
734    if ((value_index == TextFileIndex.WHOLE_LINE) and
735        (not value_dtype.is_integer) and (value_dtype != dtypes.string)):
736      raise ValueError(
737          "`value_dtype` should be either integer or string for `value_index` "
738          f"{TextFileIndex.WHOLE_LINE}, received: {value_dtype}")
739
740    if (vocab_size is not None) and (vocab_size <= 0):
741      raise ValueError(f"`vocab_size` should be > 0, received: {vocab_size}")
742
743    self._key_index = key_index
744    self._value_index = value_index
745    self._vocab_size = vocab_size
746    self._delimiter = delimiter
747    self._name = name
748    self._filename = self._track_trackable(
749        asset.Asset(filename), "_filename")
750    self._offset = value_index_offset
751
752    super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
753
754  def initialize(self, table):
755    """Initializes the table from a text file.
756
757    Args:
758      table: The table to be initialized.
759
760    Returns:
761      The operation that initializes the table.
762
763    Raises:
764      TypeError: when the keys and values data types do not match the table
765      key and value data types.
766    """
767    check_table_dtypes(table, self.key_dtype, self.value_dtype)
768    with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
769      filename = ops.convert_to_tensor(
770          self._filename, dtypes.string, name="asset_filepath")
771      init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
772          table.resource_handle, filename, self._key_index, self._value_index,
773          -1 if self._vocab_size is None else self._vocab_size, self._delimiter,
774          self._offset)
775    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
776    # If the filename tensor is anything other than a string constant (e.g.,
777    # if it is a placeholder) then it does not make sense to track it as an
778    # asset.
779    if not context.executing_eagerly() and constant_op.is_constant(filename):
780      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
781    return init_op
782
783  @property
784  def _shared_name(self):
785    if self._vocab_size:
786      # Keep the shared_name:
787      # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>_<offset>
788      if self._offset:
789        shared_name = "hash_table_%s_%d_%s_%s_%s" % (
790            self._filename_arg, self._vocab_size, self._key_index,
791            self._value_index, self._offset)
792      else:
793        shared_name = "hash_table_%s_%d_%s_%s" % (
794            self._filename_arg, self._vocab_size, self._key_index,
795            self._value_index)
796    else:
797      # Keep the shared_name
798      # <table_type>_<filename>_<key_index>_<value_index>_<offset>
799      if self._offset:
800        shared_name = "hash_table_%s_%s_%s_%s" % (
801            self._filename_arg, self._key_index, self._value_index,
802            self._offset)
803      else:
804        shared_name = "hash_table_%s_%s_%s" % (
805            self._filename_arg, self._key_index, self._value_index)
806
807    return shared_name
808
809
810class TextFileStringTableInitializer(TextFileInitializer):
811  """Table initializer for `int64` IDs to string tables from a text file."""
812
813  def __init__(self,
814               filename,
815               key_column_index=TextFileIndex.LINE_NUMBER,
816               value_column_index=TextFileIndex.WHOLE_LINE,
817               vocab_size=None,
818               delimiter="\t",
819               name="text_file_string_table_init"):
820    """Constructs an initializer for an id-to-string table from a text file.
821
822    It populates a table that its key and value types are int64 and string,
823    respectively. It generates one key-value pair per line.
824    The content of the key and value are specified by `key_column_index`
825    and `value_column_index`.
826
827    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
828      expects data type int64.
829    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
830      type string or int64.
831    - A value >=0 means use the index (starting at zero) of the split line based
832      on `delimiter`.
833
834    Args:
835      filename: The filename of the text file to be used for initialization. The
836        path must be accessible from wherever the graph is initialized (eg.
837        trainer or eval workers). The filename may be a scalar `Tensor`.
838      key_column_index: The column index from the text file to get the keys
839        from. The default is to use the line number, starting from zero.
840      value_column_index: The column index from the text file to get the values
841        from. The default is to use the whole line content.
842      vocab_size: The number of elements in the file, if known.
843      delimiter: The delimiter to separate fields in a line.
844      name: Optional name for the op.
845
846    Raises:
847      TypeError: when the filename is empty, or when the table key and value
848      data types do not match the expected data types.
849    """
850    super(TextFileStringTableInitializer, self).__init__(
851        filename,
852        dtypes.int64,
853        key_column_index,
854        dtypes.string,
855        value_column_index,
856        vocab_size=vocab_size,
857        delimiter=delimiter,
858        name=name)
859
860
861class TextFileIdTableInitializer(TextFileInitializer):
862  """Table initializer for string to `int64` IDs tables from a text file."""
863
864  def __init__(self,
865               filename,
866               key_column_index=TextFileIndex.WHOLE_LINE,
867               value_column_index=TextFileIndex.LINE_NUMBER,
868               vocab_size=None,
869               delimiter="\t",
870               name="text_file_id_table_init",
871               key_dtype=dtypes.string):
872    """Constructs an initializer for an string-to-id table from a text file.
873
874    It populates a table that its key and value types are string and int64,
875    respectively. It generates one key-value pair per line.
876    The content of the key and value are specified by the key_index
877    and value_index.
878
879    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
880      expects data type int64.
881    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
882      type string.
883    - A value >=0 means use the index (starting at zero) of the split line based
884      on `delimiter`.
885
886    Args:
887      filename: The filename of the text file to be used for initialization. The
888        path must be accessible from wherever the graph is initialized (eg.
889        trainer or eval workers). The filename may be a scalar `Tensor`.
890      key_column_index: The column index from the text file to get the `key`
891        values from. The default is to use the whole line content.
892      value_column_index: The column index from the text file to get the `value`
893        values from. The default is to use the line number, starting from zero.
894      vocab_size: The number of elements in the file, if known.
895      delimiter: The delimiter to separate fields in a line.
896      name: Optional name for the op.
897      key_dtype: The `key` data type.
898
899    Raises:
900      TypeError: when the filename is empty, or when the table key and value
901      data types do not match the expected data types.
902    """
903    super(TextFileIdTableInitializer, self).__init__(
904        filename,
905        key_dtype,
906        key_column_index,
907        dtypes.int64,
908        value_column_index,
909        vocab_size=vocab_size,
910        delimiter=delimiter,
911        name=name)
912
913
914class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
915  """A structure for the spec of the hashing function to use for hash buckets.
916
917  `hasher` is the name of the hashing function to use (eg. "fasthash",
918  "stronghash").
919  `key` is optional and specify the key to use for the hash function if
920  supported, currently only used by a strong hash.
921
922  Fields:
923    hasher: The hasher name to use.
924    key: The key to be used by the hashing function, if required.
925  """
926  __slots__ = ()
927
928
929FastHashSpec = HasherSpec("fasthash", None)  # pylint: disable=invalid-name
930
931
932class StrongHashSpec(HasherSpec):
933  """A structure to specify a key of the strong keyed hash spec.
934
935  The strong hash requires a `key`, which is a list of 2 unsigned integer
936  numbers. These should be non-zero; random numbers generated from random.org
937  would be a fine choice.
938
939  Fields:
940    key: The key to be used by the keyed hashing function.
941  """
942  __slots__ = ()
943
944  def __new__(cls, key):
945    if len(key) != 2:
946      raise ValueError(f"`key` must have size 2, received {len(key)}")
947
948    if not isinstance(key[0], compat_util.integral_types) or not isinstance(
949        key[1], compat_util.integral_types):
950      raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
951
952    return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
953
954
955def _as_string(tensor):
956  if dtypes.string == tensor.dtype.base_dtype:
957    return tensor
958  return string_ops.as_string(tensor)
959
960
961class IdTableWithHashBuckets(LookupInterface):
962  r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
963
964  For example, if an instance of `IdTableWithHashBuckets` is initialized with a
965  string-to-id table that maps:
966
967  * `emerson -> 0`
968  * `lake -> 1`
969  * `palmer -> 2`
970
971  The `IdTableWithHashBuckets` object will performs the following mapping:
972
973  * `emerson -> 0`
974  * `lake -> 1`
975  * `palmer -> 2`
976  * `<other term> -> bucket_id`, where bucket_id will be between `3` and
977  `3 + num_oov_buckets - 1`, calculated by:
978  `hash(<term>) % num_oov_buckets + vocab_size`
979
980  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
981  the lookup result is `[0, 1, 2, 4, 7]`.
982
983  If `table` is None, only out-of-vocabulary buckets are used.
984
985  Example usage:
986
987  ```python
988  num_oov_buckets = 3
989  input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
990  table = tf.IdTableWithHashBuckets(
991      tf.StaticHashTable(
992          tf.lookup.TextFileInitializer(
993              filename,
994              key_dtype=tf.string,
995              key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
996              value_dtype=tf.int64,
997              value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
998              delimiter="\t"),
999          default_value),
1000      num_oov_buckets)
1001  out = table.lookup(input_tensor).
1002  table.init.run()
1003  print(out.eval())
1004  ```
1005
1006  The hash function used for generating out-of-vocabulary buckets ID is handled
1007  by `hasher_spec`.
1008  """
1009
1010  def __init__(self,
1011               table,
1012               num_oov_buckets,
1013               hasher_spec=FastHashSpec,
1014               name=None,
1015               key_dtype=None):
1016    """Construct a `IdTableWithHashBuckets` object.
1017
1018    Args:
1019      table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
1020      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
1021      hasher_spec: A `HasherSpec` to specify the hash function to use for
1022        assignation of out-of-vocabulary buckets  (optional).
1023      name: A name for the operation (optional).
1024      key_dtype: Data type of keys passed to `lookup`. Defaults to
1025        `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must
1026        be string or integer, and must be castable to `table.key_dtype`.
1027
1028    Raises:
1029      ValueError: when `table` in None and `num_oov_buckets` is not positive.
1030      TypeError: when `hasher_spec` is invalid.
1031    """
1032    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1033    # characters to use as table name.
1034    if name:
1035      name = name.rstrip("/")
1036    if table:
1037      if key_dtype is None:
1038        key_dtype = table.key_dtype
1039      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1040      if table.key_dtype not in supported_table_key_dtypes:
1041        raise TypeError("Invalid `key_dtype`, expected one of "
1042                        f"{supported_table_key_dtypes}, received {key_dtype}.")
1043      if table.key_dtype.is_integer != key_dtype.is_integer:
1044        raise TypeError("Invalid `key dtype`, expected %s but got %s." %
1045                        ("integer" if key_dtype.is_integer else "non-integer",
1046                         table.key_dtype))
1047      if table.value_dtype != dtypes.int64:
1048        raise TypeError("Invalid `value_dtype`: expected int64 but got %s." %
1049                        (table.value_dtype))
1050      self._table = table
1051      name = name or self._table.name
1052    else:
1053      if num_oov_buckets <= 0:
1054        raise ValueError("`oov_buckets` must be > 0 if no `table` is supplied.")
1055      key_dtype = dtypes.string if key_dtype is None else key_dtype
1056      self._table = None
1057      name = name or "hash_bucket"
1058    if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
1059      raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1060                      f"{key_dtype}.")
1061    self._num_oov_buckets = num_oov_buckets
1062
1063    if not isinstance(hasher_spec, HasherSpec):
1064      raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1065                      f"{type(hasher_spec)}.")
1066    self._hasher_spec = hasher_spec
1067    if name:
1068      self._table_name = name.split("/")[-1]
1069    else:
1070      self._table_name = None
1071    super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
1072
1073  def _create_resource(self):
1074    if self._table is not None:
1075      return self._table._create_resource()  # pylint: disable=protected-access
1076    return None
1077
1078  def _initialize(self):
1079    if self._table is not None:
1080      return self._table._initialize()  # pylint: disable=protected-access
1081    with ops.name_scope(None, "init"):
1082      return control_flow_ops.no_op()
1083
1084  @property
1085  def initializer(self):
1086    if self._table is not None:
1087      return self._table._init_op  # pylint: disable=protected-access
1088    with ops.name_scope(None, "init"):
1089      return control_flow_ops.no_op()
1090
1091  @property
1092  @deprecated("2018-12-15", "Use `initializer` instead.")
1093  def init(self):
1094    return self.initializer
1095
1096  @property
1097  def resource_handle(self):
1098    if self._table is not None:
1099      return self._table.resource_handle
1100    return None
1101
1102  @property
1103  def name(self):
1104    return self._table_name
1105
1106  def size(self, name=None):
1107    """Compute the number of elements in this table."""
1108    with ops.name_scope(name, "%s_Size" % self.name):
1109      if self._table:
1110        tsize = self._table.size()
1111      else:
1112        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1113      return tsize + self._num_oov_buckets
1114
1115  def _get_string_to_hash_bucket_fn(self, hasher_spec):
1116    """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
1117    if not isinstance(hasher_spec, HasherSpec):
1118      raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1119                      f"{type(hasher_spec)}.")
1120    if hasher_spec.hasher == "fasthash":
1121      return string_ops.string_to_hash_bucket_fast
1122    if hasher_spec.hasher == "legacy":
1123      return string_ops.string_to_hash_bucket
1124    if hasher_spec.hasher == "stronghash":
1125      return functools.partial(
1126          string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
1127    raise ValueError(
1128        f"Found unknown hasher {hasher_spec.hasher} in `hasher_spec`")
1129
1130  def lookup(self, keys, name=None):
1131    """Looks up `keys` in the table, outputs the corresponding values.
1132
1133    It assigns out-of-vocabulary keys to buckets based in their hashes.
1134
1135    Args:
1136      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1137      name: Optional name for the op.
1138
1139    Returns:
1140      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1141      otherwise a dense `Tensor`.
1142
1143    Raises:
1144      TypeError: when `keys` doesn't match the table key data type.
1145    """
1146    if keys.dtype.base_dtype != self._key_dtype:
1147      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1148                      f"received: {keys.dtype}")
1149    values = keys
1150    if isinstance(keys,
1151                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1152      values = keys.values
1153    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1154      values = math_ops.cast(values, dtypes.int64)
1155
1156    if self._num_oov_buckets == 0:
1157      ids = self._table.lookup(values, name=name)
1158    else:
1159      # TODO(yleon): Consider moving this functionality to its own kernel.
1160      with ops.name_scope(name, "%s_Lookup" % self.name):
1161        str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
1162            self._hasher_spec)
1163        buckets = str_to_hash_bucket(
1164            _as_string(values),
1165            num_buckets=self._num_oov_buckets,
1166            name="hash_bucket")
1167        if self._table:
1168          ids = self._table.lookup(values)
1169          buckets = math_ops.add(buckets, self._table.size())
1170          is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1171          ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1172        else:
1173          ids = buckets
1174    if isinstance(keys, sparse_tensor.SparseTensor):
1175      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1176    elif isinstance(keys, ragged_tensor.RaggedTensor):
1177      return keys.with_values(ids)
1178    return ids
1179
1180
1181@tf_export("lookup.StaticVocabularyTable", v1=[])
1182class StaticVocabularyTable(LookupInterface):
1183  r"""String to Id table that assigns out-of-vocabulary keys to hash buckets.
1184
1185  For example, if an instance of `StaticVocabularyTable` is initialized with a
1186  string-to-id initializer that maps:
1187
1188  >>> init = tf.lookup.KeyValueTensorInitializer(
1189  ...     keys=tf.constant(['emerson', 'lake', 'palmer']),
1190  ...     values=tf.constant([0, 1, 2], dtype=tf.int64))
1191  >>> table = tf.lookup.StaticVocabularyTable(
1192  ...    init,
1193  ...    num_oov_buckets=5)
1194
1195  The `Vocabulary` object will performs the following mapping:
1196
1197  * `emerson -> 0`
1198  * `lake -> 1`
1199  * `palmer -> 2`
1200  * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and
1201  `3 + num_oov_buckets - 1 = 7`, calculated by:
1202  `hash(<term>) % num_oov_buckets + vocab_size`
1203
1204  If input_tensor is:
1205
1206  >>> input_tensor = tf.constant(["emerson", "lake", "palmer",
1207  ...                             "king", "crimson"])
1208  >>> table[input_tensor].numpy()
1209  array([0, 1, 2, 6, 7])
1210
1211  If `initializer` is None, only out-of-vocabulary buckets are used.
1212
1213  Example usage:
1214
1215  >>> num_oov_buckets = 3
1216  >>> vocab = ["emerson", "lake", "palmer", "crimnson"]
1217  >>> import tempfile
1218  >>> f = tempfile.NamedTemporaryFile(delete=False)
1219  >>> f.write('\n'.join(vocab).encode('utf-8'))
1220  >>> f.close()
1221
1222  >>> init = tf.lookup.TextFileInitializer(
1223  ...     f.name,
1224  ...     key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
1225  ...     value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
1226  >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
1227  >>> table.lookup(tf.constant(["palmer", "crimnson" , "king",
1228  ...                           "tarkus", "black", "moon"])).numpy()
1229  array([2, 3, 5, 6, 6, 4])
1230
1231  The hash function used for generating out-of-vocabulary buckets ID is
1232  Fingerprint64.
1233
1234  Note that the out-of-vocabulary bucket IDs always range from the table `size`
1235  up to `size + num_oov_buckets - 1` regardless of the table values, which could
1236  cause unexpected collisions:
1237
1238  >>> init = tf.lookup.KeyValueTensorInitializer(
1239  ...     keys=tf.constant(["emerson", "lake", "palmer"]),
1240  ...     values=tf.constant([1, 2, 3], dtype=tf.int64))
1241  >>> table = tf.lookup.StaticVocabularyTable(
1242  ...     init,
1243  ...     num_oov_buckets=1)
1244  >>> input_tensor = tf.constant(["emerson", "lake", "palmer", "king"])
1245  >>> table[input_tensor].numpy()
1246  array([1, 2, 3, 3])
1247  """
1248
1249  def __init__(self,
1250               initializer,
1251               num_oov_buckets,
1252               lookup_key_dtype=None,
1253               name=None,
1254               experimental_is_anonymous=False):
1255    """Construct a `StaticVocabularyTable` object.
1256
1257    Args:
1258      initializer: A `TableInitializerBase` object that contains the data used
1259        to initialize the table. If None, then we only use out-of-vocab buckets.
1260      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
1261        be greater than zero. If out-of-vocab buckets are not required, use
1262        `StaticHashTable` instead.
1263      lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
1264        `initializer.key_dtype` if `initializer` is specified, otherwise
1265        `tf.string`. Must be string or integer, and must be castable to
1266        `initializer.key_dtype`.
1267      name: A name for the operation (optional).
1268      experimental_is_anonymous: Whether to use anonymous mode for the
1269        table (default is False). In anonymous mode, the table
1270        resource can only be accessed via a resource handle. It can't
1271        be looked up by a name. When all resource handles pointing to
1272        that resource are gone, the resource will be deleted
1273        automatically.
1274
1275    Raises:
1276      ValueError: when `num_oov_buckets` is not positive.
1277      TypeError: when lookup_key_dtype or initializer.key_dtype are not
1278        integer or string. Also when initializer.value_dtype != int64.
1279    """
1280    if num_oov_buckets <= 0:
1281      raise ValueError("`num_oov_buckets` must be > 0; use StaticHashTable.")
1282    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1283    # characters to use as table name.
1284    if name:
1285      name = name.rstrip("/")
1286    if initializer:
1287      if lookup_key_dtype is None:
1288        lookup_key_dtype = initializer.key_dtype
1289      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1290      if initializer.key_dtype not in supported_table_key_dtypes:
1291        raise TypeError("Invalid `key_dtype`, expected one of %s, but got %s." %
1292                        (supported_table_key_dtypes, initializer.key_dtype))
1293      if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
1294        raise TypeError(
1295            "Invalid `key_dtype`, expected %s but got %s." %
1296            ("integer" if lookup_key_dtype.is_integer else "non-integer",
1297             initializer.key_dtype))
1298      if initializer.value_dtype != dtypes.int64:
1299        raise TypeError("Invalid `value_dtype`, expected %s but got %s." %
1300                        (dtypes.int64, initializer.value_dtype))
1301      if isinstance(initializer, trackable_base.Trackable):
1302        self._initializer = self._track_trackable(initializer, "_initializer")
1303      self._table = HashTable(
1304          initializer,
1305          default_value=-1,
1306          experimental_is_anonymous=experimental_is_anonymous)
1307      name = name or self._table.name
1308    else:
1309      lookup_key_dtype = dtypes.string
1310      self._table = None
1311      name = name or "hash_bucket"
1312    if (not lookup_key_dtype.is_integer) and (dtypes.string !=
1313                                              lookup_key_dtype):
1314      raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1315                      f"{lookup_key_dtype}")
1316    self._num_oov_buckets = num_oov_buckets
1317
1318    self._table_name = None
1319    if name is not None:
1320      self._table_name = name.split("/")[-1]
1321    super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
1322
1323  def _create_resource(self):
1324    if self._table is not None:
1325      return self._table._create_resource()  # pylint: disable=protected-access
1326    return None
1327
1328  def _initialize(self):
1329    if self._table is not None:
1330      return self._table._initialize()  # pylint: disable=protected-access
1331    with ops.name_scope(None, "init"):
1332      return control_flow_ops.no_op()
1333
1334  @property
1335  def resource_handle(self):
1336    if self._table is not None:
1337      return self._table.resource_handle
1338    return None
1339
1340  @property
1341  def name(self):
1342    return self._table_name
1343
1344  def size(self, name=None):
1345    """Compute the number of elements in this table."""
1346    with ops.name_scope(name, "%s_Size" % self.name):
1347      if self._table:
1348        tsize = self._table.size()
1349      else:
1350        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1351      return tsize + self._num_oov_buckets
1352
1353  def lookup(self, keys, name=None):
1354    """Looks up `keys` in the table, outputs the corresponding values.
1355
1356    It assigns out-of-vocabulary keys to buckets based in their hashes.
1357
1358    Args:
1359      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1360      name: Optional name for the op.
1361
1362    Returns:
1363      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1364      otherwise a dense `Tensor`.
1365
1366    Raises:
1367      TypeError: when `keys` doesn't match the table key data type.
1368    """
1369    if keys.dtype.base_dtype != self._key_dtype:
1370      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1371                      f"received: {keys.dtype}")
1372    values = keys
1373    if isinstance(keys,
1374                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1375      values = keys.values
1376    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1377      values = math_ops.cast(values, dtypes.int64)
1378
1379    # TODO(yleon): Consider moving this functionality to its own kernel.
1380    with ops.name_scope(name, "%s_Lookup" % self.name):
1381      buckets = string_ops.string_to_hash_bucket_fast(
1382          _as_string(values),
1383          num_buckets=self._num_oov_buckets,
1384          name="hash_bucket")
1385      if self._table:
1386        ids = self._table.lookup(values)
1387        buckets = math_ops.add(buckets, self._table.size())
1388        is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1389        ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1390      else:
1391        ids = buckets
1392    if isinstance(keys, sparse_tensor.SparseTensor):
1393      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1394    elif isinstance(keys, ragged_tensor.RaggedTensor):
1395      return keys.with_values(ids)
1396    return ids
1397
1398
1399@tf_export(v1=["lookup.StaticVocabularyTable"])
1400class StaticVocabularyTableV1(StaticVocabularyTable):
1401
1402  @property
1403  def initializer(self):
1404    if self._table is not None:
1405      return self._table._init_op  # pylint: disable=protected-access
1406    with ops.name_scope(None, "init"):
1407      return control_flow_ops.no_op()
1408
1409
1410def index_table_from_file(vocabulary_file=None,
1411                          num_oov_buckets=0,
1412                          vocab_size=None,
1413                          default_value=-1,
1414                          hasher_spec=FastHashSpec,
1415                          key_dtype=dtypes.string,
1416                          name=None,
1417                          key_column_index=TextFileIndex.WHOLE_LINE,
1418                          value_column_index=TextFileIndex.LINE_NUMBER,
1419                          delimiter="\t"):
1420  """Returns a lookup table that converts a string tensor into int64 IDs.
1421
1422  This operation constructs a lookup table to convert tensor of strings into
1423  int64 IDs. The mapping can be initialized from a vocabulary file specified in
1424  `vocabulary_file`, where the whole line is the key and the zero-based line
1425  number is the ID.
1426
1427  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1428  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1429  `default_value`.
1430  The bucket ID range is
1431  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
1432
1433  The underlying table must be initialized by calling
1434  `session.run(tf.compat.v1.tables_initializer())` or
1435  `session.run(table.init())` once.
1436
1437  To specify multi-column vocabulary files, use key_column_index and
1438  value_column_index and delimiter.
1439
1440  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1441    expects data type int64.
1442  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1443    type string.
1444  - A value >=0 means use the index (starting at zero) of the split line based
1445    on `delimiter`.
1446
1447  Sample Usages:
1448
1449  If we have a vocabulary file "test.txt" with the following content:
1450
1451  ```
1452  emerson
1453  lake
1454  palmer
1455  ```
1456
1457  ```python
1458  features = tf.constant(["emerson", "lake", "and", "palmer"])
1459  table = tf.lookup.index_table_from_file(
1460      vocabulary_file="test.txt", num_oov_buckets=1)
1461  ids = table.lookup(features)
1462  ...
1463  tf.compat.v1.tables_initializer().run()
1464
1465  ids.eval()  ==> [0, 1, 3, 2]  # where 3 is the out-of-vocabulary bucket
1466  ```
1467
1468  Args:
1469    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1470    num_oov_buckets: The number of out-of-vocabulary buckets.
1471    vocab_size: Number of the elements in the vocabulary, if known.
1472    default_value: The value to use for out-of-vocabulary feature values.
1473      Defaults to -1.
1474    hasher_spec: A `HasherSpec` to specify the hash function to use for
1475      assignation of out-of-vocabulary buckets.
1476    key_dtype: The `key` data type.
1477    name: A name for this op (optional).
1478    key_column_index: The column index from the text file to get the `key`
1479      values from. The default is to use the whole line content.
1480    value_column_index: The column index from the text file to get the `value`
1481      values from. The default is to use the line number, starting from zero.
1482    delimiter: The delimiter to separate fields in a line.
1483
1484  Returns:
1485    The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
1486
1487  Raises:
1488    ValueError: If `vocabulary_file` is not set.
1489    ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
1490      than zero.
1491  """
1492  if vocabulary_file is None or (isinstance(vocabulary_file, str) and
1493                                 not vocabulary_file):
1494    raise ValueError(
1495        "`vocabulary_file` must be specified and must not be empty.")
1496  if num_oov_buckets < 0:
1497    raise ValueError(
1498        "num_oov_buckets must be greater or equal than 0, got %d." %
1499        num_oov_buckets)
1500  if vocab_size is not None and vocab_size < 1:
1501    vocab_file_value = vocabulary_file
1502    if isinstance(vocabulary_file, ops.Tensor):
1503      vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
1504    raise ValueError("`vocab_size` must be greater than 0, got %d for "
1505                     "vocabulary_file: %s." % (vocab_size, vocab_file_value))
1506  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
1507    raise TypeError("Dtype for `keys` should be either integer or string.")
1508
1509  with ops.name_scope(name, "string_to_index"):
1510    table = None
1511    with ops.name_scope(None, "hash_table"):
1512      init = TextFileIdTableInitializer(
1513          vocabulary_file,
1514          vocab_size=vocab_size,
1515          key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
1516          name="table_init",
1517          key_column_index=key_column_index,
1518          value_column_index=value_column_index,
1519          delimiter=delimiter)
1520
1521      table = StaticHashTableV1(init, default_value)
1522    if num_oov_buckets:
1523      table = IdTableWithHashBuckets(
1524          table,
1525          num_oov_buckets=num_oov_buckets,
1526          hasher_spec=hasher_spec,
1527          key_dtype=key_dtype)
1528
1529    return table
1530
1531
1532def index_table_from_tensor(vocabulary_list,
1533                            num_oov_buckets=0,
1534                            default_value=-1,
1535                            hasher_spec=FastHashSpec,
1536                            dtype=dtypes.string,
1537                            name=None):
1538  """Returns a lookup table that converts a string tensor into int64 IDs.
1539
1540  This operation constructs a lookup table to convert tensor of strings into
1541  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
1542  tensor where each element is a key and corresponding index within the tensor
1543  is the value.
1544
1545  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1546  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1547  `default_value`. The bucket ID range is
1548  `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
1549
1550  The underlying table must be initialized by calling
1551  `session.run(tf.compat.v1.tables_initializer())` or
1552  `session.run(table.init())` once.
1553
1554  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1555  the table initializer op, it will throw a `FailedPreconditionError`.
1556
1557  Sample Usages:
1558
1559  ```python
1560  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1561  table = tf.lookup.index_table_from_tensor(
1562      vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
1563  features = tf.constant(["emerson", "lake", "and", "palmer"])
1564  ids = table.lookup(features)
1565  ...
1566  tf.compat.v1.tables_initializer().run()
1567
1568  ids.eval()  ==> [0, 1, 4, 2]
1569  ```
1570
1571  Args:
1572    vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
1573      indices. The type of this object must be castable to `dtype`.
1574    num_oov_buckets: The number of out-of-vocabulary buckets.
1575    default_value: The value to use for out-of-vocabulary feature values.
1576      Defaults to -1.
1577    hasher_spec: A `HasherSpec` to specify the hash function to use for
1578      assignment of out-of-vocabulary buckets.
1579    dtype: The type of values passed to `lookup`. Only string and integers are
1580      supported.
1581    name: A name for this op (optional).
1582
1583  Returns:
1584    The lookup table to map an input `Tensor` to index `int64` `Tensor`.
1585
1586  Raises:
1587    ValueError: If `vocabulary_list` is invalid.
1588    ValueError: If `num_oov_buckets` is negative.
1589  """
1590  if vocabulary_list is None:
1591    raise ValueError("`vocabulary_list` must be specified.")
1592
1593  if num_oov_buckets < 0:
1594    raise ValueError(
1595        "`num_oov_buckets` must be greater or equal than 0, got %d." %
1596        num_oov_buckets)
1597
1598  if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
1599    raise TypeError("`dtype` must either be integer or string.")
1600
1601  with ops.name_scope(name, "string_to_index"):
1602    keys = ops.convert_to_tensor(vocabulary_list)
1603    if keys.dtype.is_integer != dtype.is_integer:
1604      raise ValueError(
1605          "Invalid `dtype`: Expected %s, got %s." %
1606          ("integer" if dtype.is_integer else "non-integer", keys.dtype))
1607    if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
1608      raise ValueError("Invalid `dtype`: Expected %s, got %s." %
1609                       (dtype, keys.dtype))
1610    num_elements = array_ops.size(keys)
1611    values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1612
1613    with ops.name_scope(None, "hash_table"):
1614      table_keys = math_ops.cast(
1615          keys, dtypes.int64) if keys.dtype.is_integer else keys
1616      init = KeyValueTensorInitializer(
1617          table_keys,
1618          values,
1619          table_keys.dtype.base_dtype,
1620          dtypes.int64,
1621          name="table_init")
1622      table = StaticHashTableV1(init, default_value)
1623    if num_oov_buckets:
1624      table = IdTableWithHashBuckets(
1625          table,
1626          num_oov_buckets=num_oov_buckets,
1627          hasher_spec=hasher_spec,
1628          key_dtype=dtype)
1629    return table
1630
1631
1632def index_to_string_table_from_file(vocabulary_file,
1633                                    vocab_size=None,
1634                                    default_value="UNK",
1635                                    name=None,
1636                                    key_column_index=TextFileIndex.LINE_NUMBER,
1637                                    value_column_index=TextFileIndex.WHOLE_LINE,
1638                                    delimiter="\t"):
1639  """Returns a lookup table that maps a `Tensor` of indices into strings.
1640
1641  This operation constructs a lookup table to map int64 indices into string
1642  values. The table is initialized from a vocabulary file specified in
1643  `vocabulary_file`, where the whole line is the value and the
1644  zero-based line number is the index.
1645
1646  Any input which does not have a corresponding index in the vocabulary file
1647  (an out-of-vocabulary entry) is assigned the `default_value`
1648
1649  The underlying table must be initialized by calling
1650  `session.run(tf.compat.v1.tables_initializer())` or
1651  `session.run(table.init())` once.
1652
1653  To specify multi-column vocabulary files, use key_column_index and
1654  value_column_index and delimiter.
1655
1656  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1657    expects data type int64.
1658  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1659    type string.
1660  - A value >=0 means use the index (starting at zero) of the split line based
1661    on `delimiter`.
1662
1663  Sample Usages:
1664
1665  If we have a vocabulary file "test.txt" with the following content:
1666
1667  ```
1668  emerson
1669  lake
1670  palmer
1671  ```
1672
1673  ```python
1674  indices = tf.constant([1, 5], tf.int64)
1675  table = tf.lookup.index_to_string_table_from_file(
1676      vocabulary_file="test.txt", default_value="UNKNOWN")
1677  values = table.lookup(indices)
1678  ...
1679  tf.compat.v1.tables_initializer().run()
1680
1681  values.eval() ==> ["lake", "UNKNOWN"]
1682  ```
1683
1684  Args:
1685    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1686    vocab_size: Number of the elements in the vocabulary, if known.
1687    default_value: The value to use for out-of-vocabulary indices.
1688    name: A name for this op (optional).
1689    key_column_index: The column index from the text file to get the `key`
1690      values from. The default is to use the line number, starting from zero.
1691    value_column_index: The column index from the text file to get the `value`
1692      values from. The default is to use the whole line content.
1693    delimiter: The delimiter to separate fields in a line.
1694
1695  Returns:
1696    The lookup table to map a string values associated to a given index `int64`
1697    `Tensors`.
1698
1699  Raises:
1700    ValueError: when `vocabulary_file` is empty.
1701    ValueError: when `vocab_size` is invalid.
1702  """
1703  if vocabulary_file is None or (isinstance(vocabulary_file, str) and
1704                                 not vocabulary_file):
1705    raise ValueError(
1706        "`vocabulary_file` must be specified and must not be empty.")
1707
1708  if vocab_size is not None and vocab_size < 1:
1709    raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
1710
1711  with ops.name_scope(name, "index_to_string"):
1712    init = TextFileStringTableInitializer(
1713        vocabulary_file,
1714        vocab_size=vocab_size,
1715        name="table_init",
1716        key_column_index=key_column_index,
1717        value_column_index=value_column_index,
1718        delimiter=delimiter)
1719
1720    # TODO(yleon): Use a more efficient structure.
1721    return StaticHashTableV1(init, default_value)
1722
1723
1724def index_to_string_table_from_tensor(vocabulary_list,
1725                                      default_value="UNK",
1726                                      name=None):
1727  """Returns a lookup table that maps a `Tensor` of indices into strings.
1728
1729  This operation constructs a lookup table to map int64 indices into string
1730  values. The mapping is initialized from a string `vocabulary_list` 1-D
1731  `Tensor` where each element is a value and the corresponding index within the
1732  tensor is the key.
1733
1734  Any input which does not have a corresponding index in 'vocabulary_list'
1735  (an out-of-vocabulary entry) is assigned the `default_value`
1736
1737  The underlying table must be initialized by calling
1738  `session.run(tf.compat.v1.tables_initializer())` or
1739  `session.run(table.init())` once.
1740
1741  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1742  the table initializer op, it will throw a `FailedPreconditionError`.
1743
1744  Sample Usages:
1745
1746  ```python
1747  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1748  indices = tf.constant([1, 5], tf.int64)
1749  table = tf.lookup.index_to_string_table_from_tensor(
1750      vocabulary_list, default_value="UNKNOWN")
1751  values = table.lookup(indices)
1752  ...
1753  tf.compat.v1.tables_initializer().run()
1754
1755  values.eval() ==> ["lake", "UNKNOWN"]
1756  ```
1757
1758  Args:
1759    vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
1760      from indices.
1761    default_value: The value to use for out-of-vocabulary indices.
1762    name: A name for this op (optional).
1763
1764  Returns:
1765    The lookup table to map a string values associated to a given index `int64`
1766    `Tensors`.
1767
1768  Raises:
1769    ValueError: when `vocabulary_list` is not set.
1770  """
1771
1772  if vocabulary_list is None:
1773    raise ValueError("`vocabulary_list` argument must be specified.")
1774
1775  with ops.name_scope(name, "index_to_string"):
1776    vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
1777    num_elements = array_ops.size(vocabulary_list)
1778    keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1779
1780    init = KeyValueTensorInitializer(
1781        keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
1782    # TODO(yleon): Use a more efficient structure.
1783    return StaticHashTableV1(init, default_value)
1784
1785
1786@tf_export("lookup.experimental.MutableHashTable")
1787@saveable_compat.legacy_saveable_name("table")
1788class MutableHashTable(LookupInterface):
1789  """A generic mutable hash table implementation.
1790
1791  Data can be inserted by calling the `insert` method and removed by calling the
1792  `remove` method. It does not support initialization via the init method.
1793
1794  `MutableHashTable` requires additional memory during checkpointing and restore
1795  operations to create temporary key and value tensors.
1796
1797  Example usage:
1798
1799  >>> table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.string,
1800  ...                                                 value_dtype=tf.int64,
1801  ...                                                 default_value=-1)
1802  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
1803  >>> vals_tensor = tf.constant([7, 8, 9], dtype=tf.int64)
1804  >>> input_tensor = tf.constant(['a', 'f'])
1805  >>> table.insert(keys_tensor, vals_tensor)
1806  >>> table.lookup(input_tensor).numpy()
1807  array([ 7, -1])
1808  >>> table.remove(tf.constant(['c']))
1809  >>> table.lookup(keys_tensor).numpy()
1810  array([ 7, 8, -1])
1811  >>> sorted(table.export()[0].numpy())
1812  [b'a', b'b']
1813  >>> sorted(table.export()[1].numpy())
1814  [7, 8]
1815  """
1816
1817  def __init__(self,
1818               key_dtype,
1819               value_dtype,
1820               default_value,
1821               name="MutableHashTable",
1822               checkpoint=True,
1823               experimental_is_anonymous=False):
1824    """Creates an empty `MutableHashTable` object.
1825
1826    Creates a table, the type of its keys and values are specified by key_dtype
1827    and value_dtype, respectively.
1828
1829    Args:
1830      key_dtype: the type of the key tensors.
1831      value_dtype: the type of the value tensors.
1832      default_value: The value to use if a key is missing in the table.
1833      name: A name for the operation (optional).
1834      checkpoint: if True, the contents of the table are saved to and restored
1835        from checkpoints. If `shared_name` is empty for a checkpointed table, it
1836        is shared using the table node name.
1837      experimental_is_anonymous: Whether to use anonymous mode for the
1838        table (default is False). In anonymous mode, the table
1839        resource can only be accessed via a resource handle. It can't
1840        be looked up by a name. When all resource handles pointing to
1841        that resource are gone, the resource will be deleted
1842        automatically.
1843
1844    Returns:
1845      A `MutableHashTable` object.
1846
1847    Raises:
1848      ValueError: If checkpoint is True and no name was specified.
1849    """
1850    self._default_value = ops.convert_to_tensor(
1851        default_value, dtype=value_dtype)
1852    self._value_shape = self._default_value.get_shape()
1853    self._checkpoint = checkpoint
1854    self._key_dtype = key_dtype
1855    self._value_dtype = value_dtype
1856    self._name = name
1857    self._is_anonymous = experimental_is_anonymous
1858    if not self._is_anonymous:
1859      self._shared_name = None
1860      if context.executing_eagerly():
1861        # TODO(allenl): This will leak memory due to kernel caching by
1862        # the shared_name attribute value (but is better than the
1863        # alternative of sharing everything by default when executing
1864        # eagerly; hopefully creating tables in a loop is uncommon).
1865        self._shared_name = "table_%d" % (ops.uid(),)
1866    super(MutableHashTable, self).__init__(key_dtype, value_dtype)
1867    self._resource_handle = self._create_resource()
1868    if checkpoint:
1869      saveable = MutableHashTable._Saveable(self, name)
1870      if not context.executing_eagerly():
1871        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1872
1873  def _create_resource(self):
1874    if self._is_anonymous:
1875      if self._default_value.get_shape().ndims == 0:
1876        table_ref = gen_lookup_ops.anonymous_mutable_hash_table(
1877            key_dtype=self._key_dtype,
1878            value_dtype=self._value_dtype,
1879            name=self._name)
1880      else:
1881        table_ref = gen_lookup_ops.anonymous_mutable_hash_table_of_tensors(
1882            key_dtype=self._key_dtype,
1883            value_dtype=self._value_dtype,
1884            value_shape=self._default_value.get_shape(),
1885            name=self._name)
1886    else:
1887      # The table must be shared if checkpointing is requested for multi-worker
1888      # training to work correctly. Use the node name if no shared_name has been
1889      # explicitly specified.
1890      use_node_name_sharing = self._checkpoint and self._shared_name is None
1891      if self._default_value.get_shape().ndims == 0:
1892        table_ref = gen_lookup_ops.mutable_hash_table_v2(
1893            shared_name=self._shared_name,
1894            use_node_name_sharing=use_node_name_sharing,
1895            key_dtype=self._key_dtype,
1896            value_dtype=self._value_dtype,
1897            name=self._name)
1898      else:
1899        table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
1900            shared_name=self._shared_name,
1901            use_node_name_sharing=use_node_name_sharing,
1902            key_dtype=self._key_dtype,
1903            value_dtype=self._value_dtype,
1904            value_shape=self._default_value.get_shape(),
1905            name=self._name)
1906
1907    if context.executing_eagerly():
1908      self._table_name = None
1909    else:
1910      self._table_name = table_ref.op.name.split("/")[-1]
1911    return table_ref
1912
1913  @property
1914  def name(self):
1915    return self._table_name
1916
1917  def size(self, name=None):
1918    """Compute the number of elements in this table.
1919
1920    Args:
1921      name: A name for the operation (optional).
1922
1923    Returns:
1924      A scalar tensor containing the number of elements in this table.
1925    """
1926    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1927      with ops.colocate_with(self.resource_handle):
1928        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1929
1930  def remove(self, keys, name=None):
1931    """Removes `keys` and its associated values from the table.
1932
1933    If a key is not present in the table, it is silently ignored.
1934
1935    Args:
1936      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1937        key type.
1938      name: A name for the operation (optional).
1939
1940    Returns:
1941      The created Operation.
1942
1943    Raises:
1944      TypeError: when `keys` do not match the table data types.
1945    """
1946    if keys.dtype != self._key_dtype:
1947      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1948                      f"received: {keys.dtype}")
1949
1950    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1951                        (self.resource_handle, keys, self._default_value)):
1952      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1953
1954    return op
1955
1956  def lookup(self, keys, dynamic_default_values=None, name=None):
1957    """Looks up `keys` in a table, outputs the corresponding values.
1958
1959    The `default_value` is used for keys not present in the table.
1960
1961    Args:
1962      keys: Keys to look up. Can be a tensor of any shape. Must match the
1963        table's key_dtype.
1964      dynamic_default_values: The values to use if a key is missing in the
1965        table. If None (by default), the `table.default_value` will be used.
1966        Shape of `dynamic_default_values` must be same with
1967        `table.default_value` or the lookup result tensor.
1968        In the latter case, each key will have a different default value.
1969
1970        For example:
1971
1972          ```python
1973          keys = [0, 1, 3]
1974          dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1975
1976          # The key '0' will use [1, 3, 4] as default value.
1977          # The key '1' will use [2, 3, 9] as default value.
1978          # The key '3' will use [8, 3, 0] as default value.
1979          ```
1980
1981      name: A name for the operation (optional).
1982
1983    Returns:
1984      A tensor containing the values in the same shape as `keys` using the
1985        table's value type.
1986
1987    Raises:
1988      TypeError: when `keys` do not match the table data types.
1989    """
1990    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1991                        (self.resource_handle, keys, self._default_value)):
1992      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1993      with ops.colocate_with(self.resource_handle):
1994        values = gen_lookup_ops.lookup_table_find_v2(
1995            self.resource_handle, keys, dynamic_default_values
1996            if dynamic_default_values is not None else self._default_value)
1997    return values
1998
1999  def insert(self, keys, values, name=None):
2000    """Associates `keys` with `values`.
2001
2002    Args:
2003      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2004        key type.
2005      values: Values to be associated with keys. Must be a tensor of the same
2006        shape as `keys` and match the table's value type.
2007      name: A name for the operation (optional).
2008
2009    Returns:
2010      The created Operation.
2011
2012    Raises:
2013      TypeError: when `keys` or `values` doesn't match the table data
2014        types.
2015    """
2016    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2017                        [self.resource_handle, keys, values]):
2018      keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
2019      values = ops.convert_to_tensor(values, self._value_dtype, name="values")
2020      with ops.colocate_with(self.resource_handle):
2021        # pylint: disable=protected-access
2022        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2023                                                   values)
2024    return op
2025
2026  def export(self, name=None):
2027    """Returns tensors of all keys and values in the table.
2028
2029    Args:
2030      name: A name for the operation (optional).
2031
2032    Returns:
2033      A pair of tensors with the first tensor containing all keys and the
2034        second tensors containing all values in the table.
2035    """
2036    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2037                        [self.resource_handle]):
2038      with ops.colocate_with(self.resource_handle):
2039        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2040            self.resource_handle, self._key_dtype, self._value_dtype)
2041    return exported_keys, exported_values
2042
2043  def _serialize_to_tensors(self):
2044    """Implements checkpointing protocols for `Trackable`."""
2045    tensors = self.export()
2046    return {"-keys": tensors[0], "-values": tensors[1]}
2047
2048  def _restore_from_tensors(self, restored_tensors):
2049    """Implements checkpointing protocols for `Trackable`."""
2050    with ops.name_scope("%s_table_restore" % self._name):
2051      with ops.colocate_with(self.resource_handle):
2052        return gen_lookup_ops.lookup_table_import_v2(
2053            self.resource_handle,
2054            restored_tensors["-keys"],
2055            restored_tensors["-values"])
2056
2057    # This class is needed for `MutableHashTable(checkpoint=True)`.
2058  class _Saveable(BaseSaverBuilder.SaveableObject):
2059    """SaveableObject implementation for DenseHashTable."""
2060
2061    def __init__(self, table, name, table_name=None):
2062      tensors = table.export()
2063      specs = [
2064          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2065          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2066      ]
2067      self.table_name = table_name or name
2068      # pylint: disable=protected-access
2069      super(MutableHashTable._Saveable, self).__init__(table, specs, name)
2070
2071    def restore(self, restored_tensors, restored_shapes):
2072      del restored_shapes  # unused
2073      # pylint: disable=protected-access
2074      with ops.name_scope("%s_table_restore" % self.table_name):
2075        with ops.colocate_with(self.op.resource_handle):
2076          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2077                                                       restored_tensors[0],
2078                                                       restored_tensors[1])
2079
2080
2081@tf_export("lookup.experimental.DenseHashTable")
2082@saveable_compat.legacy_saveable_name("table")
2083class DenseHashTable(LookupInterface):
2084  """A mutable hash table with faster lookups and higher memory usage.
2085
2086  Data can be inserted by calling the `insert` method and removed by calling the
2087  `remove` method. It does not support initialization via the init method.
2088
2089  Compared to `MutableHashTable`, `DenseHashTable` offers generally faster
2090  `insert`, `remove` and `lookup` operations, in exchange for a higher overall
2091  memory footprint.
2092
2093  It uses "open addressing" with quadratic reprobing to resolve collisions. This
2094  requires specifying two keys in the key space, `empty_key` and `deleted_key`,
2095  that can never inserted into the table.
2096
2097  Unlike `MutableHashTable`, `DenseHashTable` does not require additional memory
2098  for temporary tensors created during checkpointing and restore operations.
2099
2100  Example usage:
2101
2102  >>> table = tf.lookup.experimental.DenseHashTable(
2103  ...     key_dtype=tf.string,
2104  ...     value_dtype=tf.int64,
2105  ...     default_value=-1,
2106  ...     empty_key='',
2107  ...     deleted_key='$')
2108  >>> keys = tf.constant(['a', 'b', 'c'])
2109  >>> values = tf.constant([0, 1, 2], dtype=tf.int64)
2110  >>> table.insert(keys, values)
2111  >>> table.remove(tf.constant(['c']))
2112  >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy()
2113  array([ 0,  1, -1, -1])
2114  """
2115
2116  # TODO(andreasst): consider extracting common code with MutableHashTable into
2117  # a common superclass.
2118  def __init__(self,
2119               key_dtype,
2120               value_dtype,
2121               default_value,
2122               empty_key,
2123               deleted_key,
2124               initial_num_buckets=None,
2125               name="MutableDenseHashTable",
2126               checkpoint=True,
2127               experimental_is_anonymous=False):
2128    """Creates an empty `DenseHashTable` object.
2129
2130    Creates a table, the type of its keys and values are specified by key_dtype
2131    and value_dtype, respectively.
2132
2133    Args:
2134      key_dtype: the type of the key tensors.
2135      value_dtype: the type of the value tensors.
2136      default_value: The value to use if a key is missing in the table.
2137      empty_key: the key to use to represent empty buckets internally. Must not
2138        be used in insert, remove or lookup operations.
2139      deleted_key: the key to use to represent deleted buckets internally. Must
2140        not be used in insert, remove or lookup operations and be different from
2141        the empty_key.
2142      initial_num_buckets: the initial number of buckets (optional,
2143        default to 2^17=131072). Note that the default value is
2144        relatively large (~1MB), so if you are going to create many
2145        tables (likely the case when `experimental_is_anonymous` is
2146        `True`), you should set `initial_num_buckets` to a smaller
2147        value to reduce memory usage.
2148      name: A name for the operation (optional).
2149      checkpoint: if True, the contents of the table are saved to and restored
2150        from checkpoints. If `shared_name` is empty for a checkpointed table, it
2151        is shared using the table node name.
2152      experimental_is_anonymous: Whether to use anonymous mode for the
2153        table (default is False). In anonymous mode, the table
2154        resource can only be accessed via a resource handle. It can't
2155        be looked up by a name. When all resource handles pointing to
2156        that resource are gone, the resource will be deleted
2157        automatically.
2158
2159    Returns:
2160      A `DenseHashTable` object.
2161
2162    Raises:
2163      ValueError: If checkpoint is True and no name was specified.
2164    """
2165    self._default_value = ops.convert_to_tensor(
2166        default_value, dtype=value_dtype, name="default_value")
2167    self._key_dtype = key_dtype
2168    self._value_dtype = value_dtype
2169    # TODO(b/201578996): Pick a good default for initial_num_buckets
2170    #   other than 2^17.
2171    self._initial_num_buckets = initial_num_buckets
2172    self._value_shape = self._default_value.get_shape()
2173    self._checkpoint = checkpoint
2174    self._name = name
2175    self._empty_key = empty_key
2176    self._deleted_key = deleted_key
2177    self._is_anonymous = experimental_is_anonymous
2178    if not self._is_anonymous:
2179      self._shared_name = None
2180      if context.executing_eagerly():
2181        # TODO(allenl): This will leak memory due to kernel caching by
2182        # the shared_name attribute value (but is better than the
2183        # alternative of sharing everything by default when executing
2184        # eagerly; hopefully creating tables in a loop is uncommon).
2185        self._shared_name = "table_%d" % (ops.uid(),)
2186    super(DenseHashTable, self).__init__(key_dtype, value_dtype)
2187    self._resource_handle = self._create_resource()
2188    if checkpoint:
2189      saveable = DenseHashTable._Saveable(self, name)
2190      if not context.executing_eagerly():
2191        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
2192
2193  def _create_resource(self):
2194    empty_key = ops.convert_to_tensor(
2195        self._empty_key, dtype=self._key_dtype, name="empty_key")
2196    deleted_key = ops.convert_to_tensor(
2197        self._deleted_key, dtype=self._key_dtype, name="deleted_key")
2198    if self._is_anonymous:
2199      table_ref = gen_lookup_ops.anonymous_mutable_dense_hash_table(
2200          empty_key=empty_key,
2201          deleted_key=deleted_key,
2202          value_dtype=self._value_dtype,
2203          value_shape=self._value_shape,
2204          initial_num_buckets=self._initial_num_buckets,
2205          name=self._name)
2206    else:
2207      # The table must be shared if checkpointing is requested for multi-worker
2208      # training to work correctly. Use the node name if no shared_name has been
2209      # explicitly specified.
2210      use_node_name_sharing = self._checkpoint and self._shared_name is None
2211      table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
2212          empty_key=empty_key,
2213          deleted_key=deleted_key,
2214          shared_name=self._shared_name,
2215          use_node_name_sharing=use_node_name_sharing,
2216          value_dtype=self._value_dtype,
2217          value_shape=self._value_shape,
2218          initial_num_buckets=self._initial_num_buckets,
2219          name=self._name)
2220    if context.executing_eagerly():
2221      self._table_name = None
2222    else:
2223      self._table_name = table_ref.op.name.split("/")[-1]
2224    return table_ref
2225
2226  @property
2227  def name(self):
2228    return self._table_name
2229
2230  def size(self, name=None):
2231    """Compute the number of elements in this table.
2232
2233    Args:
2234      name: A name for the operation (optional).
2235
2236    Returns:
2237      A scalar tensor containing the number of elements in this table.
2238    """
2239    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
2240      with ops.colocate_with(self.resource_handle):
2241        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
2242
2243  def lookup(self, keys, name=None):
2244    """Looks up `keys` in a table, outputs the corresponding values.
2245
2246    The `default_value` is used for keys not present in the table.
2247
2248    Args:
2249      keys: Keys to look up. Can be a tensor of any shape. Must match the
2250        table's key_dtype.
2251      name: A name for the operation (optional).
2252
2253    Returns:
2254      A tensor containing the values in the same shape as `keys` using the
2255        table's value type.
2256
2257    Raises:
2258      TypeError: when `keys` do not match the table data types.
2259    """
2260    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
2261                        [self.resource_handle, keys]):
2262      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2263      with ops.colocate_with(self.resource_handle):
2264        values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
2265                                                     self._default_value)
2266
2267    return values
2268
2269  def insert_or_assign(self, keys, values, name=None):
2270    """Associates `keys` with `values`.
2271
2272    Args:
2273      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2274        key type.
2275      values: Values to be associated with keys. Must be a tensor of the same
2276        shape as `keys` and match the table's value type.
2277      name: A name for the operation (optional).
2278
2279    Returns:
2280      The created Operation.
2281
2282    Raises:
2283      TypeError: when `keys` or `values` doesn't match the table data
2284        types.
2285    """
2286    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2287                        [self.resource_handle, keys, values]):
2288      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2289      values = ops.convert_to_tensor(
2290          values, dtype=self._value_dtype, name="values")
2291      with ops.colocate_with(self.resource_handle):
2292        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2293                                                   values)
2294      return op
2295
2296  def insert(self, keys, values, name=None):
2297    """Associates `keys` with `values`.
2298
2299    Args:
2300      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2301        key type.
2302      values: Values to be associated with keys. Must be a tensor of the same
2303        shape as `keys` and match the table's value type.
2304      name: A name for the operation (optional).
2305
2306    Returns:
2307      The created Operation.
2308
2309    Raises:
2310      TypeError: when `keys` or `values` doesn't match the table data
2311        types.
2312    """
2313    return self.insert_or_assign(keys, values, name)
2314
2315  def erase(self, keys, name=None):
2316    """Removes `keys` and its associated values from the table.
2317
2318    If a key is not present in the table, it is silently ignored.
2319
2320    Args:
2321      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2322        key type.
2323      name: A name for the operation (optional).
2324
2325    Returns:
2326      The created Operation.
2327
2328    Raises:
2329      TypeError: when `keys` do not match the table data types.
2330    """
2331    if keys.dtype != self._key_dtype:
2332      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
2333                      (self._key_dtype, keys.dtype))
2334
2335    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
2336                        (self.resource_handle, keys, self._default_value)):
2337      # pylint: disable=protected-access
2338      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
2339
2340    return op
2341
2342  def remove(self, keys, name=None):
2343    """Removes `keys` and its associated values from the table.
2344
2345    If a key is not present in the table, it is silently ignored.
2346
2347    Args:
2348      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2349        key type.
2350      name: A name for the operation (optional).
2351
2352    Returns:
2353      The created Operation.
2354
2355    Raises:
2356      TypeError: when `keys` do not match the table data types.
2357    """
2358    return self.erase(keys, name)
2359
2360  def export(self, name=None):
2361    """Returns tensors of all keys and values in the table.
2362
2363    Args:
2364      name: A name for the operation (optional).
2365
2366    Returns:
2367      A pair of tensors with the first tensor containing all keys and the
2368        second tensors containing all values in the table.
2369    """
2370    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2371                        [self.resource_handle]):
2372      with ops.colocate_with(self.resource_handle):
2373        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2374            self.resource_handle, self._key_dtype, self._value_dtype)
2375
2376    return exported_keys, exported_values
2377
2378  def _serialize_to_tensors(self):
2379    """Implements checkpointing interface in `Trackable`."""
2380    tensors = self.export()
2381    return {"-keys": tensors[0], "-values": tensors[1]}
2382
2383  def _restore_from_tensors(self, restored_tensors):
2384    """Implements checkpointing interface in `Trackable`."""
2385    with ops.name_scope("%s_table_restore" % self._name):
2386      with ops.colocate_with(self.resource_handle):
2387        return gen_lookup_ops.lookup_table_import_v2(
2388            self.resource_handle,
2389            restored_tensors["-keys"],
2390            restored_tensors["-values"])
2391
2392  # This class is needed for `DenseHashTable(checkpoint=True)`.
2393  class _Saveable(BaseSaverBuilder.SaveableObject):
2394    """SaveableObject implementation for DenseHashTable."""
2395
2396    def __init__(self, table, name, table_name=None):
2397      tensors = table.export()
2398      specs = [
2399          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2400          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2401      ]
2402      self.table_name = table_name or name
2403      # pylint: disable=protected-access
2404      super(DenseHashTable._Saveable, self).__init__(table, specs, name)
2405
2406    def restore(self, restored_tensors, restored_shapes):
2407      del restored_shapes  # unused
2408      # pylint: disable=protected-access
2409      with ops.name_scope("%s_table_restore" % self.table_name):
2410        with ops.colocate_with(self.op.resource_handle):
2411          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2412                                                       restored_tensors[0],
2413                                                       restored_tensors[1])
2414
2415
2416ops.NotDifferentiable("LookupTableFind")
2417ops.NotDifferentiable("LookupTableFindV2")
2418ops.NotDifferentiable("LookupTableInsert")
2419ops.NotDifferentiable("LookupTableInsertV2")
2420ops.NotDifferentiable("LookupTableSize")
2421ops.NotDifferentiable("LookupTableSizeV2")
2422ops.NotDifferentiable("HashTable")
2423ops.NotDifferentiable("HashTableV2")
2424ops.NotDifferentiable("InitializeTable")
2425ops.NotDifferentiable("InitializeTableV2")
2426ops.NotDifferentiable("InitializeTableFromTextFile")
2427ops.NotDifferentiable("InitializeTableFromTextFileV2")
2428ops.NotDifferentiable("MutableDenseHashTable")
2429ops.NotDifferentiable("MutableDenseHashTableV2")
2430ops.NotDifferentiable("MutableHashTable")
2431ops.NotDifferentiable("MutableHashTableV2")
2432ops.NotDifferentiable("MutableHashTableOfTensors")
2433ops.NotDifferentiable("MutableHashTableOfTensorsV2")
2434