• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16
17from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
22from tensorflow.python.ops import lookup_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export("data.experimental.DatasetInitializer")
28class DatasetInitializer(lookup_ops.TableInitializerBase):
29  """Creates a table initializer from a `tf.data.Dataset`.
30
31  Sample usage:
32
33  >>> keys = tf.data.Dataset.range(100)
34  >>> values = tf.data.Dataset.range(100).map(
35  ...     lambda x: tf.strings.as_string(x * 2))
36  >>> ds = tf.data.Dataset.zip((keys, values))
37  >>> init = tf.data.experimental.DatasetInitializer(ds)
38  >>> table = tf.lookup.StaticHashTable(init, "")
39  >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
40  array([b'0', b'2', b'4'], dtype=object)
41
42  Attributes:
43    dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
44      first scalar is treated as a key and the second as value.
45  Raises: ValueError if `dataset` doesn't conform to specifications.
46  """
47
48  def __init__(self, dataset):
49    """Creates a table initializer from a `tf.data.Dataset`.
50
51    Args:
52      dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
53        first scalar is treated as a key and the second as value.
54    Raises: ValueError if `dataset` doesn't conform to specifications.
55    Returns: A `DatasetInitializer` object
56    """
57    # Assert that the dataset element spec is a tuple of TensorSpecs where
58    # each tensor is a scalar.
59    self.dataset = dataset
60    elem_spec = self.dataset.element_spec
61    if len(elem_spec) != 2:
62      raise ValueError("element spec size should be 2")
63    if not isinstance(elem_spec[0], tensor_spec.TensorSpec):
64      raise ValueError("elem_spec[0] should be of type TensorSpec, got: %s" %
65                       type(elem_spec[1]))
66    if not isinstance(elem_spec[1], tensor_spec.TensorSpec):
67      raise ValueError("elem_spec[1] should be of type TensorSpec, got: %s" %
68                       type(elem_spec[1]))
69    if elem_spec[0].shape.rank not in (None, 0):
70      raise ValueError("key tensor should be a scalar")
71    if elem_spec[1].shape.rank not in (None, 0):
72      raise ValueError("value tensor should be a scalar")
73
74    key_type = elem_spec[0].dtype
75    value_type = elem_spec[1].dtype
76    super(DatasetInitializer, self).__init__(key_type, value_type)
77
78  def initialize(self, table):
79    lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype)
80    init_op = ged_ops.initialize_table_from_dataset(
81        table.resource_handle, self.dataset._variant_tensor)  # pylint: disable=protected-access
82    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
83    return init_op
84
85
86@tf_export("data.experimental.table_from_dataset")
87def table_from_dataset(dataset=None,
88                       num_oov_buckets=0,
89                       vocab_size=None,
90                       default_value=None,
91                       hasher_spec=lookup_ops.FastHashSpec,
92                       key_dtype=dtypes.string,
93                       name=None):
94  """Returns a lookup table based on the given dataset.
95
96  This operation constructs a lookup table based on the given dataset of pairs
97  of (key, value).
98
99  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
100  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
101  `default_value`.
102  The bucket ID range is
103  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
104
105  Sample Usages:
106
107  >>> keys = tf.data.Dataset.range(100)
108  >>> values = tf.data.Dataset.range(100).map(
109  ...     lambda x: tf.strings.as_string(x * 2))
110  >>> ds = tf.data.Dataset.zip((keys, values))
111  >>> table = tf.data.experimental.table_from_dataset(
112  ...                               ds, default_value='n/a', key_dtype=tf.int64)
113  >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
114  array([b'0', b'2', b'4'], dtype=object)
115
116  Args:
117    dataset: A dataset containing (key, value) pairs.
118    num_oov_buckets: The number of out-of-vocabulary buckets.
119    vocab_size: Number of the elements in the vocabulary, if known.
120    default_value: The value to use for out-of-vocabulary feature values.
121      Defaults to -1.
122    hasher_spec: A `HasherSpec` to specify the hash function to use for
123      assignation of out-of-vocabulary buckets.
124    key_dtype: The `key` data type.
125    name: A name for this op (optional).
126
127  Returns:
128    The lookup table based on the given dataset.
129
130  Raises:
131    ValueError: If
132      * `dataset` does not contain pairs
133      * The 2nd item in the `dataset` pairs has a dtype which is incompatible
134        with `default_value`
135      * `num_oov_buckets` is negative
136      * `vocab_size` is not greater than zero
137      * The `key_dtype` is not integer or string
138  """
139  elem_spec = dataset.element_spec
140  if len(elem_spec) != 2:
141    raise ValueError("The given dataset must contain pairs.")
142  if default_value is None:
143    default_value = -1
144    if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
145      raise ValueError("The dtype of the values requires manually setting a "
146                       "compatible default_value.")
147  if num_oov_buckets < 0:
148    raise ValueError(
149        "num_oov_buckets must be greater or equal than 0, got %d." %
150        num_oov_buckets)
151  if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
152      vocab_size < 1):
153    raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
154  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
155    raise TypeError("Only integer and string keys are supported.")
156  if vocab_size is not None:
157    if isinstance(vocab_size, ops.Tensor):
158      vocab_size = math_ops.cast(vocab_size, dtypes.int64)
159    dataset = dataset.take(vocab_size)
160    dataset = dataset.apply(assert_cardinality(vocab_size))
161  with ops.name_scope(name, "string_to_index"):
162    initializer = DatasetInitializer(dataset)
163    with ops.name_scope(None, "hash_table"):
164      table = lookup_ops.StaticHashTableV1(initializer, default_value)
165      if num_oov_buckets:
166        table = lookup_ops.IdTableWithHashBuckets(
167            table,
168            num_oov_buckets=num_oov_buckets,
169            hasher_spec=hasher_spec,
170            key_dtype=key_dtype)
171      return table
172
173
174@tf_export("data.experimental.index_table_from_dataset")
175def index_table_from_dataset(dataset=None,
176                             num_oov_buckets=0,
177                             vocab_size=None,
178                             default_value=-1,
179                             hasher_spec=lookup_ops.FastHashSpec,
180                             key_dtype=dtypes.string,
181                             name=None):
182  """Returns an index lookup table based on the given dataset.
183
184  This operation constructs a lookup table based on the given dataset of keys.
185
186  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
187  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
188  `default_value`.
189  The bucket ID range is
190  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
191
192  Sample Usages:
193
194  >>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2))
195  >>> table = tf.data.experimental.index_table_from_dataset(
196  ...                                     ds, key_dtype=dtypes.int64)
197  >>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy()
198  array([0, 1, 2])
199
200  Args:
201    dataset: A dataset of keys.
202    num_oov_buckets: The number of out-of-vocabulary buckets.
203    vocab_size: Number of the elements in the vocabulary, if known.
204    default_value: The value to use for out-of-vocabulary feature values.
205      Defaults to -1.
206    hasher_spec: A `HasherSpec` to specify the hash function to use for
207      assignation of out-of-vocabulary buckets.
208    key_dtype: The `key` data type.
209    name: A name for this op (optional).
210
211  Returns:
212    The lookup table based on the given dataset.
213
214  Raises:
215    ValueError: If
216      * `num_oov_buckets` is negative
217      * `vocab_size` is not greater than zero
218      * The `key_dtype` is not integer or string
219  """
220  return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)),
221                            num_oov_buckets, vocab_size, default_value,
222                            hasher_spec, key_dtype, name)
223