1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for working with tf.lookup tables in Keras.""" 16 17import collections 18import os 19import numpy as np 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import sparse_tensor 23from tensorflow.python.keras.utils import tf_utils 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import lookup_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import string_ops 28from tensorflow.python.ops.ragged import ragged_functional_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.ops.ragged import ragged_tensor_value 31from tensorflow.python.platform import gfile 32 33 34class TableHandler(object): 35 """Wrapper object that holds a lookup table and provides accessors.""" 36 37 def __init__(self, 38 table, 39 oov_tokens=None, 40 mask_token=None, 41 mask_value=0): 42 self.table = table 43 self.mutable = isinstance(table, lookup_ops.MutableHashTable) 44 self.mask_token = mask_token 45 self.mask_value = mask_value 46 47 if oov_tokens is None: 48 self.oov_tokens = oov_tokens 49 else: 50 if not isinstance(oov_tokens, (list, tuple, np.ndarray)): 51 oov_tokens = [oov_tokens] 52 self.oov_tokens = math_ops.cast(oov_tokens, table._value_dtype) # pylint: disable=protected-access 53 54 def table_size(self): 55 return self.table.size().numpy() 56 57 def clear(self): 58 if not self.mutable: 59 return RuntimeError("Unable to clear a statically-backed table.") 60 61 keys, _ = self.table.export() 62 self.table.remove(keys) 63 64 def insert(self, keys, values): 65 """Insert values into the backed table.""" 66 if not self.mutable: 67 raise RuntimeError("Unable to insert into a statically-backed table.") 68 69 if len(values) != len(keys): 70 raise RuntimeError("Size mismatch between values and key arrays. " 71 "Keys had size %s, values had size %s." % 72 (len(keys), len(values))) 73 keys = ops.convert_to_tensor_v2_with_dispatch( 74 keys, dtype=self.table._key_dtype) # pylint: disable=protected-access 75 values = ops.convert_to_tensor_v2_with_dispatch( 76 values, dtype=self.table._value_dtype) # pylint: disable=protected-access 77 if values.shape.ndims != 1: 78 raise ValueError("`values` must be 1-dimensional, got an input with " 79 " %s dimensions." % values.shape.ndims) 80 self.table.insert(keys, values) 81 82 def _replace_oov_buckets(self, inputs, lookups): 83 """Replace the default OOV value with one of the OOV bucket values.""" 84 if self.oov_tokens is None: 85 return lookups 86 87 num_oov_elements = self.oov_tokens.shape.num_elements() 88 if inputs.dtype.is_integer: 89 oov_indices = math_ops.floormod(inputs, num_oov_elements) 90 else: 91 oov_indices = string_ops.string_to_hash_bucket_fast( 92 inputs, num_buckets=num_oov_elements) 93 94 oov_values = array_ops.gather(self.oov_tokens, oov_indices) 95 oov_locations = math_ops.equal(lookups, self.table._default_value) # pylint: disable=protected-access 96 97 return array_ops.where(oov_locations, oov_values, lookups) 98 99 def _lookup_and_mask(self, inputs): 100 """Return a lookup with any location with the mask_token masked to 0.""" 101 lookups = self.table.lookup(inputs) 102 # If we don't need to handle masking, return the lookup values directly. 103 if self.mask_token is None: 104 return lookups 105 106 # Inject 0s wherever the mask token was in the inputs. 107 mask_locations = math_ops.equal(inputs, self.mask_token) 108 return array_ops.where_v2( 109 mask_locations, 110 math_ops.cast(self.mask_value, self.table._value_dtype), # pylint: disable=protected-access 111 lookups) # pylint: disable=protected-access 112 113 def _ragged_lookup(self, inputs): 114 """Perform a table lookup on a ragged tensor.""" 115 # The table lookup ops don't natively support ragged tensors, so if we have 116 # a RT we need to use map_flat_values to look up every element. 117 indexed_data = ragged_functional_ops.map_flat_values( 118 self._lookup_and_mask, inputs) 119 indexed_data = ragged_functional_ops.map_flat_values( 120 self._replace_oov_buckets, inputs, indexed_data) 121 # table.lookup is not shape-preserving, so we need to set the shape here. 122 indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access 123 # Composite tensors can pass tensor values through, which will cause 124 # errors if all operations in the TF graph do so. We can break this chain 125 # with an identity here. 126 return array_ops.identity(indexed_data) 127 128 def _sparse_lookup(self, inputs): 129 """Perform a table lookup on a sparse tensor.""" 130 values = self._lookup_and_mask(inputs.values) 131 values = self._replace_oov_buckets(inputs.values, values) 132 indexed_data = sparse_tensor.SparseTensor(inputs.indices, values, 133 inputs.dense_shape) 134 # Composite tensors can pass tensor values through, which will cause 135 # errors if all operations in the TF graph do so. We can break this chain 136 # with an identity here. 137 return array_ops.identity(indexed_data) 138 139 def _tensor_lookup(self, inputs): 140 """Perform a table lookup on a tf.tensor.""" 141 values = self._lookup_and_mask(inputs) 142 indexed_data = self._replace_oov_buckets(inputs, values) 143 # (b/149446477): output does not preserve input shape. 144 indexed_data.set_shape(inputs.shape) 145 return indexed_data 146 147 def lookup(self, inputs): 148 """Perform a table lookup.""" 149 # Sparse tensors don't play nicely with tensor conversion, so we handle 150 # them before attempting to convert lists or arrays to tensors. 151 if isinstance( 152 inputs, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 153 return self._sparse_lookup(inputs) 154 155 if tf_utils.is_ragged(inputs): 156 if isinstance(inputs, ragged_tensor_value.RaggedTensorValue): 157 flat_values = ops.convert_to_tensor_v2_with_dispatch( 158 value=inputs.flat_values, name="flat_values") 159 inputs = ragged_tensor.RaggedTensor.from_nested_row_splits( 160 flat_values, inputs.nested_row_splits, validate=False) 161 return self._ragged_lookup(inputs) 162 163 # For normal tensor inputs 164 inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) 165 return self._tensor_lookup(inputs) 166 167 168def num_tokens_in_file(vocabulary_path): 169 """Count the number of lines in a vocab file to get the number of tokens.""" 170 num_tokens = 0 171 with gfile.GFile(vocabulary_path, "r") as reader: 172 text = reader.readline() 173 while text: 174 num_tokens += 1 175 text = reader.readline() 176 177 return num_tokens 178 179 180def get_vocabulary_from_file(vocabulary_path, encoding="utf-8"): 181 """Read a vocabulary in from a file.""" 182 vocab = [] 183 with gfile.GFile(vocabulary_path, "r") as reader: 184 while True: 185 # Get the next line (incl. \n), and break if nothing is left to read. 186 text = reader.readline() 187 if not text: 188 break 189 190 # Convert the raw text and strip whitespace. 191 if isinstance(text, str): 192 token = text 193 elif isinstance(text, bytes): 194 token = text.decode(encoding, "ignore") 195 token = token.rstrip(os.linesep) 196 vocab.append(token) 197 return vocab 198 199 200def find_repeated_tokens(vocabulary): 201 """Return all repeated tokens in a vocabulary.""" 202 vocabulary_set = set(vocabulary) 203 if len(vocabulary) != len(vocabulary_set): 204 return [ 205 item for item, count in collections.Counter(vocabulary).items() 206 if count > 1 207 ] 208 else: 209 return [] 210