• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import uuid
24
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_lookup_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import string_ops
39# go/tf-wildcard-import
40# pylint: disable=wildcard-import
41from tensorflow.python.ops.gen_lookup_ops import *
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.training.saver import BaseSaverBuilder
44# pylint: enable=wildcard-import
45from tensorflow.python.training.tracking import base as trackable_base
46from tensorflow.python.training.tracking import tracking as trackable
47from tensorflow.python.util import compat as compat_util
48from tensorflow.python.util.deprecation import deprecated
49from tensorflow.python.util.tf_export import tf_export
50
51
52@tf_export(v1=["initialize_all_tables"])
53@deprecated(None, "Use `tf.tables_initializer` instead.")
54def initialize_all_tables(name="init_all_tables"):
55  """Returns an Op that initializes all tables of the default graph.
56
57  Args:
58    name: Optional name for the initialization op.
59
60  Returns:
61    An Op that initializes all tables.  Note that if there are
62    not tables the returned Op is a NoOp.
63  """
64  return tables_initializer(name)
65
66
67@tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
68def tables_initializer(name="init_all_tables"):
69  """Returns an Op that initializes all tables of the default graph.
70
71  Args:
72    name: Optional name for the initialization op.
73
74  Returns:
75    An Op that initializes all tables.  Note that if there are
76    not tables the returned Op is a NoOp.
77
78  @compatibility(TF2)
79  `tf.compat.v1.tables_initializer` is no longer needed with eager execution and
80  `tf.function`. In TF2, when creating an initializable table like a
81  `tf.lookup.StaticHashTable`, the table will automatically be initialized on
82  creation.
83
84  #### Before & After Usage Example
85
86  Before:
87
88  >>> with tf.compat.v1.Session():
89  ...   init = tf.compat.v1.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
90  ...   table = tf.compat.v1.lookup.StaticHashTable(init, default_value=-1)
91  ...   tf.compat.v1.tables_initializer().run()
92  ...   result = table.lookup(tf.constant(['a', 'c'])).eval()
93  >>> result
94  array([ 1, -1], dtype=int32)
95
96  After:
97
98  >>> init = tf.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
99  >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
100  >>> table.lookup(tf.constant(['a', 'c'])).numpy()
101  array([ 1, -1], dtype=int32)
102
103  @end_compatibility
104  """
105  initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
106  if initializers:
107    return control_flow_ops.group(*initializers, name=name)
108  return control_flow_ops.no_op(name=name)
109
110
111def check_table_dtypes(table, key_dtype, value_dtype):
112  """Check that the given key_dtype and value_dtype matches the table dtypes.
113
114  Args:
115    table: The table to check types against to.
116    key_dtype: The key data type to check.
117    value_dtype: The value data type to check.
118
119  Raises:
120    TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
121      types.
122  """
123  if key_dtype.base_dtype != table.key_dtype:
124    raise TypeError(f"Invalid key dtype for table, expected {table.key_dtype} "
125                    f"but got {key_dtype}.")
126  if value_dtype.base_dtype != table.value_dtype:
127    raise TypeError("Invalid value dtype for table, expected "
128                    f"{table.value_dtype} but got {value_dtype}.")
129
130
131class LookupInterface(trackable.TrackableResource):
132  """Represent a lookup table that persists across different steps."""
133
134  def __init__(self, key_dtype, value_dtype):
135    """Construct a lookup table interface.
136
137    Args:
138      key_dtype: The table key type.
139      value_dtype: The table value type.
140    """
141    self._key_dtype = dtypes.as_dtype(key_dtype)
142    self._value_dtype = dtypes.as_dtype(value_dtype)
143    super(LookupInterface, self).__init__()
144
145  def _create_resource(self):
146    raise NotImplementedError
147
148  @property
149  def key_dtype(self):
150    """The table key dtype."""
151    return self._key_dtype
152
153  @property
154  def value_dtype(self):
155    """The table value dtype."""
156    return self._value_dtype
157
158  @property
159  def name(self):
160    """The name of the table."""
161    return NotImplementedError
162
163  def size(self, name=None):
164    """Compute the number of elements in this table."""
165    raise NotImplementedError
166
167  def lookup(self, keys, name=None):
168    """Looks up `keys` in a table, outputs the corresponding values."""
169    raise NotImplementedError
170
171  def __getitem__(self, keys):
172    """Looks up `keys` in a table, outputs the corresponding values."""
173    return self.lookup(keys)
174
175
176class InitializableLookupTableBase(LookupInterface):
177  """Initializable lookup table interface.
178
179  An initializable lookup tables persist across different steps.
180  """
181
182  def __init__(self, default_value, initializer):
183    """Construct a table object from a table reference.
184
185    If requires a table initializer object (subclass of `TableInitializerBase`).
186    It provides the table key and value types, as well as the op to initialize
187    the table. The caller is responsible to execute the initialization op.
188
189    Args:
190      default_value: The value to use if a key is missing in the table.
191      initializer: The table initializer to use.
192    """
193    super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
194                                                       initializer.value_dtype)
195    self._default_value = ops.convert_to_tensor(
196        default_value, dtype=self._value_dtype)
197    self._default_value.get_shape().merge_with(tensor_shape.TensorShape([]))
198    if isinstance(initializer, trackable_base.Trackable):
199      self._initializer = self._track_trackable(initializer, "_initializer")
200    with ops.init_scope():
201      self._resource_handle = self._create_resource()
202    if (not context.executing_eagerly() and
203        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
204      with ops.init_scope():
205        self._init_op = self._initialize()
206    else:
207      self._init_op = self._initialize()
208
209  def _initialize(self):
210    return self._initializer.initialize(self)
211
212  @property
213  def default_value(self):
214    """The default value of the table."""
215    return self._default_value
216
217  def size(self, name=None):
218    """Compute the number of elements in this table.
219
220    Args:
221      name: A name for the operation (optional).
222
223    Returns:
224      A scalar tensor containing the number of elements in this table.
225    """
226    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
227      return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
228
229  def lookup(self, keys, name=None):
230    """Looks up `keys` in a table, outputs the corresponding values.
231
232    The `default_value` is used for keys not present in the table.
233
234    Args:
235      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
236      name: A name for the operation (optional).
237
238    Returns:
239      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
240      otherwise a dense `Tensor`.
241
242    Raises:
243      TypeError: when `keys` or `default_value` doesn't match the table data
244        types.
245    """
246    key_tensor = keys
247    if isinstance(keys,
248                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
249      key_tensor = keys.values
250
251    if keys.dtype.base_dtype != self._key_dtype:
252      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
253                      f"received: {keys.dtype}")
254
255    with ops.name_scope(
256        name, "%s_Lookup" % self.name,
257        (self.resource_handle, key_tensor, self._default_value)):
258      values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle,
259                                                   key_tensor,
260                                                   self._default_value)
261
262    values.set_shape(key_tensor.get_shape())
263    if isinstance(keys, sparse_tensor.SparseTensor):
264      return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
265    elif isinstance(keys, ragged_tensor.RaggedTensor):
266      return keys.with_values(values)
267    else:
268      return values
269
270
271class InitializableLookupTableBaseV1(InitializableLookupTableBase):
272
273  @property
274  def initializer(self):
275    return self._init_op
276
277
278@tf_export("lookup.StaticHashTable", v1=[])
279class StaticHashTable(InitializableLookupTableBase):
280  """A generic hash table that is immutable once initialized.
281
282  Example usage:
283
284  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
285  >>> vals_tensor = tf.constant([7, 8, 9])
286  >>> input_tensor = tf.constant(['a', 'f'])
287  >>> table = tf.lookup.StaticHashTable(
288  ...     tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
289  ...     default_value=-1)
290  >>> table.lookup(input_tensor).numpy()
291  array([ 7, -1], dtype=int32)
292
293  Or for more pythonic code:
294
295  >>> table[input_tensor].numpy()
296  array([ 7, -1], dtype=int32)
297
298  The result of a lookup operation has the same shape as the argument:
299
300  >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']])
301  >>> table[input_tensor].numpy()
302  array([[ 7,  8],
303         [ 9, -1]], dtype=int32)
304
305
306  """
307
308  def __init__(self, initializer, default_value, name=None):
309    """Creates a non-initialized `HashTable` object.
310
311    Creates a table, the type of its keys and values are specified by the
312    initializer.
313    Before using the table you will have to initialize it. After initialization
314    the table will be immutable.
315
316    Args:
317      initializer: The table initializer to use. See `HashTable` kernel for
318        supported key and value types.
319      default_value: The value to use if a key is missing in the table.
320      name: A name for the operation (optional).
321
322    Returns:
323      A `HashTable` object.
324    """
325    self._initializer = initializer
326    self._default_value = default_value
327    self._shared_name = self._initializer._shared_name  # pylint: disable=protected-access
328    if not self._shared_name:
329      # Force using a shared name so that StaticHashTable resources can be
330      # shared across different kernels. If no "shared_name" is set and
331      # "use_node_name_sharing" is False, then each kernel gets its own local
332      # resource.
333      self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),)
334    self._name = name or "hash_table"
335    self._table_name = None
336    super(StaticHashTable, self).__init__(default_value, initializer)
337    self._value_shape = self._default_value.get_shape()
338
339  def _create_resource(self):
340    table_ref = gen_lookup_ops.hash_table_v2(
341        shared_name=self._shared_name,
342        key_dtype=self._initializer.key_dtype,
343        value_dtype=self._initializer.value_dtype,
344        name=self._name)
345    if context.executing_eagerly():
346      self._table_name = None
347    else:
348      self._table_name = table_ref.op.name.split("/")[-1]
349    return table_ref
350
351  @property
352  def name(self):
353    return self._table_name
354
355  def export(self, name=None):
356    """Returns tensors of all keys and values in the table.
357
358    Args:
359      name: A name for the operation (optional).
360
361    Returns:
362      A pair of tensors with the first tensor containing all keys and the
363        second tensors containing all values in the table.
364    """
365    with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
366      exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
367          self.resource_handle, self._key_dtype, self._value_dtype)
368
369    exported_values.set_shape(exported_keys.get_shape().concatenate(
370        self._value_shape))
371    return exported_keys, exported_values
372
373
374@tf_export(v1=["lookup.StaticHashTable"])
375class StaticHashTableV1(StaticHashTable):
376  """A generic hash table that is immutable once initialized.
377
378  When running in graph mode, you must evaluate the tensor returned by
379  `tf.tables_initializer()` before evaluating the tensor returned by
380  this class's `lookup()` method. Example usage in graph mode:
381
382  ```python
383  keys_tensor = tf.constant([1, 2])
384  vals_tensor = tf.constant([3, 4])
385  input_tensor = tf.constant([1, 5])
386  table = tf.lookup.StaticHashTable(
387      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
388  out = table.lookup(input_tensor)
389  with tf.Session() as sess:
390      sess.run(tf.tables_initializer())
391      print(sess.run(out))
392  ```
393
394  In eager mode, no special code is needed to initialize the table.
395  Example usage in eager mode:
396
397  ```python
398  tf.enable_eager_execution()
399  keys_tensor = tf.constant([1, 2])
400  vals_tensor = tf.constant([3, 4])
401  input_tensor = tf.constant([1, 5])
402  table = tf.lookup.StaticHashTable(
403      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
404  print(table.lookup(input_tensor))
405  ```
406  """
407
408  @property
409  def initializer(self):
410    return self._init_op
411
412
413# For backwards compatibility. This will be removed in TF 2.0.
414class HashTable(StaticHashTableV1):
415
416  @property
417  def init(self):
418    return self.initializer
419
420
421class TableInitializerBase(trackable_base.Trackable):
422  """Base class for lookup table initializers."""
423
424  def __init__(self, key_dtype, value_dtype):
425    """Construct a table initializer object.
426
427    Args:
428      key_dtype: Type of the table keys.
429      value_dtype: Type of the table values.
430    """
431    self._key_dtype = dtypes.as_dtype(key_dtype)
432    self._value_dtype = dtypes.as_dtype(value_dtype)
433
434  @property
435  def key_dtype(self):
436    """The expected table key dtype."""
437    return self._key_dtype
438
439  @property
440  def value_dtype(self):
441    """The expected table value dtype."""
442    return self._value_dtype
443
444  def initialize(self, table):
445    """Returns the table initialization op."""
446    raise NotImplementedError
447
448  @property
449  def _shared_name(self):
450    """Returns a shared name to be used by the table."""
451    shared_name = ""
452    if context.executing_eagerly():
453      # Ensure a unique name when eager execution is enabled to avoid spurious
454      # sharing issues.
455      # TODO(rohanj): Use context.shared_name() instead.
456      shared_name += str(ops.uid())
457    return shared_name
458
459
460@tf_export("lookup.KeyValueTensorInitializer")
461class KeyValueTensorInitializer(TableInitializerBase):
462  """Table initializers given `keys` and `values` tensors.
463
464  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
465  >>> vals_tensor = tf.constant([7, 8, 9])
466  >>> input_tensor = tf.constant(['a', 'f'])
467  >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor)
468  >>> table = tf.lookup.StaticHashTable(
469  ...     init,
470  ...     default_value=-1)
471  >>> table.lookup(input_tensor).numpy()
472  array([ 7, -1], dtype=int32)
473
474  """
475
476  def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
477    """Constructs a table initializer object based on keys and values tensors.
478
479    Args:
480      keys: The tensor for the keys.
481      values: The tensor for the values.
482      key_dtype: The `keys` data type. Used when `keys` is a python array.
483      value_dtype: The `values` data type. Used when `values` is a python array.
484      name: A name for the operation (optional).
485    """
486    if (not context.executing_eagerly() and
487        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
488      with ops.init_scope():
489        self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
490        self._values = ops.convert_to_tensor(
491            values, dtype=value_dtype, name="values")
492    else:
493      self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
494      self._values = ops.convert_to_tensor(
495          values, dtype=value_dtype, name="values")
496    self._name = name if name is not None else "key_value_init"
497    if context.executing_eagerly():
498      # Ensure a unique name when eager execution is enabled to avoid spurious
499      # sharing issues.
500      # TODO(rohanj): Use context.shared_name() instead.
501      self._name += str(ops.uid())
502
503    super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
504                                                    self._values.dtype)
505
506  def initialize(self, table):
507    """Initializes the given `table` with `keys` and `values` tensors.
508
509    Args:
510      table: The table to initialize.
511
512    Returns:
513      The operation that initializes the table.
514
515    Raises:
516      TypeError: when the keys and values data types do not match the table
517      key and value data types.
518    """
519    check_table_dtypes(table, self._keys.dtype, self._values.dtype)
520    with ops.name_scope(
521        self._name, values=(table.resource_handle, self._keys, self._values)):
522      init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle,
523                                                      self._keys, self._values)
524    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
525    return init_op
526
527
528@tf_export("lookup.TextFileIndex")
529class TextFileIndex(object):
530  """The key and value content to get from each line.
531
532  This class defines the key and value used for `tf.lookup.TextFileInitializer`.
533
534  The key and value content to get from each line is specified either
535  by the following, or a value `>=0`.
536  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
537    expects data type int64.
538  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
539    type string.
540
541  A value `>=0` means use the index (starting at zero) of the split line based
542      on `delimiter`.
543  """
544  WHOLE_LINE = -2
545  LINE_NUMBER = -1
546
547
548@tf_export("lookup.TextFileInitializer")
549class TextFileInitializer(TableInitializerBase):
550  r"""Table initializers from a text file.
551
552  This initializer assigns one entry in the table for each line in the file.
553
554  The key and value type of the table to initialize is given by `key_dtype` and
555  `value_dtype`.
556
557  The key and value content to get from each line is specified by
558  the `key_index` and `value_index`.
559
560  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
561    expects data type int64.
562  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
563    type string.
564  * A value `>=0` means use the index (starting at zero) of the split line based
565      on `delimiter`.
566
567  For example if we have a file with the following content:
568
569  >>> import tempfile
570  >>> f = tempfile.NamedTemporaryFile(delete=False)
571  >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",])
572  >>> f.file.write(content.encode('utf-8'))
573  >>> f.file.close()
574
575  The following snippet initializes a table with the first column as keys and
576  second column as values:
577
578  * `emerson -> 10`
579  * `lake -> 20`
580  * `palmer -> 30`
581
582  >>> init= tf.lookup.TextFileInitializer(
583  ...    filename=f.name,
584  ...    key_dtype=tf.string, key_index=0,
585  ...    value_dtype=tf.int64, value_index=1,
586  ...    delimiter=" ")
587  >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
588  >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy()
589
590  Similarly to initialize the whole line as keys and the line number as values.
591
592  * `emerson 10 -> 0`
593  * `lake 20 -> 1`
594  * `palmer 30 -> 2`
595
596  >>> init = tf.lookup.TextFileInitializer(
597  ...   filename=f.name,
598  ...   key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
599  ...   value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
600  >>> table = tf.lookup.StaticHashTable(init, -1)
601  >>> table.lookup(tf.constant('palmer 30')).numpy()
602  2
603  """
604
605  def __init__(self,
606               filename,
607               key_dtype,
608               key_index,
609               value_dtype,
610               value_index,
611               vocab_size=None,
612               delimiter="\t",
613               name=None,
614               value_index_offset=0):
615    """Constructs a table initializer object to populate from a text file.
616
617    It generates one key-value pair per line. The type of table key and
618    value are specified by `key_dtype` and `value_dtype`, respectively.
619    Similarly the content of the key and value are specified by the key_index
620    and value_index.
621
622    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
623      expects data type int64.
624    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
625      type string or int64.
626    - A value >=0 means use the index (starting at zero) of the split line based
627      on `delimiter`.
628
629    Args:
630      filename: The filename of the text file to be used for initialization. The
631        path must be accessible from wherever the graph is initialized (eg.
632        trainer or eval workers). The filename may be a scalar `Tensor`.
633      key_dtype: The `key` data type.
634      key_index: the index that represents information of a line to get the
635        table 'key' values from.
636      value_dtype: The `value` data type.
637      value_index: the index that represents information of a line to get the
638        table 'value' values from.'
639      vocab_size: The number of elements in the file, if known.
640      delimiter: The delimiter to separate fields in a line.
641      name: A name for the operation (optional).
642      value_index_offset: A number to add to all indices extracted from the file
643        This is useful for cases where a user would like to reserve one or more
644        low index values for control characters. For instance, if you would
645        like to ensure that no vocabulary item is mapped to index 0 (so you can
646        reserve 0 for a masking value), you can set value_index_offset to 1;
647        this will mean that the first vocabulary element is mapped to 1
648        instead of 0.
649
650    Raises:
651      ValueError: when the filename is empty, or when the table key and value
652      data types do not match the expected data types.
653    """
654    if not isinstance(filename, ops.Tensor) and not filename:
655      raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer")
656
657    self._filename_arg = filename
658    key_dtype = dtypes.as_dtype(key_dtype)
659    value_dtype = dtypes.as_dtype(value_dtype)
660
661    if key_index < -2:
662      raise ValueError("`key_index` should be >= -2, received: {key_index}.")
663
664    if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
665      raise ValueError("`key_dtype` must be int64 if `key_index` is "
666                       f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}")
667    if ((key_index == TextFileIndex.WHOLE_LINE) and
668        (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
669      raise ValueError(
670          "`key_dtype` should be either integer or string for `key_index` "
671          f"{TextFileIndex.WHOLE_LINE}, received: {key_dtype}")
672    if value_index < -2:
673      raise ValueError("`value_index` should be >= -2, received: "
674                       f"{value_index}")
675
676    if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
677      raise ValueError("`value_dtype` must be int64 for `value_index` "
678                       f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}")
679    if ((value_index == TextFileIndex.WHOLE_LINE) and
680        (not value_dtype.is_integer) and (value_dtype != dtypes.string)):
681      raise ValueError(
682          "`value_dtype` should be either integer or string for `value_index` "
683          f"{TextFileIndex.WHOLE_LINE}, received: {value_dtype}")
684
685    if (vocab_size is not None) and (vocab_size <= 0):
686      raise ValueError(f"`vocab_size` should be > 0, received: {vocab_size}")
687
688    self._key_index = key_index
689    self._value_index = value_index
690    self._vocab_size = vocab_size
691    self._delimiter = delimiter
692    self._name = name
693    self._filename = self._track_trackable(
694        trackable.Asset(filename), "_filename")
695    self._offset = value_index_offset
696
697    super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
698
699  def initialize(self, table):
700    """Initializes the table from a text file.
701
702    Args:
703      table: The table to be initialized.
704
705    Returns:
706      The operation that initializes the table.
707
708    Raises:
709      TypeError: when the keys and values data types do not match the table
710      key and value data types.
711    """
712    check_table_dtypes(table, self.key_dtype, self.value_dtype)
713    with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
714      filename = ops.convert_to_tensor(
715          self._filename, dtypes.string, name="asset_filepath")
716      init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
717          table.resource_handle, filename, self._key_index, self._value_index,
718          -1 if self._vocab_size is None else self._vocab_size, self._delimiter,
719          self._offset)
720    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
721    # If the filename tensor is anything other than a string constant (e.g.,
722    # if it is a placeholder) then it does not make sense to track it as an
723    # asset.
724    if not context.executing_eagerly() and constant_op.is_constant(filename):
725      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
726    return init_op
727
728  @property
729  def _shared_name(self):
730    if self._vocab_size:
731      # Keep the shared_name:
732      # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>_<offset>
733      if self._offset:
734        shared_name = "hash_table_%s_%d_%s_%s_%s" % (
735            self._filename_arg, self._vocab_size, self._key_index,
736            self._value_index, self._offset)
737      else:
738        shared_name = "hash_table_%s_%d_%s_%s" % (
739            self._filename_arg, self._vocab_size, self._key_index,
740            self._value_index)
741    else:
742      # Keep the shared_name
743      # <table_type>_<filename>_<key_index>_<value_index>_<offset>
744      if self._offset:
745        shared_name = "hash_table_%s_%s_%s_%s" % (
746            self._filename_arg, self._key_index, self._value_index,
747            self._offset)
748      else:
749        shared_name = "hash_table_%s_%s_%s" % (
750            self._filename_arg, self._key_index, self._value_index)
751
752    return shared_name
753
754
755class TextFileStringTableInitializer(TextFileInitializer):
756  """Table initializer for `int64` IDs to string tables from a text file."""
757
758  def __init__(self,
759               filename,
760               key_column_index=TextFileIndex.LINE_NUMBER,
761               value_column_index=TextFileIndex.WHOLE_LINE,
762               vocab_size=None,
763               delimiter="\t",
764               name="text_file_string_table_init"):
765    """Constructs an initializer for an id-to-string table from a text file.
766
767    It populates a table that its key and value types are int64 and string,
768    respectively. It generates one key-value pair per line.
769    The content of the key and value are specified by `key_column_index`
770    and `value_column_index`.
771
772    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
773      expects data type int64.
774    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
775      type string or int64.
776    - A value >=0 means use the index (starting at zero) of the split line based
777      on `delimiter`.
778
779    Args:
780      filename: The filename of the text file to be used for initialization. The
781        path must be accessible from wherever the graph is initialized (eg.
782        trainer or eval workers). The filename may be a scalar `Tensor`.
783      key_column_index: The column index from the text file to get the keys
784        from. The default is to use the line number, starting from zero.
785      value_column_index: The column index from the text file to get the values
786        from. The default is to use the whole line content.
787      vocab_size: The number of elements in the file, if known.
788      delimiter: The delimiter to separate fields in a line.
789      name: Optional name for the op.
790
791    Raises:
792      TypeError: when the filename is empty, or when the table key and value
793      data types do not match the expected data types.
794    """
795    super(TextFileStringTableInitializer, self).__init__(
796        filename,
797        dtypes.int64,
798        key_column_index,
799        dtypes.string,
800        value_column_index,
801        vocab_size=vocab_size,
802        delimiter=delimiter,
803        name=name)
804
805
806class TextFileIdTableInitializer(TextFileInitializer):
807  """Table initializer for string to `int64` IDs tables from a text file."""
808
809  def __init__(self,
810               filename,
811               key_column_index=TextFileIndex.WHOLE_LINE,
812               value_column_index=TextFileIndex.LINE_NUMBER,
813               vocab_size=None,
814               delimiter="\t",
815               name="text_file_id_table_init",
816               key_dtype=dtypes.string):
817    """Constructs an initializer for an string-to-id table from a text file.
818
819    It populates a table that its key and value types are string and int64,
820    respectively. It generates one key-value pair per line.
821    The content of the key and value are specified by the key_index
822    and value_index.
823
824    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
825      expects data type int64.
826    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
827      type string.
828    - A value >=0 means use the index (starting at zero) of the split line based
829      on `delimiter`.
830
831    Args:
832      filename: The filename of the text file to be used for initialization. The
833        path must be accessible from wherever the graph is initialized (eg.
834        trainer or eval workers). The filename may be a scalar `Tensor`.
835      key_column_index: The column index from the text file to get the `key`
836        values from. The default is to use the whole line content.
837      value_column_index: The column index from the text file to get the `value`
838        values from. The default is to use the line number, starting from zero.
839      vocab_size: The number of elements in the file, if known.
840      delimiter: The delimiter to separate fields in a line.
841      name: Optional name for the op.
842      key_dtype: The `key` data type.
843
844    Raises:
845      TypeError: when the filename is empty, or when the table key and value
846      data types do not match the expected data types.
847    """
848    super(TextFileIdTableInitializer, self).__init__(
849        filename,
850        key_dtype,
851        key_column_index,
852        dtypes.int64,
853        value_column_index,
854        vocab_size=vocab_size,
855        delimiter=delimiter,
856        name=name)
857
858
859class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
860  """A structure for the spec of the hashing function to use for hash buckets.
861
862  `hasher` is the name of the hashing function to use (eg. "fasthash",
863  "stronghash").
864  `key` is optional and specify the key to use for the hash function if
865  supported, currently only used by a strong hash.
866
867  Fields:
868    hasher: The hasher name to use.
869    key: The key to be used by the hashing function, if required.
870  """
871  __slots__ = ()
872
873
874FastHashSpec = HasherSpec("fasthash", None)  # pylint: disable=invalid-name
875
876
877class StrongHashSpec(HasherSpec):
878  """A structure to specify a key of the strong keyed hash spec.
879
880  The strong hash requires a `key`, which is a list of 2 unsigned integer
881  numbers. These should be non-zero; random numbers generated from random.org
882  would be a fine choice.
883
884  Fields:
885    key: The key to be used by the keyed hashing function.
886  """
887  __slots__ = ()
888
889  def __new__(cls, key):
890    if len(key) != 2:
891      raise ValueError(f"`key` must have size 2, received {len(key)}")
892
893    if not isinstance(key[0], compat_util.integral_types) or not isinstance(
894        key[1], compat_util.integral_types):
895      raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
896
897    return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
898
899
900def _as_string(tensor):
901  if dtypes.string == tensor.dtype.base_dtype:
902    return tensor
903  return string_ops.as_string(tensor)
904
905
906class IdTableWithHashBuckets(LookupInterface):
907  r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
908
909  For example, if an instance of `IdTableWithHashBuckets` is initialized with a
910  string-to-id table that maps:
911
912  * `emerson -> 0`
913  * `lake -> 1`
914  * `palmer -> 2`
915
916  The `IdTableWithHashBuckets` object will performs the following mapping:
917
918  * `emerson -> 0`
919  * `lake -> 1`
920  * `palmer -> 2`
921  * `<other term> -> bucket_id`, where bucket_id will be between `3` and
922  `3 + num_oov_buckets - 1`, calculated by:
923  `hash(<term>) % num_oov_buckets + vocab_size`
924
925  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
926  the lookup result is `[0, 1, 2, 4, 7]`.
927
928  If `table` is None, only out-of-vocabulary buckets are used.
929
930  Example usage:
931
932  ```python
933  num_oov_buckets = 3
934  input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
935  table = tf.IdTableWithHashBuckets(
936      tf.StaticHashTable(
937          tf.lookup.TextFileInitializer(
938              filename,
939              key_dtype=tf.string,
940              key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
941              value_dtype=tf.int64,
942              value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
943              delimiter="\t"),
944          default_value),
945      num_oov_buckets)
946  out = table.lookup(input_tensor).
947  table.init.run()
948  print(out.eval())
949  ```
950
951  The hash function used for generating out-of-vocabulary buckets ID is handled
952  by `hasher_spec`.
953  """
954
955  def __init__(self,
956               table,
957               num_oov_buckets,
958               hasher_spec=FastHashSpec,
959               name=None,
960               key_dtype=None):
961    """Construct a `IdTableWithHashBuckets` object.
962
963    Args:
964      table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
965      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
966      hasher_spec: A `HasherSpec` to specify the hash function to use for
967        assignation of out-of-vocabulary buckets  (optional).
968      name: A name for the operation (optional).
969      key_dtype: Data type of keys passed to `lookup`. Defaults to
970        `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must
971        be string or integer, and must be castable to `table.key_dtype`.
972
973    Raises:
974      ValueError: when `table` in None and `num_oov_buckets` is not positive.
975      TypeError: when `hasher_spec` is invalid.
976    """
977    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
978    # characters to use as table name.
979    if name:
980      name = name.rstrip("/")
981    if table:
982      if key_dtype is None:
983        key_dtype = table.key_dtype
984      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
985      if table.key_dtype not in supported_table_key_dtypes:
986        raise TypeError("Invalid `key_dtype`, expected one of "
987                        f"{supported_table_key_dtypes}, received {key_dtype}.")
988      if table.key_dtype.is_integer != key_dtype.is_integer:
989        raise TypeError("Invalid `key dtype`, expected %s but got %s." %
990                        ("integer" if key_dtype.is_integer else "non-integer",
991                         table.key_dtype))
992      if table.value_dtype != dtypes.int64:
993        raise TypeError("Invalid `value_dtype`: expected int64 but got %s." %
994                        (table.value_dtype))
995      self._table = table
996      name = name or self._table.name
997    else:
998      if num_oov_buckets <= 0:
999        raise ValueError("`oov_buckets` must be > 0 if no `table` is supplied.")
1000      key_dtype = dtypes.string if key_dtype is None else key_dtype
1001      self._table = None
1002      name = name or "hash_bucket"
1003    if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
1004      raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1005                      f"{key_dtype}.")
1006    self._num_oov_buckets = num_oov_buckets
1007
1008    if not isinstance(hasher_spec, HasherSpec):
1009      raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1010                      f"{type(hasher_spec)}.")
1011    self._hasher_spec = hasher_spec
1012    if name:
1013      self._table_name = name.split("/")[-1]
1014    else:
1015      self._table_name = None
1016    super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
1017
1018  def _create_resource(self):
1019    if self._table is not None:
1020      return self._table._create_resource()  # pylint: disable=protected-access
1021    return None
1022
1023  def _initialize(self):
1024    if self._table is not None:
1025      return self._table._initialize()  # pylint: disable=protected-access
1026    with ops.name_scope(None, "init"):
1027      return control_flow_ops.no_op()
1028
1029  @property
1030  def initializer(self):
1031    if self._table is not None:
1032      return self._table._init_op  # pylint: disable=protected-access
1033    with ops.name_scope(None, "init"):
1034      return control_flow_ops.no_op()
1035
1036  @property
1037  @deprecated("2018-12-15", "Use `initializer` instead.")
1038  def init(self):
1039    return self.initializer
1040
1041  @property
1042  def resource_handle(self):
1043    if self._table is not None:
1044      return self._table.resource_handle
1045    return None
1046
1047  @property
1048  def name(self):
1049    return self._table_name
1050
1051  def size(self, name=None):
1052    """Compute the number of elements in this table."""
1053    with ops.name_scope(name, "%s_Size" % self.name):
1054      if self._table:
1055        tsize = self._table.size()
1056      else:
1057        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1058      return tsize + self._num_oov_buckets
1059
1060  def _get_string_to_hash_bucket_fn(self, hasher_spec):
1061    """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
1062    if not isinstance(hasher_spec, HasherSpec):
1063      raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1064                      f"{type(hasher_spec)}.")
1065    if hasher_spec.hasher == "fasthash":
1066      return string_ops.string_to_hash_bucket_fast
1067    if hasher_spec.hasher == "legacy":
1068      return string_ops.string_to_hash_bucket
1069    if hasher_spec.hasher == "stronghash":
1070      return functools.partial(
1071          string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
1072    raise ValueError(
1073        f"Found unknown hasher {hasher_spec.hasher} in `hasher_spec`")
1074
1075  def lookup(self, keys, name=None):
1076    """Looks up `keys` in the table, outputs the corresponding values.
1077
1078    It assigns out-of-vocabulary keys to buckets based in their hashes.
1079
1080    Args:
1081      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1082      name: Optional name for the op.
1083
1084    Returns:
1085      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1086      otherwise a dense `Tensor`.
1087
1088    Raises:
1089      TypeError: when `keys` doesn't match the table key data type.
1090    """
1091    if keys.dtype.base_dtype != self._key_dtype:
1092      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1093                      f"received: {keys.dtype}")
1094    values = keys
1095    if isinstance(keys,
1096                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1097      values = keys.values
1098    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1099      values = math_ops.cast(values, dtypes.int64)
1100
1101    if self._num_oov_buckets == 0:
1102      ids = self._table.lookup(values, name=name)
1103    else:
1104      # TODO(yleon): Consider moving this functionality to its own kernel.
1105      with ops.name_scope(name, "%s_Lookup" % self.name):
1106        str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
1107            self._hasher_spec)
1108        buckets = str_to_hash_bucket(
1109            _as_string(values),
1110            num_buckets=self._num_oov_buckets,
1111            name="hash_bucket")
1112        if self._table:
1113          ids = self._table.lookup(values)
1114          buckets = math_ops.add(buckets, self._table.size())
1115          is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1116          ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1117        else:
1118          ids = buckets
1119    if isinstance(keys, sparse_tensor.SparseTensor):
1120      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1121    elif isinstance(keys, ragged_tensor.RaggedTensor):
1122      return keys.with_values(ids)
1123    return ids
1124
1125
1126@tf_export("lookup.StaticVocabularyTable", v1=[])
1127class StaticVocabularyTable(LookupInterface):
1128  r"""String to Id table that assigns out-of-vocabulary keys to hash buckets.
1129
1130  For example, if an instance of `StaticVocabularyTable` is initialized with a
1131  string-to-id initializer that maps:
1132
1133  >>> init = tf.lookup.KeyValueTensorInitializer(
1134  ...     keys=tf.constant(['emerson', 'lake', 'palmer']),
1135  ...     values=tf.constant([0, 1, 2], dtype=tf.int64))
1136  >>> table = tf.lookup.StaticVocabularyTable(
1137  ...    init,
1138  ...    num_oov_buckets=5)
1139
1140  The `Vocabulary` object will performs the following mapping:
1141
1142  * `emerson -> 0`
1143  * `lake -> 1`
1144  * `palmer -> 2`
1145  * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and
1146  `3 + num_oov_buckets - 1 = 7`, calculated by:
1147  `hash(<term>) % num_oov_buckets + vocab_size`
1148
1149  If input_tensor is:
1150
1151  >>> input_tensor = tf.constant(["emerson", "lake", "palmer",
1152  ...                             "king", "crimson"])
1153  >>> table[input_tensor].numpy()
1154  array([0, 1, 2, 6, 7])
1155
1156  If `initializer` is None, only out-of-vocabulary buckets are used.
1157
1158  Example usage:
1159
1160  >>> num_oov_buckets = 3
1161  >>> vocab = ["emerson", "lake", "palmer", "crimnson"]
1162  >>> import tempfile
1163  >>> f = tempfile.NamedTemporaryFile(delete=False)
1164  >>> f.write('\n'.join(vocab).encode('utf-8'))
1165  >>> f.close()
1166
1167  >>> init = tf.lookup.TextFileInitializer(
1168  ...     f.name,
1169  ...     key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
1170  ...     value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
1171  >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
1172  >>> table.lookup(tf.constant(["palmer", "crimnson" , "king",
1173  ...                           "tarkus", "black", "moon"])).numpy()
1174  array([2, 3, 5, 6, 6, 4])
1175
1176  The hash function used for generating out-of-vocabulary buckets ID is
1177  Fingerprint64.
1178
1179  Note that the out-of-vocabulary bucket IDs always range from the table `size`
1180  up to `size + num_oov_buckets - 1` regardless of the table values, which could
1181  cause unexpected collisions:
1182
1183  >>> init = tf.lookup.KeyValueTensorInitializer(
1184  ...     keys=tf.constant(["emerson", "lake", "palmer"]),
1185  ...     values=tf.constant([1, 2, 3], dtype=tf.int64))
1186  >>> table = tf.lookup.StaticVocabularyTable(
1187  ...     init,
1188  ...     num_oov_buckets=1)
1189  >>> input_tensor = tf.constant(["emerson", "lake", "palmer", "king"])
1190  >>> table[input_tensor].numpy()
1191  array([1, 2, 3, 3])
1192  """
1193
1194  def __init__(self,
1195               initializer,
1196               num_oov_buckets,
1197               lookup_key_dtype=None,
1198               name=None):
1199    """Construct a `StaticVocabularyTable` object.
1200
1201    Args:
1202      initializer: A `TableInitializerBase` object that contains the data used
1203        to initialize the table. If None, then we only use out-of-vocab buckets.
1204      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
1205        be greater than zero.
1206      lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
1207        `initializer.key_dtype` if `initializer` is specified, otherwise
1208        `tf.string`. Must be string or integer, and must be castable to
1209        `initializer.key_dtype`.
1210      name: A name for the operation (optional).
1211
1212    Raises:
1213      ValueError: when `num_oov_buckets` is not positive.
1214      TypeError: when lookup_key_dtype or initializer.key_dtype are not
1215        integer or string. Also when initializer.value_dtype != int64.
1216    """
1217    if num_oov_buckets <= 0:
1218      raise ValueError("`num_oov_buckets` must be > 0.")
1219    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1220    # characters to use as table name.
1221    if name:
1222      name = name.rstrip("/")
1223    if initializer:
1224      if lookup_key_dtype is None:
1225        lookup_key_dtype = initializer.key_dtype
1226      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1227      if initializer.key_dtype not in supported_table_key_dtypes:
1228        raise TypeError("Invalid `key_dtype`, expected one of %s, but got %s." %
1229                        (supported_table_key_dtypes, initializer.key_dtype))
1230      if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
1231        raise TypeError(
1232            "Invalid `key_dtype`, expected %s but got %s." %
1233            ("integer" if lookup_key_dtype.is_integer else "non-integer",
1234             initializer.key_dtype))
1235      if initializer.value_dtype != dtypes.int64:
1236        raise TypeError("Invalid `value_dtype`, expected %s but got %s." %
1237                        (dtypes.int64, initializer.value_dtype))
1238      if isinstance(initializer, trackable_base.Trackable):
1239        self._initializer = self._track_trackable(initializer, "_initializer")
1240      self._table = HashTable(initializer, default_value=-1)
1241      name = name or self._table.name
1242    else:
1243      lookup_key_dtype = dtypes.string
1244      self._table = None
1245      name = name or "hash_bucket"
1246    if (not lookup_key_dtype.is_integer) and (dtypes.string !=
1247                                              lookup_key_dtype):
1248      raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1249                      f"{lookup_key_dtype}")
1250    self._num_oov_buckets = num_oov_buckets
1251
1252    self._table_name = None
1253    if name is not None:
1254      self._table_name = name.split("/")[-1]
1255    super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
1256
1257  def _create_resource(self):
1258    if self._table is not None:
1259      return self._table._create_resource()  # pylint: disable=protected-access
1260    return None
1261
1262  def _initialize(self):
1263    if self._table is not None:
1264      return self._table._initialize()  # pylint: disable=protected-access
1265    with ops.name_scope(None, "init"):
1266      return control_flow_ops.no_op()
1267
1268  @property
1269  def resource_handle(self):
1270    if self._table is not None:
1271      return self._table.resource_handle
1272    return None
1273
1274  @property
1275  def name(self):
1276    return self._table_name
1277
1278  def size(self, name=None):
1279    """Compute the number of elements in this table."""
1280    with ops.name_scope(name, "%s_Size" % self.name):
1281      if self._table:
1282        tsize = self._table.size()
1283      else:
1284        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1285      return tsize + self._num_oov_buckets
1286
1287  def lookup(self, keys, name=None):
1288    """Looks up `keys` in the table, outputs the corresponding values.
1289
1290    It assigns out-of-vocabulary keys to buckets based in their hashes.
1291
1292    Args:
1293      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1294      name: Optional name for the op.
1295
1296    Returns:
1297      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1298      otherwise a dense `Tensor`.
1299
1300    Raises:
1301      TypeError: when `keys` doesn't match the table key data type.
1302    """
1303    if keys.dtype.base_dtype != self._key_dtype:
1304      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1305                      f"received: {keys.dtype}")
1306    values = keys
1307    if isinstance(keys,
1308                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1309      values = keys.values
1310    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1311      values = math_ops.cast(values, dtypes.int64)
1312
1313    # TODO(yleon): Consider moving this functionality to its own kernel.
1314    with ops.name_scope(name, "%s_Lookup" % self.name):
1315      buckets = string_ops.string_to_hash_bucket_fast(
1316          _as_string(values),
1317          num_buckets=self._num_oov_buckets,
1318          name="hash_bucket")
1319      if self._table:
1320        ids = self._table.lookup(values)
1321        buckets = math_ops.add(buckets, self._table.size())
1322        is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1323        ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1324      else:
1325        ids = buckets
1326    if isinstance(keys, sparse_tensor.SparseTensor):
1327      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1328    elif isinstance(keys, ragged_tensor.RaggedTensor):
1329      return keys.with_values(ids)
1330    return ids
1331
1332
1333@tf_export(v1=["lookup.StaticVocabularyTable"])
1334class StaticVocabularyTableV1(StaticVocabularyTable):
1335
1336  @property
1337  def initializer(self):
1338    if self._table is not None:
1339      return self._table._init_op  # pylint: disable=protected-access
1340    with ops.name_scope(None, "init"):
1341      return control_flow_ops.no_op()
1342
1343
1344def index_table_from_file(vocabulary_file=None,
1345                          num_oov_buckets=0,
1346                          vocab_size=None,
1347                          default_value=-1,
1348                          hasher_spec=FastHashSpec,
1349                          key_dtype=dtypes.string,
1350                          name=None,
1351                          key_column_index=TextFileIndex.WHOLE_LINE,
1352                          value_column_index=TextFileIndex.LINE_NUMBER,
1353                          delimiter="\t"):
1354  """Returns a lookup table that converts a string tensor into int64 IDs.
1355
1356  This operation constructs a lookup table to convert tensor of strings into
1357  int64 IDs. The mapping can be initialized from a vocabulary file specified in
1358  `vocabulary_file`, where the whole line is the key and the zero-based line
1359  number is the ID.
1360
1361  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1362  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1363  `default_value`.
1364  The bucket ID range is
1365  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
1366
1367  The underlying table must be initialized by calling
1368  `session.run(tf.compat.v1.tables_initializer())` or
1369  `session.run(table.init())` once.
1370
1371  To specify multi-column vocabulary files, use key_column_index and
1372  value_column_index and delimiter.
1373
1374  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1375    expects data type int64.
1376  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1377    type string.
1378  - A value >=0 means use the index (starting at zero) of the split line based
1379    on `delimiter`.
1380
1381  Sample Usages:
1382
1383  If we have a vocabulary file "test.txt" with the following content:
1384
1385  ```
1386  emerson
1387  lake
1388  palmer
1389  ```
1390
1391  ```python
1392  features = tf.constant(["emerson", "lake", "and", "palmer"])
1393  table = tf.lookup.index_table_from_file(
1394      vocabulary_file="test.txt", num_oov_buckets=1)
1395  ids = table.lookup(features)
1396  ...
1397  tf.compat.v1.tables_initializer().run()
1398
1399  ids.eval()  ==> [0, 1, 3, 2]  # where 3 is the out-of-vocabulary bucket
1400  ```
1401
1402  Args:
1403    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1404    num_oov_buckets: The number of out-of-vocabulary buckets.
1405    vocab_size: Number of the elements in the vocabulary, if known.
1406    default_value: The value to use for out-of-vocabulary feature values.
1407      Defaults to -1.
1408    hasher_spec: A `HasherSpec` to specify the hash function to use for
1409      assignation of out-of-vocabulary buckets.
1410    key_dtype: The `key` data type.
1411    name: A name for this op (optional).
1412    key_column_index: The column index from the text file to get the `key`
1413      values from. The default is to use the whole line content.
1414    value_column_index: The column index from the text file to get the `value`
1415      values from. The default is to use the line number, starting from zero.
1416    delimiter: The delimiter to separate fields in a line.
1417
1418  Returns:
1419    The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
1420
1421  Raises:
1422    ValueError: If `vocabulary_file` is not set.
1423    ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
1424      than zero.
1425  """
1426  if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types)
1427                                 and not vocabulary_file):
1428    raise ValueError(
1429        "`vocabulary_file` must be specified and must not be empty.")
1430  if num_oov_buckets < 0:
1431    raise ValueError(
1432        "num_oov_buckets must be greater or equal than 0, got %d." %
1433        num_oov_buckets)
1434  if vocab_size is not None and vocab_size < 1:
1435    vocab_file_value = vocabulary_file
1436    if isinstance(vocabulary_file, ops.Tensor):
1437      vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
1438    raise ValueError("`vocab_size` must be greater than 0, got %d for "
1439                     "vocabulary_file: %s." % (vocab_size, vocab_file_value))
1440  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
1441    raise TypeError("Dtype for `keys` should be either integer or string.")
1442
1443  with ops.name_scope(name, "string_to_index"):
1444    table = None
1445    with ops.name_scope(None, "hash_table"):
1446      init = TextFileIdTableInitializer(
1447          vocabulary_file,
1448          vocab_size=vocab_size,
1449          key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
1450          name="table_init",
1451          key_column_index=key_column_index,
1452          value_column_index=value_column_index,
1453          delimiter=delimiter)
1454
1455      table = StaticHashTableV1(init, default_value)
1456    if num_oov_buckets:
1457      table = IdTableWithHashBuckets(
1458          table,
1459          num_oov_buckets=num_oov_buckets,
1460          hasher_spec=hasher_spec,
1461          key_dtype=key_dtype)
1462
1463    return table
1464
1465
1466def index_table_from_tensor(vocabulary_list,
1467                            num_oov_buckets=0,
1468                            default_value=-1,
1469                            hasher_spec=FastHashSpec,
1470                            dtype=dtypes.string,
1471                            name=None):
1472  """Returns a lookup table that converts a string tensor into int64 IDs.
1473
1474  This operation constructs a lookup table to convert tensor of strings into
1475  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
1476  tensor where each element is a key and corresponding index within the tensor
1477  is the value.
1478
1479  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1480  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1481  `default_value`. The bucket ID range is
1482  `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
1483
1484  The underlying table must be initialized by calling
1485  `session.run(tf.compat.v1.tables_initializer())` or
1486  `session.run(table.init())` once.
1487
1488  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1489  the table initializer op, it will throw a `FailedPreconditionError`.
1490
1491  Sample Usages:
1492
1493  ```python
1494  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1495  table = tf.lookup.index_table_from_tensor(
1496      vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
1497  features = tf.constant(["emerson", "lake", "and", "palmer"])
1498  ids = table.lookup(features)
1499  ...
1500  tf.compat.v1.tables_initializer().run()
1501
1502  ids.eval()  ==> [0, 1, 4, 2]
1503  ```
1504
1505  Args:
1506    vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
1507      indices. The type of this object must be castable to `dtype`.
1508    num_oov_buckets: The number of out-of-vocabulary buckets.
1509    default_value: The value to use for out-of-vocabulary feature values.
1510      Defaults to -1.
1511    hasher_spec: A `HasherSpec` to specify the hash function to use for
1512      assignment of out-of-vocabulary buckets.
1513    dtype: The type of values passed to `lookup`. Only string and integers are
1514      supported.
1515    name: A name for this op (optional).
1516
1517  Returns:
1518    The lookup table to map an input `Tensor` to index `int64` `Tensor`.
1519
1520  Raises:
1521    ValueError: If `vocabulary_list` is invalid.
1522    ValueError: If `num_oov_buckets` is negative.
1523  """
1524  if vocabulary_list is None:
1525    raise ValueError("`vocabulary_list` must be specified.")
1526
1527  if num_oov_buckets < 0:
1528    raise ValueError(
1529        "`num_oov_buckets` must be greater or equal than 0, got %d." %
1530        num_oov_buckets)
1531
1532  if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
1533    raise TypeError("`dtype` must either be integer or string.")
1534
1535  with ops.name_scope(name, "string_to_index"):
1536    keys = ops.convert_to_tensor(vocabulary_list)
1537    if keys.dtype.is_integer != dtype.is_integer:
1538      raise ValueError(
1539          "Invalid `dtype`: Expected %s, got %s." %
1540          ("integer" if dtype.is_integer else "non-integer", keys.dtype))
1541    if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
1542      raise ValueError("Invalid `dtype`: Expected %s, got %s." %
1543                       (dtype, keys.dtype))
1544    num_elements = array_ops.size(keys)
1545    values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1546
1547    with ops.name_scope(None, "hash_table"):
1548      table_keys = math_ops.cast(
1549          keys, dtypes.int64) if keys.dtype.is_integer else keys
1550      init = KeyValueTensorInitializer(
1551          table_keys,
1552          values,
1553          table_keys.dtype.base_dtype,
1554          dtypes.int64,
1555          name="table_init")
1556      table = StaticHashTableV1(init, default_value)
1557    if num_oov_buckets:
1558      table = IdTableWithHashBuckets(
1559          table,
1560          num_oov_buckets=num_oov_buckets,
1561          hasher_spec=hasher_spec,
1562          key_dtype=dtype)
1563    return table
1564
1565
1566def index_to_string_table_from_file(vocabulary_file,
1567                                    vocab_size=None,
1568                                    default_value="UNK",
1569                                    name=None,
1570                                    key_column_index=TextFileIndex.LINE_NUMBER,
1571                                    value_column_index=TextFileIndex.WHOLE_LINE,
1572                                    delimiter="\t"):
1573  """Returns a lookup table that maps a `Tensor` of indices into strings.
1574
1575  This operation constructs a lookup table to map int64 indices into string
1576  values. The table is initialized from a vocabulary file specified in
1577  `vocabulary_file`, where the whole line is the value and the
1578  zero-based line number is the index.
1579
1580  Any input which does not have a corresponding index in the vocabulary file
1581  (an out-of-vocabulary entry) is assigned the `default_value`
1582
1583  The underlying table must be initialized by calling
1584  `session.run(tf.compat.v1.tables_initializer())` or
1585  `session.run(table.init())` once.
1586
1587  To specify multi-column vocabulary files, use key_column_index and
1588  value_column_index and delimiter.
1589
1590  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1591    expects data type int64.
1592  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1593    type string.
1594  - A value >=0 means use the index (starting at zero) of the split line based
1595    on `delimiter`.
1596
1597  Sample Usages:
1598
1599  If we have a vocabulary file "test.txt" with the following content:
1600
1601  ```
1602  emerson
1603  lake
1604  palmer
1605  ```
1606
1607  ```python
1608  indices = tf.constant([1, 5], tf.int64)
1609  table = tf.lookup.index_to_string_table_from_file(
1610      vocabulary_file="test.txt", default_value="UNKNOWN")
1611  values = table.lookup(indices)
1612  ...
1613  tf.compat.v1.tables_initializer().run()
1614
1615  values.eval() ==> ["lake", "UNKNOWN"]
1616  ```
1617
1618  Args:
1619    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1620    vocab_size: Number of the elements in the vocabulary, if known.
1621    default_value: The value to use for out-of-vocabulary indices.
1622    name: A name for this op (optional).
1623    key_column_index: The column index from the text file to get the `key`
1624      values from. The default is to use the line number, starting from zero.
1625    value_column_index: The column index from the text file to get the `value`
1626      values from. The default is to use the whole line content.
1627    delimiter: The delimiter to separate fields in a line.
1628
1629  Returns:
1630    The lookup table to map a string values associated to a given index `int64`
1631    `Tensors`.
1632
1633  Raises:
1634    ValueError: when `vocabulary_file` is empty.
1635    ValueError: when `vocab_size` is invalid.
1636  """
1637  if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types)
1638                                 and not vocabulary_file):
1639    raise ValueError(
1640        "`vocabulary_file` must be specified and must not be empty.")
1641
1642  if vocab_size is not None and vocab_size < 1:
1643    raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
1644
1645  with ops.name_scope(name, "index_to_string"):
1646    init = TextFileStringTableInitializer(
1647        vocabulary_file,
1648        vocab_size=vocab_size,
1649        name="table_init",
1650        key_column_index=key_column_index,
1651        value_column_index=value_column_index,
1652        delimiter=delimiter)
1653
1654    # TODO(yleon): Use a more efficient structure.
1655    return StaticHashTableV1(init, default_value)
1656
1657
1658def index_to_string_table_from_tensor(vocabulary_list,
1659                                      default_value="UNK",
1660                                      name=None):
1661  """Returns a lookup table that maps a `Tensor` of indices into strings.
1662
1663  This operation constructs a lookup table to map int64 indices into string
1664  values. The mapping is initialized from a string `vocabulary_list` 1-D
1665  `Tensor` where each element is a value and the corresponding index within the
1666  tensor is the key.
1667
1668  Any input which does not have a corresponding index in 'vocabulary_list'
1669  (an out-of-vocabulary entry) is assigned the `default_value`
1670
1671  The underlying table must be initialized by calling
1672  `session.run(tf.compat.v1.tables_initializer())` or
1673  `session.run(table.init())` once.
1674
1675  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1676  the table initializer op, it will throw a `FailedPreconditionError`.
1677
1678  Sample Usages:
1679
1680  ```python
1681  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1682  indices = tf.constant([1, 5], tf.int64)
1683  table = tf.lookup.index_to_string_table_from_tensor(
1684      vocabulary_list, default_value="UNKNOWN")
1685  values = table.lookup(indices)
1686  ...
1687  tf.compat.v1.tables_initializer().run()
1688
1689  values.eval() ==> ["lake", "UNKNOWN"]
1690  ```
1691
1692  Args:
1693    vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
1694      from indices.
1695    default_value: The value to use for out-of-vocabulary indices.
1696    name: A name for this op (optional).
1697
1698  Returns:
1699    The lookup table to map a string values associated to a given index `int64`
1700    `Tensors`.
1701
1702  Raises:
1703    ValueError: when `vocabulary_list` is not set.
1704  """
1705
1706  if vocabulary_list is None:
1707    raise ValueError("`vocabulary_list` argument must be specified.")
1708
1709  with ops.name_scope(name, "index_to_string"):
1710    vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
1711    num_elements = array_ops.size(vocabulary_list)
1712    keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1713
1714    init = KeyValueTensorInitializer(
1715        keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
1716    # TODO(yleon): Use a more efficient structure.
1717    return StaticHashTableV1(init, default_value)
1718
1719
1720@tf_export("lookup.experimental.MutableHashTable")
1721class MutableHashTable(LookupInterface):
1722  """A generic mutable hash table implementation.
1723
1724  Data can be inserted by calling the `insert` method and removed by calling the
1725  `remove` method. It does not support initialization via the init method.
1726
1727  `MutableHashTable` requires additional memory during checkpointing and restore
1728  operations to create temporary key and value tensors.
1729
1730  Example usage:
1731
1732  >>> table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.string,
1733  ...                                                 value_dtype=tf.int64,
1734  ...                                                 default_value=-1)
1735  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
1736  >>> vals_tensor = tf.constant([7, 8, 9], dtype=tf.int64)
1737  >>> input_tensor = tf.constant(['a', 'f'])
1738  >>> table.insert(keys_tensor, vals_tensor)
1739  >>> table.lookup(input_tensor).numpy()
1740  array([ 7, -1])
1741  >>> table.remove(tf.constant(['c']))
1742  >>> table.lookup(keys_tensor).numpy()
1743  array([ 7, 8, -1])
1744  >>> sorted(table.export()[0].numpy())
1745  [b'a', b'b']
1746  >>> sorted(table.export()[1].numpy())
1747  [7, 8]
1748  """
1749
1750  def __init__(self,
1751               key_dtype,
1752               value_dtype,
1753               default_value,
1754               name="MutableHashTable",
1755               checkpoint=True):
1756    """Creates an empty `MutableHashTable` object.
1757
1758    Creates a table, the type of its keys and values are specified by key_dtype
1759    and value_dtype, respectively.
1760
1761    Args:
1762      key_dtype: the type of the key tensors.
1763      value_dtype: the type of the value tensors.
1764      default_value: The value to use if a key is missing in the table.
1765      name: A name for the operation (optional).
1766      checkpoint: if True, the contents of the table are saved to and restored
1767        from checkpoints. If `shared_name` is empty for a checkpointed table, it
1768        is shared using the table node name.
1769
1770    Returns:
1771      A `MutableHashTable` object.
1772
1773    Raises:
1774      ValueError: If checkpoint is True and no name was specified.
1775    """
1776    self._default_value = ops.convert_to_tensor(
1777        default_value, dtype=value_dtype)
1778    self._value_shape = self._default_value.get_shape()
1779    self._checkpoint = checkpoint
1780    self._key_dtype = key_dtype
1781    self._value_dtype = value_dtype
1782    self._name = name
1783
1784    self._shared_name = None
1785    if context.executing_eagerly():
1786      # TODO(allenl): This will leak memory due to kernel caching by the
1787      # shared_name attribute value (but is better than the alternative of
1788      # sharing everything by default when executing eagerly; hopefully creating
1789      # tables in a loop is uncommon).
1790      # TODO(rohanj): Use context.shared_name() instead.
1791      self._shared_name = "table_%d" % (ops.uid(),)
1792    super(MutableHashTable, self).__init__(key_dtype, value_dtype)
1793
1794    self._resource_handle = self._create_resource()
1795    if checkpoint:
1796      saveable = MutableHashTable._Saveable(self, name)
1797      if not context.executing_eagerly():
1798        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1799
1800  def _create_resource(self):
1801    # The table must be shared if checkpointing is requested for multi-worker
1802    # training to work correctly. Use the node name if no shared_name has been
1803    # explicitly specified.
1804    use_node_name_sharing = self._checkpoint and self._shared_name is None
1805    if self._default_value.get_shape().ndims == 0:
1806      table_ref = gen_lookup_ops.mutable_hash_table_v2(
1807          shared_name=self._shared_name,
1808          use_node_name_sharing=use_node_name_sharing,
1809          key_dtype=self._key_dtype,
1810          value_dtype=self._value_dtype,
1811          name=self._name)
1812    else:
1813      table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
1814          shared_name=self._shared_name,
1815          use_node_name_sharing=use_node_name_sharing,
1816          key_dtype=self._key_dtype,
1817          value_dtype=self._value_dtype,
1818          value_shape=self._default_value.get_shape(),
1819          name=self._name)
1820
1821    if context.executing_eagerly():
1822      self._table_name = None
1823    else:
1824      self._table_name = table_ref.op.name.split("/")[-1]
1825    return table_ref
1826
1827  @property
1828  def name(self):
1829    return self._table_name
1830
1831  def size(self, name=None):
1832    """Compute the number of elements in this table.
1833
1834    Args:
1835      name: A name for the operation (optional).
1836
1837    Returns:
1838      A scalar tensor containing the number of elements in this table.
1839    """
1840    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1841      with ops.colocate_with(self.resource_handle):
1842        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1843
1844  def remove(self, keys, name=None):
1845    """Removes `keys` and its associated values from the table.
1846
1847    If a key is not present in the table, it is silently ignored.
1848
1849    Args:
1850      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1851        key type.
1852      name: A name for the operation (optional).
1853
1854    Returns:
1855      The created Operation.
1856
1857    Raises:
1858      TypeError: when `keys` do not match the table data types.
1859    """
1860    if keys.dtype != self._key_dtype:
1861      raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1862                      f"received: {keys.dtype}")
1863
1864    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1865                        (self.resource_handle, keys, self._default_value)):
1866      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1867
1868    return op
1869
1870  def lookup(self, keys, dynamic_default_values=None, name=None):
1871    """Looks up `keys` in a table, outputs the corresponding values.
1872
1873    The `default_value` is used for keys not present in the table.
1874
1875    Args:
1876      keys: Keys to look up. Can be a tensor of any shape. Must match the
1877        table's key_dtype.
1878      dynamic_default_values: The values to use if a key is missing in the
1879        table. If None (by default), the `table.default_value` will be used.
1880        Shape of `dynamic_default_values` must be same with
1881        `table.default_value` or the lookup result tensor.
1882        In the latter case, each key will have a different default value.
1883
1884        For example:
1885
1886          ```python
1887          keys = [0, 1, 3]
1888          dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1889
1890          # The key '0' will use [1, 3, 4] as default value.
1891          # The key '1' will use [2, 3, 9] as default value.
1892          # The key '3' will use [8, 3, 0] as default value.
1893          ```
1894
1895      name: A name for the operation (optional).
1896
1897    Returns:
1898      A tensor containing the values in the same shape as `keys` using the
1899        table's value type.
1900
1901    Raises:
1902      TypeError: when `keys` do not match the table data types.
1903    """
1904    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1905                        (self.resource_handle, keys, self._default_value)):
1906      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1907      with ops.colocate_with(self.resource_handle):
1908        values = gen_lookup_ops.lookup_table_find_v2(
1909            self.resource_handle, keys, dynamic_default_values
1910            if dynamic_default_values is not None else self._default_value)
1911    return values
1912
1913  def insert(self, keys, values, name=None):
1914    """Associates `keys` with `values`.
1915
1916    Args:
1917      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
1918        key type.
1919      values: Values to be associated with keys. Must be a tensor of the same
1920        shape as `keys` and match the table's value type.
1921      name: A name for the operation (optional).
1922
1923    Returns:
1924      The created Operation.
1925
1926    Raises:
1927      TypeError: when `keys` or `values` doesn't match the table data
1928        types.
1929    """
1930    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
1931                        [self.resource_handle, keys, values]):
1932      keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
1933      values = ops.convert_to_tensor(values, self._value_dtype, name="values")
1934      with ops.colocate_with(self.resource_handle):
1935        # pylint: disable=protected-access
1936        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
1937                                                   values)
1938    return op
1939
1940  def export(self, name=None):
1941    """Returns tensors of all keys and values in the table.
1942
1943    Args:
1944      name: A name for the operation (optional).
1945
1946    Returns:
1947      A pair of tensors with the first tensor containing all keys and the
1948        second tensors containing all values in the table.
1949    """
1950    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
1951                        [self.resource_handle]):
1952      with ops.colocate_with(self.resource_handle):
1953        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
1954            self.resource_handle, self._key_dtype, self._value_dtype)
1955    return exported_keys, exported_values
1956
1957  def _gather_saveables_for_checkpoint(self):
1958    """For object-based checkpointing."""
1959    return {
1960        "table":
1961            functools.partial(
1962                MutableHashTable._Saveable, table=self, name=self._name,
1963                table_name=self._name)
1964    }
1965
1966  class _Saveable(BaseSaverBuilder.SaveableObject):
1967    """SaveableObject implementation for DenseHashTable."""
1968
1969    def __init__(self, table, name, table_name=None):
1970      tensors = table.export()
1971      specs = [
1972          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
1973          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
1974      ]
1975      self.table_name = table_name or name
1976      # pylint: disable=protected-access
1977      super(MutableHashTable._Saveable, self).__init__(table, specs, name)
1978
1979    def restore(self, restored_tensors, restored_shapes):
1980      del restored_shapes  # unused
1981      # pylint: disable=protected-access
1982      with ops.name_scope("%s_table_restore" % self.table_name):
1983        with ops.colocate_with(self.op.resource_handle):
1984          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
1985                                                       restored_tensors[0],
1986                                                       restored_tensors[1])
1987
1988
1989@tf_export("lookup.experimental.DenseHashTable")
1990class DenseHashTable(LookupInterface):
1991  """A mutable hash table with faster lookups and higher memory usage.
1992
1993  Data can be inserted by calling the `insert` method and removed by calling the
1994  `remove` method. It does not support initialization via the init method.
1995
1996  Compared to `MutableHashTable`, `DenseHashTable` offers generally faster
1997  `insert`, `remove` and `lookup` operations, in exchange for a higher overall
1998  memory footprint.
1999
2000  It uses "open addressing" with quadratic reprobing to resolve collisions. This
2001  requires specifying two keys in the key space, `empty_key` and `deleted_key`,
2002  that can never inserted into the table.
2003
2004  Unlike `MutableHashTable`, `DenseHashTable` does not require additional memory
2005  for temporary tensors created during checkpointing and restore operations.
2006
2007  Example usage:
2008
2009  >>> table = tf.lookup.experimental.DenseHashTable(
2010  ...     key_dtype=tf.string,
2011  ...     value_dtype=tf.int64,
2012  ...     default_value=-1,
2013  ...     empty_key='',
2014  ...     deleted_key='$')
2015  >>> keys = tf.constant(['a', 'b', 'c'])
2016  >>> values = tf.constant([0, 1, 2], dtype=tf.int64)
2017  >>> table.insert(keys, values)
2018  >>> table.remove(tf.constant(['c']))
2019  >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy()
2020  array([ 0,  1, -1, -1])
2021  """
2022
2023  # TODO(andreasst): consider extracting common code with MutableHashTable into
2024  # a common superclass.
2025  def __init__(self,
2026               key_dtype,
2027               value_dtype,
2028               default_value,
2029               empty_key,
2030               deleted_key,
2031               initial_num_buckets=None,
2032               name="MutableDenseHashTable",
2033               checkpoint=True):
2034    """Creates an empty `DenseHashTable` object.
2035
2036    Creates a table, the type of its keys and values are specified by key_dtype
2037    and value_dtype, respectively.
2038
2039    Args:
2040      key_dtype: the type of the key tensors.
2041      value_dtype: the type of the value tensors.
2042      default_value: The value to use if a key is missing in the table.
2043      empty_key: the key to use to represent empty buckets internally. Must not
2044        be used in insert, remove or lookup operations.
2045      deleted_key: the key to use to represent deleted buckets internally. Must
2046        not be used in insert, remove or lookup operations and be different from
2047        the empty_key.
2048      initial_num_buckets: the initial number of buckets.
2049      name: A name for the operation (optional).
2050      checkpoint: if True, the contents of the table are saved to and restored
2051        from checkpoints. If `shared_name` is empty for a checkpointed table, it
2052        is shared using the table node name.
2053
2054    Returns:
2055      A `DenseHashTable` object.
2056
2057    Raises:
2058      ValueError: If checkpoint is True and no name was specified.
2059    """
2060    self._default_value = ops.convert_to_tensor(
2061        default_value, dtype=value_dtype, name="default_value")
2062    self._key_dtype = key_dtype
2063    self._value_dtype = value_dtype
2064    self._initial_num_buckets = initial_num_buckets
2065    self._value_shape = self._default_value.get_shape()
2066    self._checkpoint = checkpoint
2067    self._name = name
2068
2069    self._empty_key = empty_key
2070    self._deleted_key = deleted_key
2071    self._shared_name = None
2072    if context.executing_eagerly():
2073      # TODO(allenl): This will leak memory due to kernel caching by the
2074      # shared_name attribute value (but is better than the alternative of
2075      # sharing everything by default when executing eagerly; hopefully creating
2076      # tables in a loop is uncommon).
2077      # TODO(rohanj): Use context.shared_name() instead.
2078      self._shared_name = "table_%d" % (ops.uid(),)
2079    super(DenseHashTable, self).__init__(key_dtype, value_dtype)
2080
2081    self._resource_handle = self._create_resource()
2082    if checkpoint:
2083      saveable = DenseHashTable._Saveable(self, name)
2084      if not context.executing_eagerly():
2085        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
2086
2087  def _create_resource(self):
2088    # The table must be shared if checkpointing is requested for multi-worker
2089    # training to work correctly. Use the node name if no shared_name has been
2090    # explicitly specified.
2091    use_node_name_sharing = self._checkpoint and self._shared_name is None
2092    empty_key = ops.convert_to_tensor(
2093        self._empty_key, dtype=self._key_dtype, name="empty_key")
2094    deleted_key = ops.convert_to_tensor(
2095        self._deleted_key, dtype=self._key_dtype, name="deleted_key")
2096    table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
2097        empty_key=empty_key,
2098        deleted_key=deleted_key,
2099        shared_name=self._shared_name,
2100        use_node_name_sharing=use_node_name_sharing,
2101        value_dtype=self._value_dtype,
2102        value_shape=self._value_shape,
2103        initial_num_buckets=self._initial_num_buckets,
2104        name=self._name)
2105    if context.executing_eagerly():
2106      self._table_name = None
2107    else:
2108      self._table_name = table_ref.op.name.split("/")[-1]
2109    return table_ref
2110
2111  @property
2112  def name(self):
2113    return self._table_name
2114
2115  def size(self, name=None):
2116    """Compute the number of elements in this table.
2117
2118    Args:
2119      name: A name for the operation (optional).
2120
2121    Returns:
2122      A scalar tensor containing the number of elements in this table.
2123    """
2124    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
2125      with ops.colocate_with(self.resource_handle):
2126        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
2127
2128  def lookup(self, keys, name=None):
2129    """Looks up `keys` in a table, outputs the corresponding values.
2130
2131    The `default_value` is used for keys not present in the table.
2132
2133    Args:
2134      keys: Keys to look up. Can be a tensor of any shape. Must match the
2135        table's key_dtype.
2136      name: A name for the operation (optional).
2137
2138    Returns:
2139      A tensor containing the values in the same shape as `keys` using the
2140        table's value type.
2141
2142    Raises:
2143      TypeError: when `keys` do not match the table data types.
2144    """
2145    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
2146                        [self.resource_handle, keys]):
2147      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2148      with ops.colocate_with(self.resource_handle):
2149        values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
2150                                                     self._default_value)
2151
2152    return values
2153
2154  def insert_or_assign(self, keys, values, name=None):
2155    """Associates `keys` with `values`.
2156
2157    Args:
2158      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2159        key type.
2160      values: Values to be associated with keys. Must be a tensor of the same
2161        shape as `keys` and match the table's value type.
2162      name: A name for the operation (optional).
2163
2164    Returns:
2165      The created Operation.
2166
2167    Raises:
2168      TypeError: when `keys` or `values` doesn't match the table data
2169        types.
2170    """
2171    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2172                        [self.resource_handle, keys, values]):
2173      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2174      values = ops.convert_to_tensor(
2175          values, dtype=self._value_dtype, name="values")
2176      with ops.colocate_with(self.resource_handle):
2177        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2178                                                   values)
2179      return op
2180
2181  def insert(self, keys, values, name=None):
2182    """Associates `keys` with `values`.
2183
2184    Args:
2185      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2186        key type.
2187      values: Values to be associated with keys. Must be a tensor of the same
2188        shape as `keys` and match the table's value type.
2189      name: A name for the operation (optional).
2190
2191    Returns:
2192      The created Operation.
2193
2194    Raises:
2195      TypeError: when `keys` or `values` doesn't match the table data
2196        types.
2197    """
2198    return self.insert_or_assign(keys, values, name)
2199
2200  def erase(self, keys, name=None):
2201    """Removes `keys` and its associated values from the table.
2202
2203    If a key is not present in the table, it is silently ignored.
2204
2205    Args:
2206      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2207        key type.
2208      name: A name for the operation (optional).
2209
2210    Returns:
2211      The created Operation.
2212
2213    Raises:
2214      TypeError: when `keys` do not match the table data types.
2215    """
2216    if keys.dtype != self._key_dtype:
2217      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
2218                      (self._key_dtype, keys.dtype))
2219
2220    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
2221                        (self.resource_handle, keys, self._default_value)):
2222      # pylint: disable=protected-access
2223      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
2224
2225    return op
2226
2227  def remove(self, keys, name=None):
2228    """Removes `keys` and its associated values from the table.
2229
2230    If a key is not present in the table, it is silently ignored.
2231
2232    Args:
2233      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2234        key type.
2235      name: A name for the operation (optional).
2236
2237    Returns:
2238      The created Operation.
2239
2240    Raises:
2241      TypeError: when `keys` do not match the table data types.
2242    """
2243    return self.erase(keys, name)
2244
2245  def export(self, name=None):
2246    """Returns tensors of all keys and values in the table.
2247
2248    Args:
2249      name: A name for the operation (optional).
2250
2251    Returns:
2252      A pair of tensors with the first tensor containing all keys and the
2253        second tensors containing all values in the table.
2254    """
2255    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2256                        [self.resource_handle]):
2257      with ops.colocate_with(self.resource_handle):
2258        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2259            self.resource_handle, self._key_dtype, self._value_dtype)
2260
2261    return exported_keys, exported_values
2262
2263  def _gather_saveables_for_checkpoint(self):
2264    """For object-based checkpointing."""
2265    return {
2266        "table":
2267            functools.partial(
2268                DenseHashTable._Saveable, table=self, name=self._name,
2269                table_name=self._name)
2270    }
2271
2272  class _Saveable(BaseSaverBuilder.SaveableObject):
2273    """SaveableObject implementation for DenseHashTable."""
2274
2275    def __init__(self, table, name, table_name=None):
2276      tensors = table.export()
2277      specs = [
2278          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2279          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2280      ]
2281      self.table_name = table_name or name
2282      # pylint: disable=protected-access
2283      super(DenseHashTable._Saveable, self).__init__(table, specs, name)
2284
2285    def restore(self, restored_tensors, restored_shapes):
2286      del restored_shapes  # unused
2287      # pylint: disable=protected-access
2288      with ops.name_scope("%s_table_restore" % self.table_name):
2289        with ops.colocate_with(self.op.resource_handle):
2290          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2291                                                       restored_tensors[0],
2292                                                       restored_tensors[1])
2293
2294
2295ops.NotDifferentiable("LookupTableFind")
2296ops.NotDifferentiable("LookupTableFindV2")
2297ops.NotDifferentiable("LookupTableInsert")
2298ops.NotDifferentiable("LookupTableInsertV2")
2299ops.NotDifferentiable("LookupTableSize")
2300ops.NotDifferentiable("LookupTableSizeV2")
2301ops.NotDifferentiable("HashTable")
2302ops.NotDifferentiable("HashTableV2")
2303ops.NotDifferentiable("InitializeTable")
2304ops.NotDifferentiable("InitializeTableV2")
2305ops.NotDifferentiable("InitializeTableFromTextFile")
2306ops.NotDifferentiable("InitializeTableFromTextFileV2")
2307ops.NotDifferentiable("MutableDenseHashTable")
2308ops.NotDifferentiable("MutableDenseHashTableV2")
2309ops.NotDifferentiable("MutableHashTable")
2310ops.NotDifferentiable("MutableHashTableV2")
2311ops.NotDifferentiable("MutableHashTableOfTensors")
2312ops.NotDifferentiable("MutableHashTableOfTensorsV2")
2313