1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras string lookup preprocessing layer.""" 16# pylint: disable=g-classes-have-attributes 17 18import numpy as np 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.keras.layers.preprocessing import index_lookup 22from tensorflow.python.keras.layers.preprocessing import table_utils 23from tensorflow.python.util import compat 24from tensorflow.python.util.tf_export import keras_export 25 26 27@keras_export("keras.layers.experimental.preprocessing.StringLookup", v1=[]) 28class StringLookup(index_lookup.IndexLookup): 29 """Maps strings from a vocabulary to integer indices. 30 31 This layer translates a set of arbitrary strings into an integer output via a 32 table-based vocabulary lookup. 33 34 The vocabulary for the layer can be supplied on construction or learned via 35 `adapt()`. During `adapt()`, the layer will analyze a data set, determine the 36 frequency of individual strings tokens, and create a vocabulary from them. If 37 the vocabulary is capped in size, the most frequent tokens will be used to 38 create the vocabulary and all others will be treated as out-of-vocabulary 39 (OOV). 40 41 There are two possible output modes for the layer. 42 When `output_mode` is `"int"`, 43 input strings are converted to their index in the vocabulary (an integer). 44 When `output_mode` is `"multi_hot"`, `"count"`, or `"tf_idf"`, input strings 45 are encoded into an array where each dimension corresponds to an element in 46 the vocabulary. 47 48 The vocabulary can optionally contain a mask token as well as an OOV token 49 (which can optionally occupy multiple indices in the vocabulary, as set 50 by `num_oov_indices`). 51 The position of these tokens in the vocabulary is fixed. When `output_mode` is 52 `"int"`, the vocabulary will begin with the mask token (if set), followed by 53 OOV indices, followed by the rest of the vocabulary. When `output_mode` is 54 `"multi_hot"`, `"count"`, or `"tf_idf"` the vocabulary will begin with OOV 55 indices and instances of the mask token will be dropped. 56 57 Args: 58 max_tokens: The maximum size of the vocabulary for this layer. If None, 59 there is no cap on the size of the vocabulary. Note that this size 60 includes the OOV and mask tokens. Default to None. 61 num_oov_indices: The number of out-of-vocabulary tokens to use. If this 62 value is more than 1, OOV inputs are hashed to determine their OOV value. 63 If this value is 0, OOV inputs will cause an error when calling the layer. 64 Defaults to 1. 65 mask_token: A token that represents masked inputs. When `output_mode` is 66 `"int"`, the token is included in vocabulary and mapped to index 0. In 67 other output modes, the token will not appear in the vocabulary and 68 instances of the mask token in the input will be dropped. If set to None, 69 no mask term will be added. Defaults to `None`. 70 oov_token: Only used when `invert` is True. The token to return for OOV 71 indices. Defaults to `"[UNK]"`. 72 vocabulary: An optional list of tokens, or a path to a text file containing 73 a vocabulary to load into this layer. The file should contain one token 74 per line. If the list or file contains the same token multiple times, an 75 error will be thrown. 76 invert: Only valid when `output_mode` is `"int"`. If True, this layer will 77 map indices to vocabulary items instead of mapping vocabulary items to 78 indices. Default to False. 79 output_mode: Specification for the output of the layer. Defaults to `"int"`. 80 Values can be `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or 81 `"tf_idf"` configuring the layer as follows: 82 - `"int"`: Return the raw integer indices of the input tokens. 83 - `"one_hot"`: Encodes each individual element in the input into an 84 array the same size as the vocabulary, containing a 1 at the element 85 index. If the last dimension is size 1, will encode on that dimension. 86 If the last dimension is not size 1, will append a new dimension for 87 the encoded output. 88 - `"multi_hot"`: Encodes each sample in the input into a single array 89 the same size as the vocabulary, containing a 1 for each vocabulary 90 term present in the sample. Treats the last dimension as the sample 91 dimension, if input shape is (..., sample_length), output shape will 92 be (..., num_tokens). 93 - `"count"`: As `"multi_hot"`, but the int array contains a count of the 94 number of times the token at that index appeared in the sample. 95 - `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is applied to 96 find the value in each token slot. 97 pad_to_max_tokens: Only applicable when `output_mode` is `"multi_hot"`, 98 `"count"`, or `"tf_idf"`. If True, the output will have its feature axis 99 padded to `max_tokens` even if the number of unique tokens in the 100 vocabulary is less than max_tokens, resulting in a tensor of shape 101 [batch_size, max_tokens] regardless of vocabulary size. Defaults to False. 102 sparse: Boolean. Only applicable when `output_mode` is `"multi_hot"`, 103 `"count"`, or `"tf_idf"`. If True, returns a `SparseTensor` instead of a 104 dense `Tensor`. Defaults to False. 105 106 Examples: 107 108 **Creating a lookup layer with a known vocabulary** 109 110 This example creates a lookup layer with a pre-existing vocabulary. 111 112 >>> vocab = ["a", "b", "c", "d"] 113 >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) 114 >>> layer = StringLookup(vocabulary=vocab) 115 >>> layer(data) 116 <tf.Tensor: shape=(2, 3), dtype=int64, numpy= 117 array([[1, 3, 4], 118 [4, 0, 2]])> 119 120 **Creating a lookup layer with an adapted vocabulary** 121 122 This example creates a lookup layer and generates the vocabulary by analyzing 123 the dataset. 124 125 >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) 126 >>> layer = StringLookup() 127 >>> layer.adapt(data) 128 >>> layer.get_vocabulary() 129 ['[UNK]', 'd', 'z', 'c', 'b', 'a'] 130 131 Note that the OOV token [UNK] has been added to the vocabulary. The remaining 132 tokens are sorted by frequency ('d', which has 2 occurrences, is first) then 133 by inverse sort order. 134 135 >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) 136 >>> layer = StringLookup() 137 >>> layer.adapt(data) 138 >>> layer(data) 139 <tf.Tensor: shape=(2, 3), dtype=int64, numpy= 140 array([[5, 3, 1], 141 [1, 2, 4]])> 142 143 **Lookups with multiple OOV indices** 144 145 This example demonstrates how to use a lookup layer with multiple OOV indices. 146 When a layer is created with more than one OOV index, any OOV values are 147 hashed into the number of OOV buckets, distributing OOV values in a 148 deterministic fashion across the set. 149 150 >>> vocab = ["a", "b", "c", "d"] 151 >>> data = tf.constant([["a", "c", "d"], ["m", "z", "b"]]) 152 >>> layer = StringLookup(vocabulary=vocab, num_oov_indices=2) 153 >>> layer(data) 154 <tf.Tensor: shape=(2, 3), dtype=int64, numpy= 155 array([[2, 4, 5], 156 [0, 1, 3]])> 157 158 Note that the output for OOV value 'm' is 0, while the output for OOV value 159 'z' is 1. The in-vocab terms have their output index increased by 1 from 160 earlier examples (a maps to 2, etc) in order to make space for the extra OOV 161 value. 162 163 **One-hot output** 164 165 Configure the layer with `output_mode='one_hot'`. Note that the first 166 `num_oov_indices` dimensions in the ont_hot encoding represent OOV values. 167 168 >>> vocab = ["a", "b", "c", "d"] 169 >>> data = tf.constant(["a", "b", "c", "d", "z"]) 170 >>> layer = StringLookup(vocabulary=vocab, output_mode='one_hot') 171 >>> layer(data) 172 <tf.Tensor: shape=(5, 5), dtype=float32, numpy= 173 array([[0., 1., 0., 0., 0.], 174 [0., 0., 1., 0., 0.], 175 [0., 0., 0., 1., 0.], 176 [0., 0., 0., 0., 1.], 177 [1., 0., 0., 0., 0.]], dtype=float32)> 178 179 **Multi-hot output** 180 181 Configure the layer with `output_mode='multi_hot'`. Note that the first 182 `num_oov_indices` dimensions in the multi_hot encoding represent OOV values. 183 184 >>> vocab = ["a", "b", "c", "d"] 185 >>> data = tf.constant([["a", "c", "d", "d"], ["d", "z", "b", "z"]]) 186 >>> layer = StringLookup(vocabulary=vocab, output_mode='multi_hot') 187 >>> layer(data) 188 <tf.Tensor: shape=(2, 5), dtype=float32, numpy= 189 array([[0., 1., 0., 1., 1.], 190 [1., 0., 1., 0., 1.]], dtype=float32)> 191 192 **Token count output** 193 194 Configure the layer with `output_mode='count'`. As with multi_hot output, the 195 first `num_oov_indices` dimensions in the output represent OOV values. 196 197 >>> vocab = ["a", "b", "c", "d"] 198 >>> data = tf.constant([["a", "c", "d", "d"], ["d", "z", "b", "z"]]) 199 >>> layer = StringLookup(vocabulary=vocab, output_mode='count') 200 >>> layer(data) 201 <tf.Tensor: shape=(2, 5), dtype=float32, numpy= 202 array([[0., 1., 0., 1., 2.], 203 [2., 0., 1., 0., 1.]], dtype=float32)> 204 205 **TF-IDF output** 206 207 Configure the layer with `output_mode='tf_idf'`. As with multi_hot output, the 208 first `num_oov_indices` dimensions in the output represent OOV values. 209 210 Each token bin will output `token_count * idf_weight`, where the idf weights 211 are the inverse document frequency weights per token. These should be provided 212 along with the vocabulary. Note that the `idf_weight` for OOV values will 213 default to the average of all idf weights passed in. 214 215 >>> vocab = ["a", "b", "c", "d"] 216 >>> idf_weights = [0.25, 0.75, 0.6, 0.4] 217 >>> data = tf.constant([["a", "c", "d", "d"], ["d", "z", "b", "z"]]) 218 >>> layer = StringLookup(output_mode='tf_idf') 219 >>> layer.set_vocabulary(vocab, idf_weights=idf_weights) 220 >>> layer(data) 221 <tf.Tensor: shape=(2, 5), dtype=float32, numpy= 222 array([[0. , 0.25, 0. , 0.6 , 0.8 ], 223 [1.0 , 0. , 0.75, 0. , 0.4 ]], dtype=float32)> 224 225 To specify the idf weights for oov values, you will need to pass the entire 226 vocabularly including the leading oov token. 227 228 >>> vocab = ["[UNK]", "a", "b", "c", "d"] 229 >>> idf_weights = [0.9, 0.25, 0.75, 0.6, 0.4] 230 >>> data = tf.constant([["a", "c", "d", "d"], ["d", "z", "b", "z"]]) 231 >>> layer = StringLookup(output_mode='tf_idf') 232 >>> layer.set_vocabulary(vocab, idf_weights=idf_weights) 233 >>> layer(data) 234 <tf.Tensor: shape=(2, 5), dtype=float32, numpy= 235 array([[0. , 0.25, 0. , 0.6 , 0.8 ], 236 [1.8 , 0. , 0.75, 0. , 0.4 ]], dtype=float32)> 237 238 When adapting the layer in tf_idf mode, each input sample will be considered a 239 document, and idf weight per token will be calculated as 240 `log(1 + num_documents / (1 + token_document_count))`. 241 242 **Inverse lookup** 243 244 This example demonstrates how to map indices to strings using this layer. (You 245 can also use adapt() with inverse=True, but for simplicity we'll pass the 246 vocab in this example.) 247 248 >>> vocab = ["a", "b", "c", "d"] 249 >>> data = tf.constant([[1, 3, 4], [4, 0, 2]]) 250 >>> layer = StringLookup(vocabulary=vocab, invert=True) 251 >>> layer(data) 252 <tf.Tensor: shape=(2, 3), dtype=string, numpy= 253 array([[b'a', b'c', b'd'], 254 [b'd', b'[UNK]', b'b']], dtype=object)> 255 256 Note that the first index correspond to the oov token by default. 257 258 259 **Forward and inverse lookup pairs** 260 261 This example demonstrates how to use the vocabulary of a standard lookup 262 layer to create an inverse lookup layer. 263 264 >>> vocab = ["a", "b", "c", "d"] 265 >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) 266 >>> layer = StringLookup(vocabulary=vocab) 267 >>> i_layer = StringLookup(vocabulary=vocab, invert=True) 268 >>> int_data = layer(data) 269 >>> i_layer(int_data) 270 <tf.Tensor: shape=(2, 3), dtype=string, numpy= 271 array([[b'a', b'c', b'd'], 272 [b'd', b'[UNK]', b'b']], dtype=object)> 273 274 In this example, the input value 'z' resulted in an output of '[UNK]', since 275 1000 was not in the vocabulary - it got represented as an OOV, and all OOV 276 values are returned as '[OOV}' in the inverse layer. Also, note that for the 277 inverse to work, you must have already set the forward layer vocabulary 278 either directly or via adapt() before calling get_vocabulary(). 279 """ 280 281 def __init__(self, 282 max_tokens=None, 283 num_oov_indices=1, 284 mask_token=None, 285 oov_token="[UNK]", 286 vocabulary=None, 287 encoding=None, 288 invert=False, 289 output_mode=index_lookup.INT, 290 sparse=False, 291 pad_to_max_tokens=False, 292 **kwargs): 293 allowed_dtypes = [dtypes.string] 294 295 if "dtype" in kwargs and kwargs["dtype"] not in allowed_dtypes: 296 raise ValueError("The value of the dtype argument for StringLookup may " 297 "only be one of %s." % (allowed_dtypes,)) 298 299 if "dtype" not in kwargs: 300 kwargs["dtype"] = dtypes.string 301 302 if encoding is None: 303 encoding = "utf-8" 304 305 self.encoding = encoding 306 307 super(StringLookup, self).__init__( 308 max_tokens=max_tokens, 309 num_oov_indices=num_oov_indices, 310 mask_token=mask_token, 311 oov_token=oov_token, 312 vocabulary=vocabulary, 313 invert=invert, 314 output_mode=output_mode, 315 sparse=sparse, 316 pad_to_max_tokens=pad_to_max_tokens, 317 **kwargs) 318 319 def get_config(self): 320 config = {"encoding": self.encoding} 321 base_config = super(StringLookup, self).get_config() 322 return dict(list(base_config.items()) + list(config.items())) 323 324 def set_vocabulary(self, vocabulary, idf_weights=None): 325 if isinstance(vocabulary, str): 326 if self.output_mode == index_lookup.TF_IDF: 327 raise RuntimeError("Setting vocabulary directly from a file is not " 328 "supported in TF-IDF mode, since this layer cannot " 329 "read files containing TF-IDF weight data. Please " 330 "read the file using Python and set the vocabulary " 331 "and weights by passing lists or arrays to the " 332 "set_vocabulary function's `vocabulary` and " 333 "`idf_weights` args.") 334 vocabulary = table_utils.get_vocabulary_from_file(vocabulary, 335 self.encoding) 336 super().set_vocabulary(vocabulary, idf_weights=idf_weights) 337 338 # Overriden methods from IndexLookup. 339 def _tensor_vocab_to_numpy(self, vocabulary): 340 vocabulary = vocabulary.numpy() 341 return np.array([compat.as_text(x, self.encoding) for x in vocabulary]) 342