• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import uuid
24
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_lookup_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import string_ops
40# go/tf-wildcard-import
41# pylint: disable=wildcard-import
42from tensorflow.python.ops.gen_lookup_ops import *
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.training.saver import BaseSaverBuilder
45# pylint: enable=wildcard-import
46from tensorflow.python.training.tracking import base as trackable_base
47from tensorflow.python.training.tracking import tracking as trackable
48from tensorflow.python.util import compat
49from tensorflow.python.util.deprecation import deprecated
50from tensorflow.python.util.tf_export import tf_export
51
52
53@tf_export(v1=["initialize_all_tables"])
54@deprecated(None, "Use `tf.tables_initializer` instead.")
55def initialize_all_tables(name="init_all_tables"):
56  """Returns an Op that initializes all tables of the default graph.
57
58  Args:
59    name: Optional name for the initialization op.
60
61  Returns:
62    An Op that initializes all tables.  Note that if there are
63    not tables the returned Op is a NoOp.
64  """
65  return tables_initializer(name)
66
67
68@tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
69def tables_initializer(name="init_all_tables"):
70  """Returns an Op that initializes all tables of the default graph.
71
72  See the [Low Level
73  Intro](https://www.tensorflow.org/guide/low_level_intro#feature_columns)
74  guide, for an example of usage.
75
76  Args:
77    name: Optional name for the initialization op.
78
79  Returns:
80    An Op that initializes all tables.  Note that if there are
81    not tables the returned Op is a NoOp.
82  """
83  initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
84  if initializers:
85    return control_flow_ops.group(*initializers, name=name)
86  return control_flow_ops.no_op(name=name)
87
88
89def _check_table_dtypes(table, key_dtype, value_dtype):
90  """Check that the given key_dtype and value_dtype matches the table dtypes.
91
92  Args:
93    table: The table to check types against to.
94    key_dtype: The key data type to check.
95    value_dtype: The value data type to check.
96
97  Raises:
98    TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
99      types.
100  """
101  if key_dtype.base_dtype != table.key_dtype:
102    raise TypeError("Invalid key dtype, expected %s but got %s." %
103                    (table.key_dtype, key_dtype))
104  if value_dtype.base_dtype != table.value_dtype:
105    raise TypeError("Invalid value dtype, expected %s but got %s." %
106                    (table.value_dtype, value_dtype))
107
108
109class LookupInterface(trackable.TrackableResource):
110  """Represent a lookup table that persists across different steps."""
111
112  def __init__(self, key_dtype, value_dtype):
113    """Construct a lookup table interface.
114
115    Args:
116      key_dtype: The table key type.
117      value_dtype: The table value type.
118    """
119    self._key_dtype = dtypes.as_dtype(key_dtype)
120    self._value_dtype = dtypes.as_dtype(value_dtype)
121    super(LookupInterface, self).__init__()
122
123  def _create_resource(self):
124    raise NotImplementedError
125
126  @property
127  def key_dtype(self):
128    """The table key dtype."""
129    return self._key_dtype
130
131  @property
132  def value_dtype(self):
133    """The table value dtype."""
134    return self._value_dtype
135
136  @property
137  def name(self):
138    """The name of the table."""
139    return NotImplementedError
140
141  def size(self, name=None):
142    """Compute the number of elements in this table."""
143    raise NotImplementedError
144
145  def lookup(self, keys, name=None):
146    """Looks up `keys` in a table, outputs the corresponding values."""
147    raise NotImplementedError
148
149  def __getitem__(self, keys):
150    """Looks up `keys` in a table, outputs the corresponding values."""
151    return self.lookup(keys)
152
153
154class InitializableLookupTableBase(LookupInterface):
155  """Initializable lookup table interface.
156
157  An initializable lookup tables persist across different steps.
158  """
159
160  def __init__(self, default_value, initializer):
161    """Construct a table object from a table reference.
162
163    If requires a table initializer object (subclass of `TableInitializerBase`).
164    It provides the table key and value types, as well as the op to initialize
165    the table. The caller is responsible to execute the initialization op.
166
167    Args:
168      default_value: The value to use if a key is missing in the table.
169      initializer: The table initializer to use.
170    """
171    super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
172                                                       initializer.value_dtype)
173    self._default_value = ops.convert_to_tensor(
174        default_value, dtype=self._value_dtype)
175    self._default_value.get_shape().merge_with(tensor_shape.TensorShape([]))
176    if isinstance(initializer, trackable_base.Trackable):
177      self._initializer = self._track_trackable(initializer, "_initializer")
178    with ops.init_scope():
179      self._resource_handle = self._create_resource()
180    if (not context.executing_eagerly() and
181        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
182      with ops.init_scope():
183        self._init_op = self._initialize()
184    else:
185      self._init_op = self._initialize()
186
187  def _initialize(self):
188    return self._initializer.initialize(self)
189
190  @property
191  def default_value(self):
192    """The default value of the table."""
193    return self._default_value
194
195  def size(self, name=None):
196    """Compute the number of elements in this table.
197
198    Args:
199      name: A name for the operation (optional).
200
201    Returns:
202      A scalar tensor containing the number of elements in this table.
203    """
204    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
205      return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
206
207  def lookup(self, keys, name=None):
208    """Looks up `keys` in a table, outputs the corresponding values.
209
210    The `default_value` is used for keys not present in the table.
211
212    Args:
213      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
214      name: A name for the operation (optional).
215
216    Returns:
217      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
218      otherwise a dense `Tensor`.
219
220    Raises:
221      TypeError: when `keys` or `default_value` doesn't match the table data
222        types.
223    """
224    key_tensor = keys
225    if isinstance(keys,
226                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
227      key_tensor = keys.values
228
229    if keys.dtype.base_dtype != self._key_dtype:
230      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
231                      (self._key_dtype, keys.dtype))
232
233    with ops.name_scope(
234        name, "%s_Lookup" % self.name,
235        (self.resource_handle, key_tensor, self._default_value)):
236      values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle,
237                                                   key_tensor,
238                                                   self._default_value)
239
240    values.set_shape(key_tensor.get_shape())
241    if isinstance(keys, sparse_tensor.SparseTensor):
242      return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
243    elif isinstance(keys, ragged_tensor.RaggedTensor):
244      return keys.with_values(values)
245    else:
246      return values
247
248
249class InitializableLookupTableBaseV1(InitializableLookupTableBase):
250
251  @property
252  def initializer(self):
253    return self._init_op
254
255
256@tf_export("lookup.StaticHashTable", v1=[])
257class StaticHashTable(InitializableLookupTableBase):
258  """A generic hash table that is immutable once initialized.
259
260  Example usage:
261
262  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
263  >>> vals_tensor = tf.constant([7, 8, 9])
264  >>> input_tensor = tf.constant(['a', 'f'])
265  >>> table = tf.lookup.StaticHashTable(
266  ...     tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
267  ...     default_value=-1)
268  >>> table.lookup(input_tensor).numpy()
269  array([ 7, -1], dtype=int32)
270
271  Or for more pythonic code:
272
273  >>> table[input_tensor].numpy()
274  array([ 7, -1], dtype=int32)
275
276  The result of a lookup operation has the same shape as the argument:
277
278  >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']])
279  >>> table[input_tensor].numpy()
280  array([[ 7,  8],
281         [ 9, -1]], dtype=int32)
282
283
284  """
285
286  def __init__(self, initializer, default_value, name=None):
287    """Creates a non-initialized `HashTable` object.
288
289    Creates a table, the type of its keys and values are specified by the
290    initializer.
291    Before using the table you will have to initialize it. After initialization
292    the table will be immutable.
293
294    Args:
295      initializer: The table initializer to use. See `HashTable` kernel for
296        supported key and value types.
297      default_value: The value to use if a key is missing in the table.
298      name: A name for the operation (optional).
299
300    Returns:
301      A `HashTable` object.
302    """
303    self._initializer = initializer
304    self._default_value = default_value
305    self._shared_name = self._initializer._shared_name  # pylint: disable=protected-access
306    if not self._shared_name:
307      # Force using a shared name so that StaticHashTable resources can be
308      # shared across different kernels. If no "shared_name" is set and
309      # "use_node_name_sharing" is False, then each kernel gets its own local
310      # resource.
311      self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),)
312    self._name = name or "hash_table"
313    self._table_name = None
314    super(StaticHashTable, self).__init__(default_value, initializer)
315    self._value_shape = self._default_value.get_shape()
316
317  def _create_resource(self):
318    table_ref = gen_lookup_ops.hash_table_v2(
319        shared_name=self._shared_name,
320        key_dtype=self._initializer.key_dtype,
321        value_dtype=self._initializer.value_dtype,
322        name=self._name)
323    if context.executing_eagerly():
324      self._table_name = None
325    else:
326      self._table_name = table_ref.op.name.split("/")[-1]
327    return table_ref
328
329  @property
330  def name(self):
331    return self._table_name
332
333  def export(self, name=None):
334    """Returns tensors of all keys and values in the table.
335
336    Args:
337      name: A name for the operation (optional).
338
339    Returns:
340      A pair of tensors with the first tensor containing all keys and the
341        second tensors containing all values in the table.
342    """
343    with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
344      exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
345          self.resource_handle, self._key_dtype, self._value_dtype)
346
347    exported_values.set_shape(exported_keys.get_shape().concatenate(
348        self._value_shape))
349    return exported_keys, exported_values
350
351
352@tf_export(v1=["lookup.StaticHashTable"])
353class StaticHashTableV1(StaticHashTable):
354  """A generic hash table that is immutable once initialized.
355
356  When running in graph mode, you must evaluate the tensor returned by
357  `tf.tables_initializer()` before evaluating the tensor returned by
358  this class's `lookup()` method. Example usage in graph mode:
359
360  ```python
361  keys_tensor = tf.constant([1, 2])
362  vals_tensor = tf.constant([3, 4])
363  input_tensor = tf.constant([1, 5])
364  table = tf.lookup.StaticHashTable(
365      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
366  out = table.lookup(input_tensor)
367  with tf.Session() as sess:
368      sess.run(tf.tables_initializer())
369      print(sess.run(out))
370  ```
371
372  In eager mode, no special code is needed to initialize the table.
373  Example usage in eager mode:
374
375  ```python
376  tf.enable_eager_execution()
377  keys_tensor = tf.constant([1, 2])
378  vals_tensor = tf.constant([3, 4])
379  input_tensor = tf.constant([1, 5])
380  table = tf.lookup.StaticHashTable(
381      tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
382  print(table.lookup(input_tensor))
383  ```
384  """
385
386  @property
387  def initializer(self):
388    return self._init_op
389
390
391# For backwards compatibility. This will be removed in TF 2.0.
392class HashTable(StaticHashTableV1):
393
394  @property
395  def init(self):
396    return self.initializer
397
398
399class TableInitializerBase(trackable_base.Trackable):
400  """Base class for lookup table initializers."""
401
402  def __init__(self, key_dtype, value_dtype):
403    """Construct a table initializer object.
404
405    Args:
406      key_dtype: Type of the table keys.
407      value_dtype: Type of the table values.
408    """
409    self._key_dtype = dtypes.as_dtype(key_dtype)
410    self._value_dtype = dtypes.as_dtype(value_dtype)
411
412  @property
413  def key_dtype(self):
414    """The expected table key dtype."""
415    return self._key_dtype
416
417  @property
418  def value_dtype(self):
419    """The expected table value dtype."""
420    return self._value_dtype
421
422  def initialize(self, table):
423    """Returns the table initialization op."""
424    raise NotImplementedError
425
426  @property
427  def _shared_name(self):
428    """Returns a shared name to be used by the table."""
429    shared_name = ""
430    if context.executing_eagerly():
431      # Ensure a unique name when eager execution is enabled to avoid spurious
432      # sharing issues.
433      # TODO(rohanj): Use context.shared_name() instead.
434      shared_name += str(ops.uid())
435    return shared_name
436
437
438@tf_export("lookup.experimental.DatasetInitializer")
439class DatasetInitializer(TableInitializerBase):
440  """Creates a table initializer from a `tf.data.Dataset`.
441
442  Sample usage:
443
444  >>> keys = tf.data.Dataset.range(100)
445  >>> values = tf.data.Dataset.range(100).map(
446  ...     lambda x: string_ops.as_string(x * 2))
447  >>> ds = tf.data.Dataset.zip((keys, values))
448  >>> init = tf.lookup.experimental.DatasetInitializer(ds)
449  >>> table = tf.lookup.StaticHashTable(init, "")
450  >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
451  array([b'0', b'2', b'4'], dtype=object)
452
453  Attributes:
454    dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
455      first scalar is treated as a key and the second as value.
456
457  Raises: ValueError if `dataset` doesn't conform to specifications.
458  """
459
460  def __init__(self, dataset):
461    """Creates a table initializer from a `tf.data.Dataset`.
462
463    Args:
464      dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
465      first scalar is treated as a key and the second as value.
466
467    Raises: ValueError if `dataset` doesn't conform to specifications.
468    Returns: A `DatasetInitializer` object
469    """
470    # Assert that the dataset element spec is a tuple of TensorSpecs where
471    # each tensor is a scalar.
472    self.dataset = dataset
473    elem_spec = self.dataset.element_spec
474    if len(elem_spec) != 2:
475      raise ValueError("element spec size should be 2")
476    if not isinstance(elem_spec[0], tensor_spec.TensorSpec):
477      raise ValueError("elem_spec[0] should be of type TensorSpec")
478    if not isinstance(elem_spec[1], tensor_spec.TensorSpec):
479      raise ValueError("elem_spec[1] should be of type TensorSpec")
480    if elem_spec[0].shape.rank not in (None, 0):
481      raise ValueError("key tensor should be a scalar")
482    if elem_spec[1].shape.rank not in (None, 0):
483      raise ValueError("value tensor should be a scalar")
484
485    key_type = elem_spec[0].dtype
486    value_type = elem_spec[1].dtype
487    super(DatasetInitializer, self).__init__(key_type, value_type)
488
489  def initialize(self, table):
490    _check_table_dtypes(table, self._key_dtype, self._value_dtype)
491    init_op = gen_lookup_ops.initialize_table_from_dataset(
492        table.resource_handle, self.dataset._variant_tensor)  # pylint: disable=protected-access
493    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
494    return init_op
495
496
497@tf_export("lookup.KeyValueTensorInitializer")
498class KeyValueTensorInitializer(TableInitializerBase):
499  """Table initializers given `keys` and `values` tensors.
500
501  >>> keys_tensor = tf.constant(['a', 'b', 'c'])
502  >>> vals_tensor = tf.constant([7, 8, 9])
503  >>> input_tensor = tf.constant(['a', 'f'])
504  >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor)
505  >>> table = tf.lookup.StaticHashTable(
506  ...     init,
507  ...     default_value=-1)
508  >>> table.lookup(input_tensor).numpy()
509  array([ 7, -1], dtype=int32)
510
511  """
512
513  def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
514    """Constructs a table initializer object based on keys and values tensors.
515
516    Args:
517      keys: The tensor for the keys.
518      values: The tensor for the values.
519      key_dtype: The `keys` data type. Used when `keys` is a python array.
520      value_dtype: The `values` data type. Used when `values` is a python array.
521      name: A name for the operation (optional).
522    """
523    if (not context.executing_eagerly() and
524        ops.get_default_graph()._get_control_flow_context() is not None):  # pylint: disable=protected-access
525      with ops.init_scope():
526        self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
527        self._values = ops.convert_to_tensor(
528            values, dtype=value_dtype, name="values")
529    else:
530      self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
531      self._values = ops.convert_to_tensor(
532          values, dtype=value_dtype, name="values")
533    self._name = name if name is not None else "key_value_init"
534    if context.executing_eagerly():
535      # Ensure a unique name when eager execution is enabled to avoid spurious
536      # sharing issues.
537      # TODO(rohanj): Use context.shared_name() instead.
538      self._name += str(ops.uid())
539
540    super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
541                                                    self._values.dtype)
542
543  def initialize(self, table):
544    """Initializes the given `table` with `keys` and `values` tensors.
545
546    Args:
547      table: The table to initialize.
548
549    Returns:
550      The operation that initializes the table.
551
552    Raises:
553      TypeError: when the keys and values data types do not match the table
554      key and value data types.
555    """
556    _check_table_dtypes(table, self._keys.dtype, self._values.dtype)
557    with ops.name_scope(
558        self._name, values=(table.resource_handle, self._keys, self._values)):
559      init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle,
560                                                      self._keys, self._values)
561    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
562    return init_op
563
564
565@tf_export("lookup.TextFileIndex")
566class TextFileIndex(object):
567  """The key and value content to get from each line.
568
569  This class defines the key and value used for `tf.lookup.TextFileInitializer`.
570
571  The key and value content to get from each line is specified either
572  by the following, or a value `>=0`.
573  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
574    expects data type int64.
575  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
576    type string.
577
578  A value `>=0` means use the index (starting at zero) of the split line based
579      on `delimiter`.
580  """
581  WHOLE_LINE = -2
582  LINE_NUMBER = -1
583
584
585@tf_export("lookup.TextFileInitializer")
586class TextFileInitializer(TableInitializerBase):
587  r"""Table initializers from a text file.
588
589  This initializer assigns one entry in the table for each line in the file.
590
591  The key and value type of the table to initialize is given by `key_dtype` and
592  `value_dtype`.
593
594  The key and value content to get from each line is specified by
595  the `key_index` and `value_index`.
596
597  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
598    expects data type int64.
599  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
600    type string.
601  * A value `>=0` means use the index (starting at zero) of the split line based
602      on `delimiter`.
603
604  For example if we have a file with the following content:
605
606  >>> import tempfile
607  >>> f = tempfile.NamedTemporaryFile(delete=False)
608  >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",])
609  >>> f.file.write(content.encode('utf-8'))
610  >>> f.file.close()
611
612  The following snippet initializes a table with the first column as keys and
613  second column as values:
614
615  * `emerson -> 10`
616  * `lake -> 20`
617  * `palmer -> 30`
618
619  >>> init= tf.lookup.TextFileInitializer(
620  ...    filename=f.name,
621  ...    key_dtype=tf.string, key_index=0,
622  ...    value_dtype=tf.int64, value_index=1,
623  ...    delimiter=" ")
624  >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
625  >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy()
626
627  Similarly to initialize the whole line as keys and the line number as values.
628
629  * `emerson 10 -> 0`
630  * `lake 20 -> 1`
631  * `palmer 30 -> 2`
632
633  >>> init = tf.lookup.TextFileInitializer(
634  ...   filename=f.name,
635  ...   key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
636  ...   value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
637  >>> table = tf.lookup.StaticHashTable(init, -1)
638  >>> table.lookup(tf.constant('palmer 30')).numpy()
639  2
640  """
641
642  def __init__(self,
643               filename,
644               key_dtype,
645               key_index,
646               value_dtype,
647               value_index,
648               vocab_size=None,
649               delimiter="\t",
650               name=None,
651               value_index_offset=0):
652    """Constructs a table initializer object to populate from a text file.
653
654    It generates one key-value pair per line. The type of table key and
655    value are specified by `key_dtype` and `value_dtype`, respectively.
656    Similarly the content of the key and value are specified by the key_index
657    and value_index.
658
659    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
660      expects data type int64.
661    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
662      type string.
663    - A value >=0 means use the index (starting at zero) of the split line based
664      on `delimiter`.
665
666    Args:
667      filename: The filename of the text file to be used for initialization. The
668        path must be accessible from wherever the graph is initialized (eg.
669        trainer or eval workers). The filename may be a scalar `Tensor`.
670      key_dtype: The `key` data type.
671      key_index: the index that represents information of a line to get the
672        table 'key' values from.
673      value_dtype: The `value` data type.
674      value_index: the index that represents information of a line to get the
675        table 'value' values from.'
676      vocab_size: The number of elements in the file, if known.
677      delimiter: The delimiter to separate fields in a line.
678      name: A name for the operation (optional).
679      value_index_offset: A number to add to all indices extracted from the file
680        This is useful for cases where a user would like to reserve one or more
681        low index values for control characters. For instance, if you would
682        like to ensure that no vocabulary item is mapped to index 0 (so you can
683        reserve 0 for a masking value), you can set value_index_offset to 1;
684        this will mean that the first vocabulary element is mapped to 1
685        instead of 0.
686
687    Raises:
688      ValueError: when the filename is empty, or when the table key and value
689      data types do not match the expected data types.
690    """
691    if not isinstance(filename, ops.Tensor) and not filename:
692      raise ValueError("Filename required for %s." % name)
693
694    self._filename_arg = filename
695    key_dtype = dtypes.as_dtype(key_dtype)
696    value_dtype = dtypes.as_dtype(value_dtype)
697
698    if key_index < -2:
699      raise ValueError("Invalid key index %s." % (key_index))
700
701    if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
702      raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
703                       (dtypes.int64, key_dtype))
704    if ((key_index == TextFileIndex.WHOLE_LINE) and
705        (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
706      raise ValueError(
707          "Signature mismatch. Keys must be integer or string, got %s." %
708          key_dtype)
709    if value_index < -2:
710      raise ValueError("Invalid value index %s." % (value_index))
711
712    if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
713      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
714                       (dtypes.int64, value_dtype))
715    if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
716      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
717                       (dtypes.string, value_dtype))
718
719    if (vocab_size is not None) and (vocab_size <= 0):
720      raise ValueError("Invalid vocab_size %s." % vocab_size)
721
722    self._key_index = key_index
723    self._value_index = value_index
724    self._vocab_size = vocab_size
725    self._delimiter = delimiter
726    self._name = name
727    self._filename = self._track_trackable(
728        trackable.Asset(filename), "_filename")
729    self._offset = value_index_offset
730
731    super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
732
733  def initialize(self, table):
734    """Initializes the table from a text file.
735
736    Args:
737      table: The table to be initialized.
738
739    Returns:
740      The operation that initializes the table.
741
742    Raises:
743      TypeError: when the keys and values data types do not match the table
744      key and value data types.
745    """
746    _check_table_dtypes(table, self.key_dtype, self.value_dtype)
747    with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
748      filename = ops.convert_to_tensor(
749          self._filename, dtypes.string, name="asset_filepath")
750      init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
751          table.resource_handle, filename, self._key_index, self._value_index,
752          -1 if self._vocab_size is None else self._vocab_size, self._delimiter,
753          self._offset)
754    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
755    # If the filename tensor is anything other than a string constant (e.g.,
756    # if it is a placeholder) then it does not make sense to track it as an
757    # asset.
758    if not context.executing_eagerly() and constant_op.is_constant(filename):
759      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
760    return init_op
761
762  @property
763  def _shared_name(self):
764    if self._vocab_size:
765      # Keep the shared_name:
766      # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
767      shared_name = "hash_table_%s_%d_%s_%s" % (
768          self._filename_arg, self._vocab_size, self._key_index,
769          self._value_index)
770    else:
771      # Keep the shared_name
772      # <table_type>_<filename>_<key_index>_<value_index>
773      shared_name = "hash_table_%s_%s_%s" % (self._filename_arg,
774                                             self._key_index, self._value_index)
775    return shared_name
776
777
778class TextFileStringTableInitializer(TextFileInitializer):
779  """Table initializer for `int64` IDs to string tables from a text file."""
780
781  def __init__(self,
782               filename,
783               key_column_index=TextFileIndex.LINE_NUMBER,
784               value_column_index=TextFileIndex.WHOLE_LINE,
785               vocab_size=None,
786               delimiter="\t",
787               name="text_file_string_table_init"):
788    """Constructs an initializer for an id-to-string table from a text file.
789
790    It populates a table that its key and value types are int64 and string,
791    respectively. It generates one key-value pair per line.
792    The content of the key and value are specified by `key_column_index`
793    and `value_column_index`.
794
795    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
796      expects data type int64.
797    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
798      type string.
799    - A value >=0 means use the index (starting at zero) of the split line based
800      on `delimiter`.
801
802    Args:
803      filename: The filename of the text file to be used for initialization. The
804        path must be accessible from wherever the graph is initialized (eg.
805        trainer or eval workers). The filename may be a scalar `Tensor`.
806      key_column_index: The column index from the text file to get the keys
807        from. The default is to use the line number, starting from zero.
808      value_column_index: The column index from the text file to get the values
809        from. The default is to use the whole line content.
810      vocab_size: The number of elements in the file, if known.
811      delimiter: The delimiter to separate fields in a line.
812      name: Optional name for the op.
813
814    Raises:
815      TypeError: when the filename is empty, or when the table key and value
816      data types do not match the expected data types.
817    """
818    super(TextFileStringTableInitializer, self).__init__(
819        filename,
820        dtypes.int64,
821        key_column_index,
822        dtypes.string,
823        value_column_index,
824        vocab_size=vocab_size,
825        delimiter=delimiter,
826        name=name)
827
828
829class TextFileIdTableInitializer(TextFileInitializer):
830  """Table initializer for string to `int64` IDs tables from a text file."""
831
832  def __init__(self,
833               filename,
834               key_column_index=TextFileIndex.WHOLE_LINE,
835               value_column_index=TextFileIndex.LINE_NUMBER,
836               vocab_size=None,
837               delimiter="\t",
838               name="text_file_id_table_init",
839               key_dtype=dtypes.string):
840    """Constructs an initializer for an string-to-id table from a text file.
841
842    It populates a table that its key and value types are string and int64,
843    respectively. It generates one key-value pair per line.
844    The content of the key and value are specified by the key_index
845    and value_index.
846
847    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
848      expects data type int64.
849    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
850      type string.
851    - A value >=0 means use the index (starting at zero) of the split line based
852      on `delimiter`.
853
854    Args:
855      filename: The filename of the text file to be used for initialization. The
856        path must be accessible from wherever the graph is initialized (eg.
857        trainer or eval workers). The filename may be a scalar `Tensor`.
858      key_column_index: The column index from the text file to get the `key`
859        values from. The default is to use the whole line content.
860      value_column_index: The column index from the text file to get the `value`
861        values from. The default is to use the line number, starting from zero.
862      vocab_size: The number of elements in the file, if known.
863      delimiter: The delimiter to separate fields in a line.
864      name: Optional name for the op.
865      key_dtype: The `key` data type.
866
867    Raises:
868      TypeError: when the filename is empty, or when the table key and value
869      data types do not match the expected data types.
870    """
871    super(TextFileIdTableInitializer, self).__init__(
872        filename,
873        key_dtype,
874        key_column_index,
875        dtypes.int64,
876        value_column_index,
877        vocab_size=vocab_size,
878        delimiter=delimiter,
879        name=name)
880
881
882class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
883  """A structure for the spec of the hashing function to use for hash buckets.
884
885  `hasher` is the name of the hashing function to use (eg. "fasthash",
886  "stronghash").
887  `key` is optional and specify the key to use for the hash function if
888  supported, currently only used by a strong hash.
889
890  Fields:
891    hasher: The hasher name to use.
892    key: The key to be used by the hashing function, if required.
893  """
894  __slots__ = ()
895
896
897FastHashSpec = HasherSpec("fasthash", None)  # pylint: disable=invalid-name
898
899
900class StrongHashSpec(HasherSpec):
901  """A structure to specify a key of the strong keyed hash spec.
902
903  The strong hash requires a `key`, which is a list of 2 unsigned integer
904  numbers. These should be non-zero; random numbers generated from random.org
905  would be a fine choice.
906
907  Fields:
908    key: The key to be used by the keyed hashing function.
909  """
910  __slots__ = ()
911
912  def __new__(cls, key):
913    if len(key) != 2:
914      raise ValueError("key must have size 2, got %s." % len(key))
915
916    if not isinstance(key[0], compat.integral_types) or not isinstance(
917        key[1], compat.integral_types):
918      raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
919
920    return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
921
922
923def _as_string(tensor):
924  if dtypes.string == tensor.dtype.base_dtype:
925    return tensor
926  return string_ops.as_string(tensor)
927
928
929class IdTableWithHashBuckets(LookupInterface):
930  r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
931
932  For example, if an instance of `IdTableWithHashBuckets` is initialized with a
933  string-to-id table that maps:
934
935  * `emerson -> 0`
936  * `lake -> 1`
937  * `palmer -> 2`
938
939  The `IdTableWithHashBuckets` object will performs the following mapping:
940
941  * `emerson -> 0`
942  * `lake -> 1`
943  * `palmer -> 2`
944  * `<other term> -> bucket_id`, where bucket_id will be between `3` and
945  `3 + num_oov_buckets - 1`, calculated by:
946  `hash(<term>) % num_oov_buckets + vocab_size`
947
948  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
949  the lookup result is `[0, 1, 2, 4, 7]`.
950
951  If `table` is None, only out-of-vocabulary buckets are used.
952
953  Example usage:
954
955  ```python
956  num_oov_buckets = 3
957  input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
958  table = tf.IdTableWithHashBuckets(
959      tf.StaticHashTable(
960          tf.lookup.TextFileInitializer(
961              filename,
962              key_dtype=tf.string,
963              key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
964              value_dtype=tf.int64,
965              value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
966              delimiter="\t"),
967          default_value),
968      num_oov_buckets)
969  out = table.lookup(input_tensor).
970  table.init.run()
971  print(out.eval())
972  ```
973
974  The hash function used for generating out-of-vocabulary buckets ID is handled
975  by `hasher_spec`.
976  """
977
978  def __init__(self,
979               table,
980               num_oov_buckets,
981               hasher_spec=FastHashSpec,
982               name=None,
983               key_dtype=None):
984    """Construct a `IdTableWithHashBuckets` object.
985
986    Args:
987      table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
988      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
989      hasher_spec: A `HasherSpec` to specify the hash function to use for
990        assignation of out-of-vocabulary buckets  (optional).
991      name: A name for the operation (optional).
992      key_dtype: Data type of keys passed to `lookup`. Defaults to
993        `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must
994        be string or integer, and must be castable to `table.key_dtype`.
995
996    Raises:
997      ValueError: when `table` in None and `num_oov_buckets` is not positive.
998      TypeError: when `hasher_spec` is invalid.
999    """
1000    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1001    # characters to use as table name.
1002    if name:
1003      name = name.rstrip("/")
1004    if table:
1005      if key_dtype is None:
1006        key_dtype = table.key_dtype
1007      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1008      if table.key_dtype not in supported_table_key_dtypes:
1009        raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
1010                        (supported_table_key_dtypes, key_dtype))
1011      if table.key_dtype.is_integer != key_dtype.is_integer:
1012        raise TypeError("Invalid key dtype, expected %s but got %s." %
1013                        ("integer" if key_dtype.is_integer else "non-integer",
1014                         table.key_dtype))
1015      if table.value_dtype != dtypes.int64:
1016        raise TypeError("Invalid value dtype, expected %s but got %s." %
1017                        (dtypes.int64, table.value_dtype))
1018      self._table = table
1019      name = name or self._table.name
1020    else:
1021      if num_oov_buckets <= 0:
1022        raise ValueError("oov_buckets must be > 0 if no table is supplied.")
1023      key_dtype = dtypes.string if key_dtype is None else key_dtype
1024      self._table = None
1025      name = name or "hash_bucket"
1026    if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
1027      raise TypeError("Invalid key_dtype, expected integer or string, got %s." %
1028                      key_dtype)
1029    self._num_oov_buckets = num_oov_buckets
1030
1031    if not isinstance(hasher_spec, HasherSpec):
1032      raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
1033                      hasher_spec)
1034    self._hasher_spec = hasher_spec
1035    if name:
1036      self._table_name = name.split("/")[-1]
1037    else:
1038      self._table_name = None
1039    super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
1040
1041  def _create_resource(self):
1042    if self._table is not None:
1043      return self._table._create_resource()  # pylint: disable=protected-access
1044    return None
1045
1046  def _initialize(self):
1047    if self._table is not None:
1048      return self._table._initialize()  # pylint: disable=protected-access
1049    with ops.name_scope(None, "init"):
1050      return control_flow_ops.no_op()
1051
1052  @property
1053  def initializer(self):
1054    if self._table is not None:
1055      return self._table._init_op  # pylint: disable=protected-access
1056    with ops.name_scope(None, "init"):
1057      return control_flow_ops.no_op()
1058
1059  @property
1060  @deprecated("2018-12-15", "Use `initializer` instead.")
1061  def init(self):
1062    return self.initializer
1063
1064  @property
1065  def resource_handle(self):
1066    if self._table is not None:
1067      return self._table.resource_handle
1068    return None
1069
1070  @property
1071  def name(self):
1072    return self._table_name
1073
1074  def size(self, name=None):
1075    """Compute the number of elements in this table."""
1076    with ops.name_scope(name, "%s_Size" % self.name):
1077      if self._table:
1078        tsize = self._table.size()
1079      else:
1080        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1081      return tsize + self._num_oov_buckets
1082
1083  def _get_string_to_hash_bucket_fn(self, hasher_spec):
1084    """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
1085    if not isinstance(hasher_spec, HasherSpec):
1086      raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
1087    if hasher_spec.hasher == "fasthash":
1088      return string_ops.string_to_hash_bucket_fast
1089    if hasher_spec.hasher == "legacy":
1090      return string_ops.string_to_hash_bucket
1091    if hasher_spec.hasher == "stronghash":
1092      return functools.partial(
1093          string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
1094    raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
1095
1096  def lookup(self, keys, name=None):
1097    """Looks up `keys` in the table, outputs the corresponding values.
1098
1099    It assigns out-of-vocabulary keys to buckets based in their hashes.
1100
1101    Args:
1102      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1103      name: Optional name for the op.
1104
1105    Returns:
1106      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1107      otherwise a dense `Tensor`.
1108
1109    Raises:
1110      TypeError: when `keys` doesn't match the table key data type.
1111    """
1112    if keys.dtype.base_dtype != self._key_dtype:
1113      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1114                      (self._key_dtype, keys.dtype))
1115    values = keys
1116    if isinstance(keys,
1117                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1118      values = keys.values
1119    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1120      values = math_ops.cast(values, dtypes.int64)
1121
1122    if self._num_oov_buckets == 0:
1123      ids = self._table.lookup(values, name=name)
1124    else:
1125      # TODO(yleon): Consider moving this functionality to its own kernel.
1126      with ops.name_scope(name, "%s_Lookup" % self.name):
1127        str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
1128            self._hasher_spec)
1129        buckets = str_to_hash_bucket(
1130            _as_string(values),
1131            num_buckets=self._num_oov_buckets,
1132            name="hash_bucket")
1133        if self._table:
1134          ids = self._table.lookup(values)
1135          buckets = math_ops.add(buckets, self._table.size())
1136          is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1137          ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1138        else:
1139          ids = buckets
1140    if isinstance(keys, sparse_tensor.SparseTensor):
1141      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1142    elif isinstance(keys, ragged_tensor.RaggedTensor):
1143      return keys.with_values(ids)
1144    return ids
1145
1146
1147@tf_export("lookup.StaticVocabularyTable", v1=[])
1148class StaticVocabularyTable(LookupInterface):
1149  r"""String to Id table that assigns out-of-vocabulary keys to hash buckets.
1150
1151  For example, if an instance of `StaticVocabularyTable` is initialized with a
1152  string-to-id initializer that maps:
1153
1154  >>> init = tf.lookup.KeyValueTensorInitializer(
1155  ...     keys=tf.constant(['emerson', 'lake', 'palmer']),
1156  ...     values=tf.constant([0, 1, 2], dtype=tf.int64))
1157  >>> table = tf.lookup.StaticVocabularyTable(
1158  ...    init,
1159  ...    num_oov_buckets=5)
1160
1161  The `Vocabulary` object will performs the following mapping:
1162
1163  * `emerson -> 0`
1164  * `lake -> 1`
1165  * `palmer -> 2`
1166  * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and
1167  `3 + num_oov_buckets - 1 = 7`, calculated by:
1168  `hash(<term>) % num_oov_buckets + vocab_size`
1169
1170  If input_tensor is:
1171
1172  >>> input_tensor = tf.constant(["emerson", "lake", "palmer",
1173  ...                             "king", "crimson"])
1174  >>> table[input_tensor].numpy()
1175  array([0, 1, 2, 6, 7])
1176
1177  If `initializer` is None, only out-of-vocabulary buckets are used.
1178
1179  Example usage:
1180
1181  >>> num_oov_buckets = 3
1182  >>> vocab = ["emerson", "lake", "palmer", "crimnson"]
1183  >>> import tempfile
1184  >>> f = tempfile.NamedTemporaryFile(delete=False)
1185  >>> f.write('\n'.join(vocab).encode('utf-8'))
1186  >>> f.close()
1187
1188  >>> init = tf.lookup.TextFileInitializer(
1189  ...     f.name,
1190  ...     key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
1191  ...     value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
1192  >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
1193  >>> table.lookup(tf.constant(["palmer", "crimnson" , "king",
1194  ...                           "tarkus", "black", "moon"])).numpy()
1195  array([2, 3, 5, 6, 6, 4])
1196
1197  The hash function used for generating out-of-vocabulary buckets ID is
1198  Fingerprint64.
1199  """
1200
1201  def __init__(self,
1202               initializer,
1203               num_oov_buckets,
1204               lookup_key_dtype=None,
1205               name=None):
1206    """Construct a `StaticVocabularyTable` object.
1207
1208    Args:
1209      initializer: A `TableInitializerBase` object that contains the data used
1210        to initialize the table. If None, then we only use out-of-vocab buckets.
1211      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
1212        be greater than zero.
1213      lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
1214        `initializer.key_dtype` if `initializer` is specified, otherwise
1215        `tf.string`. Must be string or integer, and must be castable to
1216        `initializer.key_dtype`.
1217      name: A name for the operation (optional).
1218
1219    Raises:
1220      ValueError: when `num_oov_buckets` is not positive.
1221      TypeError: when lookup_key_dtype or initializer.key_dtype are not
1222        integer or string. Also when initializer.value_dtype != int64.
1223    """
1224    if num_oov_buckets <= 0:
1225      raise ValueError("oov_buckets must be > 0.")
1226    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1227    # characters to use as table name.
1228    if name:
1229      name = name.rstrip("/")
1230    if initializer:
1231      if lookup_key_dtype is None:
1232        lookup_key_dtype = initializer.key_dtype
1233      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1234      if initializer.key_dtype not in supported_table_key_dtypes:
1235        raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
1236                        (supported_table_key_dtypes, initializer.key_dtype))
1237      if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
1238        raise TypeError(
1239            "Invalid key dtype, expected %s but got %s." %
1240            ("integer" if lookup_key_dtype.is_integer else "non-integer",
1241             initializer.key_dtype))
1242      if initializer.value_dtype != dtypes.int64:
1243        raise TypeError("Invalid value dtype, expected %s but got %s." %
1244                        (dtypes.int64, initializer.value_dtype))
1245      if isinstance(initializer, trackable_base.Trackable):
1246        self._initializer = self._track_trackable(initializer, "_initializer")
1247      self._table = HashTable(initializer, default_value=-1)
1248      name = name or self._table.name
1249    else:
1250      lookup_key_dtype = dtypes.string
1251      self._table = None
1252      name = name or "hash_bucket"
1253    if (not lookup_key_dtype.is_integer) and (dtypes.string !=
1254                                              lookup_key_dtype):
1255      raise TypeError("Invalid key_dtype, expected integer or string, got %s." %
1256                      lookup_key_dtype)
1257    self._num_oov_buckets = num_oov_buckets
1258
1259    self._table_name = None
1260    if name is not None:
1261      self._table_name = name.split("/")[-1]
1262    super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
1263
1264  def _create_resource(self):
1265    if self._table is not None:
1266      return self._table._create_resource()  # pylint: disable=protected-access
1267    return None
1268
1269  def _initialize(self):
1270    if self._table is not None:
1271      return self._table._initialize()  # pylint: disable=protected-access
1272    with ops.name_scope(None, "init"):
1273      return control_flow_ops.no_op()
1274
1275  @property
1276  def resource_handle(self):
1277    if self._table is not None:
1278      return self._table.resource_handle
1279    return None
1280
1281  @property
1282  def name(self):
1283    return self._table_name
1284
1285  def size(self, name=None):
1286    """Compute the number of elements in this table."""
1287    with ops.name_scope(name, "%s_Size" % self.name):
1288      if self._table:
1289        tsize = self._table.size()
1290      else:
1291        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1292      return tsize + self._num_oov_buckets
1293
1294  def lookup(self, keys, name=None):
1295    """Looks up `keys` in the table, outputs the corresponding values.
1296
1297    It assigns out-of-vocabulary keys to buckets based in their hashes.
1298
1299    Args:
1300      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1301      name: Optional name for the op.
1302
1303    Returns:
1304      A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1305      otherwise a dense `Tensor`.
1306
1307    Raises:
1308      TypeError: when `keys` doesn't match the table key data type.
1309    """
1310    if keys.dtype.base_dtype != self._key_dtype:
1311      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1312                      (self._key_dtype, keys.dtype))
1313    values = keys
1314    if isinstance(keys,
1315                  (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1316      values = keys.values
1317    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1318      values = math_ops.cast(values, dtypes.int64)
1319
1320    # TODO(yleon): Consider moving this functionality to its own kernel.
1321    with ops.name_scope(name, "%s_Lookup" % self.name):
1322      buckets = string_ops.string_to_hash_bucket_fast(
1323          _as_string(values),
1324          num_buckets=self._num_oov_buckets,
1325          name="hash_bucket")
1326      if self._table:
1327        ids = self._table.lookup(values)
1328        buckets = math_ops.add(buckets, self._table.size())
1329        is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1330        ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1331      else:
1332        ids = buckets
1333    if isinstance(keys, sparse_tensor.SparseTensor):
1334      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1335    elif isinstance(keys, ragged_tensor.RaggedTensor):
1336      return keys.with_values(ids)
1337    return ids
1338
1339
1340@tf_export(v1=["lookup.StaticVocabularyTable"])
1341class StaticVocabularyTableV1(StaticVocabularyTable):
1342
1343  @property
1344  def initializer(self):
1345    if self._table is not None:
1346      return self._table._init_op  # pylint: disable=protected-access
1347    with ops.name_scope(None, "init"):
1348      return control_flow_ops.no_op()
1349
1350
1351def index_table_from_file(vocabulary_file=None,
1352                          num_oov_buckets=0,
1353                          vocab_size=None,
1354                          default_value=-1,
1355                          hasher_spec=FastHashSpec,
1356                          key_dtype=dtypes.string,
1357                          name=None,
1358                          key_column_index=TextFileIndex.WHOLE_LINE,
1359                          value_column_index=TextFileIndex.LINE_NUMBER,
1360                          delimiter="\t"):
1361  """Returns a lookup table that converts a string tensor into int64 IDs.
1362
1363  This operation constructs a lookup table to convert tensor of strings into
1364  int64 IDs. The mapping can be initialized from a vocabulary file specified in
1365  `vocabulary_file`, where the whole line is the key and the zero-based line
1366  number is the ID.
1367
1368  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1369  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1370  `default_value`.
1371  The bucket ID range is
1372  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
1373
1374  The underlying table must be initialized by calling
1375  `session.run(tf.compat.v1.tables_initializer())` or
1376  `session.run(table.init())` once.
1377
1378  To specify multi-column vocabulary files, use key_column_index and
1379  value_column_index and delimiter.
1380
1381  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1382    expects data type int64.
1383  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1384    type string.
1385  - A value >=0 means use the index (starting at zero) of the split line based
1386    on `delimiter`.
1387
1388  Sample Usages:
1389
1390  If we have a vocabulary file "test.txt" with the following content:
1391
1392  ```
1393  emerson
1394  lake
1395  palmer
1396  ```
1397
1398  ```python
1399  features = tf.constant(["emerson", "lake", "and", "palmer"])
1400  table = tf.lookup.index_table_from_file(
1401      vocabulary_file="test.txt", num_oov_buckets=1)
1402  ids = table.lookup(features)
1403  ...
1404  tf.compat.v1.tables_initializer().run()
1405
1406  ids.eval()  ==> [0, 1, 3, 2]  # where 3 is the out-of-vocabulary bucket
1407  ```
1408
1409  Args:
1410    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1411    num_oov_buckets: The number of out-of-vocabulary buckets.
1412    vocab_size: Number of the elements in the vocabulary, if known.
1413    default_value: The value to use for out-of-vocabulary feature values.
1414      Defaults to -1.
1415    hasher_spec: A `HasherSpec` to specify the hash function to use for
1416      assignation of out-of-vocabulary buckets.
1417    key_dtype: The `key` data type.
1418    name: A name for this op (optional).
1419    key_column_index: The column index from the text file to get the `key`
1420      values from. The default is to use the whole line content.
1421    value_column_index: The column index from the text file to get the `value`
1422      values from. The default is to use the line number, starting from zero.
1423    delimiter: The delimiter to separate fields in a line.
1424
1425  Returns:
1426    The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
1427
1428  Raises:
1429    ValueError: If `vocabulary_file` is not set.
1430    ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
1431      than zero.
1432  """
1433  if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types)
1434                                 and not vocabulary_file):
1435    raise ValueError("vocabulary_file must be specified and must not be empty.")
1436  if num_oov_buckets < 0:
1437    raise ValueError(
1438        "num_oov_buckets must be greater or equal than 0, got %d." %
1439        num_oov_buckets)
1440  if vocab_size is not None and vocab_size < 1:
1441    vocab_file_value = vocabulary_file
1442    if isinstance(vocabulary_file, ops.Tensor):
1443      vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
1444    raise ValueError("vocab_size must be greater than 0, got %d. "
1445                     "vocabulary_file: %s" % (vocab_size, vocab_file_value))
1446  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
1447    raise TypeError("Only integer and string keys are supported.")
1448
1449  with ops.name_scope(name, "string_to_index"):
1450    table = None
1451    with ops.name_scope(None, "hash_table"):
1452      init = TextFileIdTableInitializer(
1453          vocabulary_file,
1454          vocab_size=vocab_size,
1455          key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
1456          name="table_init",
1457          key_column_index=key_column_index,
1458          value_column_index=value_column_index,
1459          delimiter=delimiter)
1460
1461      table = StaticHashTableV1(init, default_value)
1462    if num_oov_buckets:
1463      table = IdTableWithHashBuckets(
1464          table,
1465          num_oov_buckets=num_oov_buckets,
1466          hasher_spec=hasher_spec,
1467          key_dtype=key_dtype)
1468
1469    return table
1470
1471
1472def index_table_from_tensor(vocabulary_list,
1473                            num_oov_buckets=0,
1474                            default_value=-1,
1475                            hasher_spec=FastHashSpec,
1476                            dtype=dtypes.string,
1477                            name=None):
1478  """Returns a lookup table that converts a string tensor into int64 IDs.
1479
1480  This operation constructs a lookup table to convert tensor of strings into
1481  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
1482  tensor where each element is a key and corresponding index within the tensor
1483  is the value.
1484
1485  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1486  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1487  `default_value`. The bucket ID range is
1488  `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
1489
1490  The underlying table must be initialized by calling
1491  `session.run(tf.compat.v1.tables_initializer())` or
1492  `session.run(table.init())` once.
1493
1494  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1495  the table initializer op, it will throw a `FailedPreconditionError`.
1496
1497  Sample Usages:
1498
1499  ```python
1500  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1501  table = tf.lookup.index_table_from_tensor(
1502      vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
1503  features = tf.constant(["emerson", "lake", "and", "palmer"])
1504  ids = table.lookup(features)
1505  ...
1506  tf.compat.v1.tables_initializer().run()
1507
1508  ids.eval()  ==> [0, 1, 4, 2]
1509  ```
1510
1511  Args:
1512    vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
1513      indices. The type of this object must be castable to `dtype`.
1514    num_oov_buckets: The number of out-of-vocabulary buckets.
1515    default_value: The value to use for out-of-vocabulary feature values.
1516      Defaults to -1.
1517    hasher_spec: A `HasherSpec` to specify the hash function to use for
1518      assignment of out-of-vocabulary buckets.
1519    dtype: The type of values passed to `lookup`. Only string and integers are
1520      supported.
1521    name: A name for this op (optional).
1522
1523  Returns:
1524    The lookup table to map an input `Tensor` to index `int64` `Tensor`.
1525
1526  Raises:
1527    ValueError: If `vocabulary_list` is invalid.
1528    ValueError: If `num_oov_buckets` is negative.
1529  """
1530  if vocabulary_list is None:
1531    raise ValueError("vocabulary_list must be specified.")
1532
1533  if num_oov_buckets < 0:
1534    raise ValueError(
1535        "num_oov_buckets must be greater or equal than 0, got %d." %
1536        num_oov_buckets)
1537
1538  if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
1539    raise TypeError("Only integer and string keys are supported.")
1540
1541  with ops.name_scope(name, "string_to_index"):
1542    keys = ops.convert_to_tensor(vocabulary_list)
1543    if keys.dtype.is_integer != dtype.is_integer:
1544      raise ValueError(
1545          "Expected %s, got %s." %
1546          ("integer" if dtype.is_integer else "non-integer", keys.dtype))
1547    if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
1548      raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
1549    num_elements = array_ops.size(keys)
1550    values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1551
1552    with ops.name_scope(None, "hash_table"):
1553      table_keys = math_ops.cast(
1554          keys, dtypes.int64) if keys.dtype.is_integer else keys
1555      init = KeyValueTensorInitializer(
1556          table_keys,
1557          values,
1558          table_keys.dtype.base_dtype,
1559          dtypes.int64,
1560          name="table_init")
1561      table = StaticHashTableV1(init, default_value)
1562    if num_oov_buckets:
1563      table = IdTableWithHashBuckets(
1564          table,
1565          num_oov_buckets=num_oov_buckets,
1566          hasher_spec=hasher_spec,
1567          key_dtype=dtype)
1568    return table
1569
1570
1571def index_to_string_table_from_file(vocabulary_file,
1572                                    vocab_size=None,
1573                                    default_value="UNK",
1574                                    name=None,
1575                                    key_column_index=TextFileIndex.LINE_NUMBER,
1576                                    value_column_index=TextFileIndex.WHOLE_LINE,
1577                                    delimiter="\t"):
1578  """Returns a lookup table that maps a `Tensor` of indices into strings.
1579
1580  This operation constructs a lookup table to map int64 indices into string
1581  values. The table is initialized from a vocabulary file specified in
1582  `vocabulary_file`, where the whole line is the value and the
1583  zero-based line number is the index.
1584
1585  Any input which does not have a corresponding index in the vocabulary file
1586  (an out-of-vocabulary entry) is assigned the `default_value`
1587
1588  The underlying table must be initialized by calling
1589  `session.run(tf.compat.v1.tables_initializer())` or
1590  `session.run(table.init())` once.
1591
1592  To specify multi-column vocabulary files, use key_column_index and
1593  value_column_index and delimiter.
1594
1595  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1596    expects data type int64.
1597  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1598    type string.
1599  - A value >=0 means use the index (starting at zero) of the split line based
1600    on `delimiter`.
1601
1602  Sample Usages:
1603
1604  If we have a vocabulary file "test.txt" with the following content:
1605
1606  ```
1607  emerson
1608  lake
1609  palmer
1610  ```
1611
1612  ```python
1613  indices = tf.constant([1, 5], tf.int64)
1614  table = tf.lookup.index_to_string_table_from_file(
1615      vocabulary_file="test.txt", default_value="UNKNOWN")
1616  values = table.lookup(indices)
1617  ...
1618  tf.compat.v1.tables_initializer().run()
1619
1620  values.eval() ==> ["lake", "UNKNOWN"]
1621  ```
1622
1623  Args:
1624    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1625    vocab_size: Number of the elements in the vocabulary, if known.
1626    default_value: The value to use for out-of-vocabulary indices.
1627    name: A name for this op (optional).
1628    key_column_index: The column index from the text file to get the `key`
1629      values from. The default is to use the line number, starting from zero.
1630    value_column_index: The column index from the text file to get the `value`
1631      values from. The default is to use the whole line content.
1632    delimiter: The delimiter to separate fields in a line.
1633
1634  Returns:
1635    The lookup table to map a string values associated to a given index `int64`
1636    `Tensors`.
1637
1638  Raises:
1639    ValueError: when `vocabulary_file` is empty.
1640    ValueError: when `vocab_size` is invalid.
1641  """
1642  if vocabulary_file is None or (isinstance(vocabulary_file, six.string_types)
1643                                 and not vocabulary_file):
1644    raise ValueError("vocabulary_file must be specified and must not be empty.")
1645
1646  if vocab_size is not None and vocab_size < 1:
1647    raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
1648
1649  with ops.name_scope(name, "index_to_string"):
1650    init = TextFileStringTableInitializer(
1651        vocabulary_file,
1652        vocab_size=vocab_size,
1653        name="table_init",
1654        key_column_index=key_column_index,
1655        value_column_index=value_column_index,
1656        delimiter=delimiter)
1657
1658    # TODO(yleon): Use a more efficient structure.
1659    return StaticHashTableV1(init, default_value)
1660
1661
1662def index_to_string_table_from_tensor(vocabulary_list,
1663                                      default_value="UNK",
1664                                      name=None):
1665  """Returns a lookup table that maps a `Tensor` of indices into strings.
1666
1667  This operation constructs a lookup table to map int64 indices into string
1668  values. The mapping is initialized from a string `vocabulary_list` 1-D
1669  `Tensor` where each element is a value and the corresponding index within the
1670  tensor is the key.
1671
1672  Any input which does not have a corresponding index in 'vocabulary_list'
1673  (an out-of-vocabulary entry) is assigned the `default_value`
1674
1675  The underlying table must be initialized by calling
1676  `session.run(tf.compat.v1.tables_initializer())` or
1677  `session.run(table.init())` once.
1678
1679  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1680  the table initializer op, it will throw a `FailedPreconditionError`.
1681
1682  Sample Usages:
1683
1684  ```python
1685  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1686  indices = tf.constant([1, 5], tf.int64)
1687  table = tf.lookup.index_to_string_table_from_tensor(
1688      vocabulary_list, default_value="UNKNOWN")
1689  values = table.lookup(indices)
1690  ...
1691  tf.compat.v1.tables_initializer().run()
1692
1693  values.eval() ==> ["lake", "UNKNOWN"]
1694  ```
1695
1696  Args:
1697    vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
1698      from indices.
1699    default_value: The value to use for out-of-vocabulary indices.
1700    name: A name for this op (optional).
1701
1702  Returns:
1703    The lookup table to map a string values associated to a given index `int64`
1704    `Tensors`.
1705
1706  Raises:
1707    ValueError: when `vocabulary_list` is not set.
1708  """
1709
1710  if vocabulary_list is None:
1711    raise ValueError("vocabulary_list must be specified.")
1712
1713  with ops.name_scope(name, "index_to_string"):
1714    vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
1715    num_elements = array_ops.size(vocabulary_list)
1716    keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1717
1718    init = KeyValueTensorInitializer(
1719        keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
1720    # TODO(yleon): Use a more efficient structure.
1721    return StaticHashTableV1(init, default_value)
1722
1723
1724class MutableHashTable(LookupInterface):
1725  """A generic mutable hash table implementation.
1726
1727  Data can be inserted by calling the insert method and removed by calling the
1728  remove method. It does not support initialization via the init method.
1729
1730  Example usage:
1731
1732  ```python
1733  table = tf.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64,
1734                                     default_value=-1)
1735  sess.run(table.insert(keys, values))
1736  out = table.lookup(query_keys)
1737  print(out.eval())
1738  ```
1739  """
1740
1741  def __init__(self,
1742               key_dtype,
1743               value_dtype,
1744               default_value,
1745               name="MutableHashTable",
1746               checkpoint=True):
1747    """Creates an empty `MutableHashTable` object.
1748
1749    Creates a table, the type of its keys and values are specified by key_dtype
1750    and value_dtype, respectively.
1751
1752    Args:
1753      key_dtype: the type of the key tensors.
1754      value_dtype: the type of the value tensors.
1755      default_value: The value to use if a key is missing in the table.
1756      name: A name for the operation (optional).
1757      checkpoint: if True, the contents of the table are saved to and restored
1758        from checkpoints. If `shared_name` is empty for a checkpointed table, it
1759        is shared using the table node name.
1760
1761    Returns:
1762      A `MutableHashTable` object.
1763
1764    Raises:
1765      ValueError: If checkpoint is True and no name was specified.
1766    """
1767    self._default_value = ops.convert_to_tensor(
1768        default_value, dtype=value_dtype)
1769    self._value_shape = self._default_value.get_shape()
1770    self._checkpoint = checkpoint
1771    self._key_dtype = key_dtype
1772    self._value_dtype = value_dtype
1773    self._name = name
1774
1775    self._shared_name = None
1776    if context.executing_eagerly():
1777      # TODO(allenl): This will leak memory due to kernel caching by the
1778      # shared_name attribute value (but is better than the alternative of
1779      # sharing everything by default when executing eagerly; hopefully creating
1780      # tables in a loop is uncommon).
1781      # TODO(rohanj): Use context.shared_name() instead.
1782      self._shared_name = "table_%d" % (ops.uid(),)
1783    super(MutableHashTable, self).__init__(key_dtype, value_dtype)
1784
1785    self._resource_handle = self._create_resource()
1786    if checkpoint:
1787      saveable = MutableHashTable._Saveable(self, name)
1788      if not context.executing_eagerly():
1789        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1790
1791  def _create_resource(self):
1792    # The table must be shared if checkpointing is requested for multi-worker
1793    # training to work correctly. Use the node name if no shared_name has been
1794    # explicitly specified.
1795    use_node_name_sharing = self._checkpoint and self._shared_name is None
1796    if self._default_value.get_shape().ndims == 0:
1797      table_ref = gen_lookup_ops.mutable_hash_table_v2(
1798          shared_name=self._shared_name,
1799          use_node_name_sharing=use_node_name_sharing,
1800          key_dtype=self._key_dtype,
1801          value_dtype=self._value_dtype,
1802          name=self._name)
1803    else:
1804      table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
1805          shared_name=self._shared_name,
1806          use_node_name_sharing=use_node_name_sharing,
1807          key_dtype=self._key_dtype,
1808          value_dtype=self._value_dtype,
1809          value_shape=self._default_value.get_shape(),
1810          name=self._name)
1811
1812    if context.executing_eagerly():
1813      self._table_name = None
1814    else:
1815      self._table_name = table_ref.op.name.split("/")[-1]
1816    return table_ref
1817
1818  @property
1819  def name(self):
1820    return self._table_name
1821
1822  def size(self, name=None):
1823    """Compute the number of elements in this table.
1824
1825    Args:
1826      name: A name for the operation (optional).
1827
1828    Returns:
1829      A scalar tensor containing the number of elements in this table.
1830    """
1831    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1832      with ops.colocate_with(self.resource_handle):
1833        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1834
1835  def remove(self, keys, name=None):
1836    """Removes `keys` and its associated values from the table.
1837
1838    If a key is not present in the table, it is silently ignored.
1839
1840    Args:
1841      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1842        key type.
1843      name: A name for the operation (optional).
1844
1845    Returns:
1846      The created Operation.
1847
1848    Raises:
1849      TypeError: when `keys` do not match the table data types.
1850    """
1851    if keys.dtype != self._key_dtype:
1852      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1853                      (self._key_dtype, keys.dtype))
1854
1855    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1856                        (self.resource_handle, keys, self._default_value)):
1857      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1858
1859    return op
1860
1861  def lookup(self, keys, dynamic_default_values=None, name=None):
1862    """Looks up `keys` in a table, outputs the corresponding values.
1863
1864    The `default_value` is used for keys not present in the table.
1865
1866    Args:
1867      keys: Keys to look up. Can be a tensor of any shape. Must match the
1868        table's key_dtype.
1869      dynamic_default_values: The values to use if a key is missing in the
1870        table. If None (by default), the `table.default_value` will be used.
1871        Shape of `dynamic_default_values` must be same with
1872        `table.default_value` or the lookup result tensor.
1873        In the latter case, each key will have a different default value.
1874
1875        For example:
1876
1877          ```python
1878          keys = [0, 1, 3]
1879          dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1880
1881          # The key '0' will use [1, 3, 4] as default value.
1882          # The key '1' will use [2, 3, 9] as default value.
1883          # The key '3' will use [8, 3, 0] as default value.
1884          ```
1885
1886      name: A name for the operation (optional).
1887
1888    Returns:
1889      A tensor containing the values in the same shape as `keys` using the
1890        table's value type.
1891
1892    Raises:
1893      TypeError: when `keys` do not match the table data types.
1894    """
1895    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1896                        (self.resource_handle, keys, self._default_value)):
1897      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1898      with ops.colocate_with(self.resource_handle):
1899        values = gen_lookup_ops.lookup_table_find_v2(
1900            self.resource_handle, keys, dynamic_default_values
1901            if dynamic_default_values is not None else self._default_value)
1902    return values
1903
1904  def insert(self, keys, values, name=None):
1905    """Associates `keys` with `values`.
1906
1907    Args:
1908      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
1909        key type.
1910      values: Values to be associated with keys. Must be a tensor of the same
1911        shape as `keys` and match the table's value type.
1912      name: A name for the operation (optional).
1913
1914    Returns:
1915      The created Operation.
1916
1917    Raises:
1918      TypeError: when `keys` or `values` doesn't match the table data
1919        types.
1920    """
1921    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
1922                        [self.resource_handle, keys, values]):
1923      keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
1924      values = ops.convert_to_tensor(values, self._value_dtype, name="values")
1925      with ops.colocate_with(self.resource_handle):
1926        # pylint: disable=protected-access
1927        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
1928                                                   values)
1929    return op
1930
1931  def export(self, name=None):
1932    """Returns tensors of all keys and values in the table.
1933
1934    Args:
1935      name: A name for the operation (optional).
1936
1937    Returns:
1938      A pair of tensors with the first tensor containing all keys and the
1939        second tensors containing all values in the table.
1940    """
1941    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
1942                        [self.resource_handle]):
1943      with ops.colocate_with(self.resource_handle):
1944        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
1945            self.resource_handle, self._key_dtype, self._value_dtype)
1946    return exported_keys, exported_values
1947
1948  def _gather_saveables_for_checkpoint(self):
1949    """For object-based checkpointing."""
1950    return {
1951        "table":
1952            functools.partial(
1953                MutableHashTable._Saveable, table=self, name=self._name,
1954                table_name=self._name)
1955    }
1956
1957  class _Saveable(BaseSaverBuilder.SaveableObject):
1958    """SaveableObject implementation for DenseHashTable."""
1959
1960    def __init__(self, table, name, table_name=None):
1961      tensors = table.export()
1962      specs = [
1963          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
1964          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
1965      ]
1966      self.table_name = table_name or name
1967      # pylint: disable=protected-access
1968      super(MutableHashTable._Saveable, self).__init__(table, specs, name)
1969
1970    def restore(self, restored_tensors, restored_shapes):
1971      del restored_shapes  # unused
1972      # pylint: disable=protected-access
1973      with ops.name_scope("%s_table_restore" % self.table_name):
1974        with ops.colocate_with(self.op.resource_handle):
1975          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
1976                                                       restored_tensors[0],
1977                                                       restored_tensors[1])
1978
1979
1980@tf_export("lookup.experimental.DenseHashTable")
1981class DenseHashTable(LookupInterface):
1982  """A generic mutable hash table implementation using tensors as backing store.
1983
1984  Data can be inserted by calling the insert method and removed by calling the
1985  remove method. It does not support initialization via the init method.
1986
1987  It uses "open addressing" with quadratic reprobing to resolve collisions.
1988  Compared to `MutableHashTable` the insert, remove and lookup operations in a
1989  `DenseHashTable` are typically faster, but memory usage can be higher.
1990  However, `DenseHashTable` does not require additional memory for
1991  temporary tensors created during checkpointing and restore operations.
1992
1993  Example usage:
1994
1995  >>> table = tf.lookup.experimental.DenseHashTable(
1996  ...     key_dtype=tf.string,
1997  ...     value_dtype=tf.int64,
1998  ...     default_value=-1,
1999  ...     empty_key='',
2000  ...     deleted_key='$')
2001  >>> keys = tf.constant(['a', 'b', 'c'])
2002  >>> values = tf.constant([0, 1, 2], dtype=tf.int64)
2003  >>> table.insert(keys, values)
2004  >>> table.remove(tf.constant(['c']))
2005  >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy()
2006  array([ 0,  1, -1, -1])
2007  """
2008
2009  # TODO(andreasst): consider extracting common code with MutableHashTable into
2010  # a common superclass.
2011  def __init__(self,
2012               key_dtype,
2013               value_dtype,
2014               default_value,
2015               empty_key,
2016               deleted_key,
2017               initial_num_buckets=None,
2018               name="MutableDenseHashTable",
2019               checkpoint=True):
2020    """Creates an empty `DenseHashTable` object.
2021
2022    Creates a table, the type of its keys and values are specified by key_dtype
2023    and value_dtype, respectively.
2024
2025    Args:
2026      key_dtype: the type of the key tensors.
2027      value_dtype: the type of the value tensors.
2028      default_value: The value to use if a key is missing in the table.
2029      empty_key: the key to use to represent empty buckets internally. Must not
2030        be used in insert, remove or lookup operations.
2031      deleted_key: the key to use to represent deleted buckets internally. Must
2032        not be used in insert, remove or lookup operations and be different from
2033        the empty_key.
2034      initial_num_buckets: the initial number of buckets.
2035      name: A name for the operation (optional).
2036      checkpoint: if True, the contents of the table are saved to and restored
2037        from checkpoints. If `shared_name` is empty for a checkpointed table, it
2038        is shared using the table node name.
2039
2040    Returns:
2041      A `DenseHashTable` object.
2042
2043    Raises:
2044      ValueError: If checkpoint is True and no name was specified.
2045    """
2046    self._default_value = ops.convert_to_tensor(
2047        default_value, dtype=value_dtype, name="default_value")
2048    self._key_dtype = key_dtype
2049    self._value_dtype = value_dtype
2050    self._initial_num_buckets = initial_num_buckets
2051    self._value_shape = self._default_value.get_shape()
2052    self._checkpoint = checkpoint
2053    self._name = name
2054
2055    self._empty_key = empty_key
2056    self._deleted_key = deleted_key
2057    self._shared_name = None
2058    if context.executing_eagerly():
2059      # TODO(allenl): This will leak memory due to kernel caching by the
2060      # shared_name attribute value (but is better than the alternative of
2061      # sharing everything by default when executing eagerly; hopefully creating
2062      # tables in a loop is uncommon).
2063      # TODO(rohanj): Use context.shared_name() instead.
2064      self._shared_name = "table_%d" % (ops.uid(),)
2065    super(DenseHashTable, self).__init__(key_dtype, value_dtype)
2066
2067    self._resource_handle = self._create_resource()
2068    if checkpoint:
2069      saveable = DenseHashTable._Saveable(self, name)
2070      if not context.executing_eagerly():
2071        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
2072
2073  def _create_resource(self):
2074    # The table must be shared if checkpointing is requested for multi-worker
2075    # training to work correctly. Use the node name if no shared_name has been
2076    # explicitly specified.
2077    use_node_name_sharing = self._checkpoint and self._shared_name is None
2078    empty_key = ops.convert_to_tensor(
2079        self._empty_key, dtype=self._key_dtype, name="empty_key")
2080    deleted_key = ops.convert_to_tensor(
2081        self._deleted_key, dtype=self._key_dtype, name="deleted_key")
2082    table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
2083        empty_key=empty_key,
2084        deleted_key=deleted_key,
2085        shared_name=self._shared_name,
2086        use_node_name_sharing=use_node_name_sharing,
2087        value_dtype=self._value_dtype,
2088        value_shape=self._value_shape,
2089        initial_num_buckets=self._initial_num_buckets,
2090        name=self._name)
2091    if context.executing_eagerly():
2092      self._table_name = None
2093    else:
2094      self._table_name = table_ref.op.name.split("/")[-1]
2095    return table_ref
2096
2097  @property
2098  def name(self):
2099    return self._table_name
2100
2101  def size(self, name=None):
2102    """Compute the number of elements in this table.
2103
2104    Args:
2105      name: A name for the operation (optional).
2106
2107    Returns:
2108      A scalar tensor containing the number of elements in this table.
2109    """
2110    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
2111      with ops.colocate_with(self.resource_handle):
2112        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
2113
2114  def lookup(self, keys, name=None):
2115    """Looks up `keys` in a table, outputs the corresponding values.
2116
2117    The `default_value` is used for keys not present in the table.
2118
2119    Args:
2120      keys: Keys to look up. Can be a tensor of any shape. Must match the
2121        table's key_dtype.
2122      name: A name for the operation (optional).
2123
2124    Returns:
2125      A tensor containing the values in the same shape as `keys` using the
2126        table's value type.
2127
2128    Raises:
2129      TypeError: when `keys` do not match the table data types.
2130    """
2131    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
2132                        [self.resource_handle, keys]):
2133      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2134      with ops.colocate_with(self.resource_handle):
2135        values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
2136                                                     self._default_value)
2137
2138    return values
2139
2140  def insert_or_assign(self, keys, values, name=None):
2141    """Associates `keys` with `values`.
2142
2143    Args:
2144      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2145        key type.
2146      values: Values to be associated with keys. Must be a tensor of the same
2147        shape as `keys` and match the table's value type.
2148      name: A name for the operation (optional).
2149
2150    Returns:
2151      The created Operation.
2152
2153    Raises:
2154      TypeError: when `keys` or `values` doesn't match the table data
2155        types.
2156    """
2157    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2158                        [self.resource_handle, keys, values]):
2159      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2160      values = ops.convert_to_tensor(
2161          values, dtype=self._value_dtype, name="values")
2162      with ops.colocate_with(self.resource_handle):
2163        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2164                                                   values)
2165      return op
2166
2167  def insert(self, keys, values, name=None):
2168    """Associates `keys` with `values`.
2169
2170    Args:
2171      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2172        key type.
2173      values: Values to be associated with keys. Must be a tensor of the same
2174        shape as `keys` and match the table's value type.
2175      name: A name for the operation (optional).
2176
2177    Returns:
2178      The created Operation.
2179
2180    Raises:
2181      TypeError: when `keys` or `values` doesn't match the table data
2182        types.
2183    """
2184    return self.insert_or_assign(keys, values, name)
2185
2186  def erase(self, keys, name=None):
2187    """Removes `keys` and its associated values from the table.
2188
2189    If a key is not present in the table, it is silently ignored.
2190
2191    Args:
2192      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2193        key type.
2194      name: A name for the operation (optional).
2195
2196    Returns:
2197      The created Operation.
2198
2199    Raises:
2200      TypeError: when `keys` do not match the table data types.
2201    """
2202    if keys.dtype != self._key_dtype:
2203      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
2204                      (self._key_dtype, keys.dtype))
2205
2206    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
2207                        (self.resource_handle, keys, self._default_value)):
2208      # pylint: disable=protected-access
2209      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
2210
2211    return op
2212
2213  def remove(self, keys, name=None):
2214    """Removes `keys` and its associated values from the table.
2215
2216    If a key is not present in the table, it is silently ignored.
2217
2218    Args:
2219      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2220        key type.
2221      name: A name for the operation (optional).
2222
2223    Returns:
2224      The created Operation.
2225
2226    Raises:
2227      TypeError: when `keys` do not match the table data types.
2228    """
2229    return self.erase(keys, name)
2230
2231  def export(self, name=None):
2232    """Returns tensors of all keys and values in the table.
2233
2234    Args:
2235      name: A name for the operation (optional).
2236
2237    Returns:
2238      A pair of tensors with the first tensor containing all keys and the
2239        second tensors containing all values in the table.
2240    """
2241    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2242                        [self.resource_handle]):
2243      with ops.colocate_with(self.resource_handle):
2244        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2245            self.resource_handle, self._key_dtype, self._value_dtype)
2246
2247    return exported_keys, exported_values
2248
2249  def _gather_saveables_for_checkpoint(self):
2250    """For object-based checkpointing."""
2251    return {
2252        "table":
2253            functools.partial(
2254                DenseHashTable._Saveable, table=self, name=self._name,
2255                table_name=self._name)
2256    }
2257
2258  class _Saveable(BaseSaverBuilder.SaveableObject):
2259    """SaveableObject implementation for DenseHashTable."""
2260
2261    def __init__(self, table, name, table_name=None):
2262      tensors = table.export()
2263      specs = [
2264          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2265          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2266      ]
2267      self.table_name = table_name or name
2268      # pylint: disable=protected-access
2269      super(DenseHashTable._Saveable, self).__init__(table, specs, name)
2270
2271    def restore(self, restored_tensors, restored_shapes):
2272      del restored_shapes  # unused
2273      # pylint: disable=protected-access
2274      with ops.name_scope("%s_table_restore" % self.table_name):
2275        with ops.colocate_with(self.op.resource_handle):
2276          return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2277                                                       restored_tensors[0],
2278                                                       restored_tensors[1])
2279
2280
2281ops.NotDifferentiable("LookupTableFind")
2282ops.NotDifferentiable("LookupTableFindV2")
2283ops.NotDifferentiable("LookupTableInsert")
2284ops.NotDifferentiable("LookupTableInsertV2")
2285ops.NotDifferentiable("LookupTableSize")
2286ops.NotDifferentiable("LookupTableSizeV2")
2287ops.NotDifferentiable("HashTable")
2288ops.NotDifferentiable("HashTableV2")
2289ops.NotDifferentiable("InitializeTable")
2290ops.NotDifferentiable("InitializeTableV2")
2291ops.NotDifferentiable("InitializeTableFromTextFile")
2292ops.NotDifferentiable("InitializeTableFromTextFileV2")
2293ops.NotDifferentiable("MutableDenseHashTable")
2294ops.NotDifferentiable("MutableDenseHashTableV2")
2295ops.NotDifferentiable("MutableHashTable")
2296ops.NotDifferentiable("MutableHashTableV2")
2297ops.NotDifferentiable("MutableHashTableOfTensors")
2298ops.NotDifferentiable("MutableHashTableOfTensorsV2")
2299