1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Lookup operations.""" 16 17from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_spec 21from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 22from tensorflow.python.ops import lookup_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.util.tf_export import tf_export 25 26 27@tf_export("data.experimental.DatasetInitializer") 28class DatasetInitializer(lookup_ops.TableInitializerBase): 29 """Creates a table initializer from a `tf.data.Dataset`. 30 31 Sample usage: 32 33 >>> keys = tf.data.Dataset.range(100) 34 >>> values = tf.data.Dataset.range(100).map( 35 ... lambda x: tf.strings.as_string(x * 2)) 36 >>> ds = tf.data.Dataset.zip((keys, values)) 37 >>> init = tf.data.experimental.DatasetInitializer(ds) 38 >>> table = tf.lookup.StaticHashTable(init, "") 39 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() 40 array([b'0', b'2', b'4'], dtype=object) 41 42 Attributes: 43 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 44 first scalar is treated as a key and the second as value. 45 Raises: ValueError if `dataset` doesn't conform to specifications. 46 """ 47 48 def __init__(self, dataset): 49 """Creates a table initializer from a `tf.data.Dataset`. 50 51 Args: 52 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 53 first scalar is treated as a key and the second as value. 54 Raises: ValueError if `dataset` doesn't conform to specifications. 55 Returns: A `DatasetInitializer` object 56 """ 57 # Assert that the dataset element spec is a tuple of TensorSpecs where 58 # each tensor is a scalar. 59 self.dataset = dataset 60 elem_spec = self.dataset.element_spec 61 if len(elem_spec) != 2: 62 raise ValueError("element spec size should be 2") 63 if not isinstance(elem_spec[0], tensor_spec.TensorSpec): 64 raise ValueError("elem_spec[0] should be of type TensorSpec, got: %s" % 65 type(elem_spec[1])) 66 if not isinstance(elem_spec[1], tensor_spec.TensorSpec): 67 raise ValueError("elem_spec[1] should be of type TensorSpec, got: %s" % 68 type(elem_spec[1])) 69 if elem_spec[0].shape.rank not in (None, 0): 70 raise ValueError("key tensor should be a scalar") 71 if elem_spec[1].shape.rank not in (None, 0): 72 raise ValueError("value tensor should be a scalar") 73 74 key_type = elem_spec[0].dtype 75 value_type = elem_spec[1].dtype 76 super(DatasetInitializer, self).__init__(key_type, value_type) 77 78 def initialize(self, table): 79 lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype) 80 init_op = ged_ops.initialize_table_from_dataset( 81 table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access 82 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 83 return init_op 84 85 86@tf_export("data.experimental.table_from_dataset") 87def table_from_dataset(dataset=None, 88 num_oov_buckets=0, 89 vocab_size=None, 90 default_value=None, 91 hasher_spec=lookup_ops.FastHashSpec, 92 key_dtype=dtypes.string, 93 name=None): 94 """Returns a lookup table based on the given dataset. 95 96 This operation constructs a lookup table based on the given dataset of pairs 97 of (key, value). 98 99 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 100 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 101 `default_value`. 102 The bucket ID range is 103 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 104 105 Sample Usages: 106 107 >>> keys = tf.data.Dataset.range(100) 108 >>> values = tf.data.Dataset.range(100).map( 109 ... lambda x: tf.strings.as_string(x * 2)) 110 >>> ds = tf.data.Dataset.zip((keys, values)) 111 >>> table = tf.data.experimental.table_from_dataset( 112 ... ds, default_value='n/a', key_dtype=tf.int64) 113 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() 114 array([b'0', b'2', b'4'], dtype=object) 115 116 Args: 117 dataset: A dataset containing (key, value) pairs. 118 num_oov_buckets: The number of out-of-vocabulary buckets. 119 vocab_size: Number of the elements in the vocabulary, if known. 120 default_value: The value to use for out-of-vocabulary feature values. 121 Defaults to -1. 122 hasher_spec: A `HasherSpec` to specify the hash function to use for 123 assignation of out-of-vocabulary buckets. 124 key_dtype: The `key` data type. 125 name: A name for this op (optional). 126 127 Returns: 128 The lookup table based on the given dataset. 129 130 Raises: 131 ValueError: If 132 * `dataset` does not contain pairs 133 * The 2nd item in the `dataset` pairs has a dtype which is incompatible 134 with `default_value` 135 * `num_oov_buckets` is negative 136 * `vocab_size` is not greater than zero 137 * The `key_dtype` is not integer or string 138 """ 139 elem_spec = dataset.element_spec 140 if len(elem_spec) != 2: 141 raise ValueError("The given dataset must contain pairs.") 142 if default_value is None: 143 default_value = -1 144 if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating): 145 raise ValueError("The dtype of the values requires manually setting a " 146 "compatible default_value.") 147 if num_oov_buckets < 0: 148 raise ValueError( 149 "num_oov_buckets must be greater or equal than 0, got %d." % 150 num_oov_buckets) 151 if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and 152 vocab_size < 1): 153 raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) 154 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 155 raise TypeError("Only integer and string keys are supported.") 156 if vocab_size is not None: 157 if isinstance(vocab_size, ops.Tensor): 158 vocab_size = math_ops.cast(vocab_size, dtypes.int64) 159 dataset = dataset.take(vocab_size) 160 dataset = dataset.apply(assert_cardinality(vocab_size)) 161 with ops.name_scope(name, "string_to_index"): 162 initializer = DatasetInitializer(dataset) 163 with ops.name_scope(None, "hash_table"): 164 table = lookup_ops.StaticHashTableV1(initializer, default_value) 165 if num_oov_buckets: 166 table = lookup_ops.IdTableWithHashBuckets( 167 table, 168 num_oov_buckets=num_oov_buckets, 169 hasher_spec=hasher_spec, 170 key_dtype=key_dtype) 171 return table 172 173 174@tf_export("data.experimental.index_table_from_dataset") 175def index_table_from_dataset(dataset=None, 176 num_oov_buckets=0, 177 vocab_size=None, 178 default_value=-1, 179 hasher_spec=lookup_ops.FastHashSpec, 180 key_dtype=dtypes.string, 181 name=None): 182 """Returns an index lookup table based on the given dataset. 183 184 This operation constructs a lookup table based on the given dataset of keys. 185 186 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 187 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 188 `default_value`. 189 The bucket ID range is 190 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 191 192 Sample Usages: 193 194 >>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2)) 195 >>> table = tf.data.experimental.index_table_from_dataset( 196 ... ds, key_dtype=dtypes.int64) 197 >>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy() 198 array([0, 1, 2]) 199 200 Args: 201 dataset: A dataset of keys. 202 num_oov_buckets: The number of out-of-vocabulary buckets. 203 vocab_size: Number of the elements in the vocabulary, if known. 204 default_value: The value to use for out-of-vocabulary feature values. 205 Defaults to -1. 206 hasher_spec: A `HasherSpec` to specify the hash function to use for 207 assignation of out-of-vocabulary buckets. 208 key_dtype: The `key` data type. 209 name: A name for this op (optional). 210 211 Returns: 212 The lookup table based on the given dataset. 213 214 Raises: 215 ValueError: If 216 * `num_oov_buckets` is negative 217 * `vocab_size` is not greater than zero 218 * The `key_dtype` is not integer or string 219 """ 220 return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)), 221 num_oov_buckets, vocab_size, default_value, 222 hasher_spec, key_dtype, name) 223