• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for working with tf.lookup tables in Keras."""
16
17import collections
18import os
19import numpy as np
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.keras.utils import tf_utils
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import lookup_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import string_ops
28from tensorflow.python.ops.ragged import ragged_functional_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.ops.ragged import ragged_tensor_value
31from tensorflow.python.platform import gfile
32
33
34class TableHandler(object):
35  """Wrapper object that holds a lookup table and provides accessors."""
36
37  def __init__(self,
38               table,
39               oov_tokens=None,
40               mask_token=None,
41               mask_value=0):
42    self.table = table
43    self.mutable = isinstance(table, lookup_ops.MutableHashTable)
44    self.mask_token = mask_token
45    self.mask_value = mask_value
46
47    if oov_tokens is None:
48      self.oov_tokens = oov_tokens
49    else:
50      if not isinstance(oov_tokens, (list, tuple, np.ndarray)):
51        oov_tokens = [oov_tokens]
52      self.oov_tokens = math_ops.cast(oov_tokens, table._value_dtype)  # pylint: disable=protected-access
53
54  def table_size(self):
55    return self.table.size().numpy()
56
57  def clear(self):
58    if not self.mutable:
59      return RuntimeError("Unable to clear a statically-backed table.")
60
61    keys, _ = self.table.export()
62    self.table.remove(keys)
63
64  def insert(self, keys, values):
65    """Insert values into the backed table."""
66    if not self.mutable:
67      raise RuntimeError("Unable to insert into a statically-backed table.")
68
69    if len(values) != len(keys):
70      raise RuntimeError("Size mismatch between values and key arrays. "
71                         "Keys had size %s, values had size %s." %
72                         (len(keys), len(values)))
73    keys = ops.convert_to_tensor_v2_with_dispatch(
74        keys, dtype=self.table._key_dtype)  # pylint: disable=protected-access
75    values = ops.convert_to_tensor_v2_with_dispatch(
76        values, dtype=self.table._value_dtype)  # pylint: disable=protected-access
77    if values.shape.ndims != 1:
78      raise ValueError("`values` must be 1-dimensional, got an input with "
79                       " %s dimensions." % values.shape.ndims)
80    self.table.insert(keys, values)
81
82  def _replace_oov_buckets(self, inputs, lookups):
83    """Replace the default OOV value with one of the OOV bucket values."""
84    if self.oov_tokens is None:
85      return lookups
86
87    num_oov_elements = self.oov_tokens.shape.num_elements()
88    if inputs.dtype.is_integer:
89      oov_indices = math_ops.floormod(inputs, num_oov_elements)
90    else:
91      oov_indices = string_ops.string_to_hash_bucket_fast(
92          inputs, num_buckets=num_oov_elements)
93
94    oov_values = array_ops.gather(self.oov_tokens, oov_indices)
95    oov_locations = math_ops.equal(lookups, self.table._default_value)  # pylint: disable=protected-access
96
97    return array_ops.where(oov_locations, oov_values, lookups)
98
99  def _lookup_and_mask(self, inputs):
100    """Return a lookup with any location with the mask_token masked to 0."""
101    lookups = self.table.lookup(inputs)
102    # If we don't need to handle masking, return the lookup values directly.
103    if self.mask_token is None:
104      return lookups
105
106    # Inject 0s wherever the mask token was in the inputs.
107    mask_locations = math_ops.equal(inputs, self.mask_token)
108    return array_ops.where_v2(
109        mask_locations,
110        math_ops.cast(self.mask_value, self.table._value_dtype),  # pylint: disable=protected-access
111        lookups)  # pylint: disable=protected-access
112
113  def _ragged_lookup(self, inputs):
114    """Perform a table lookup on a ragged tensor."""
115    # The table lookup ops don't natively support ragged tensors, so if we have
116    # a RT we need to use map_flat_values to look up every element.
117    indexed_data = ragged_functional_ops.map_flat_values(
118        self._lookup_and_mask, inputs)
119    indexed_data = ragged_functional_ops.map_flat_values(
120        self._replace_oov_buckets, inputs, indexed_data)
121    # table.lookup is not shape-preserving, so we need to set the shape here.
122    indexed_data._set_shape(inputs.shape)  # pylint: disable=protected-access
123    # Composite tensors can pass tensor values through, which will cause
124    # errors if all operations in the TF graph do so. We can break this chain
125    # with an identity here.
126    return array_ops.identity(indexed_data)
127
128  def _sparse_lookup(self, inputs):
129    """Perform a table lookup on a sparse tensor."""
130    values = self._lookup_and_mask(inputs.values)
131    values = self._replace_oov_buckets(inputs.values, values)
132    indexed_data = sparse_tensor.SparseTensor(inputs.indices, values,
133                                              inputs.dense_shape)
134    # Composite tensors can pass tensor values through, which will cause
135    # errors if all operations in the TF graph do so. We can break this chain
136    # with an identity here.
137    return array_ops.identity(indexed_data)
138
139  def _tensor_lookup(self, inputs):
140    """Perform a table lookup on a tf.tensor."""
141    values = self._lookup_and_mask(inputs)
142    indexed_data = self._replace_oov_buckets(inputs, values)
143    # (b/149446477): output does not preserve input shape.
144    indexed_data.set_shape(inputs.shape)
145    return indexed_data
146
147  def lookup(self, inputs):
148    """Perform a table lookup."""
149    # Sparse tensors don't play nicely with tensor conversion, so we handle
150    # them before attempting to convert lists or arrays to tensors.
151    if isinstance(
152        inputs, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
153      return self._sparse_lookup(inputs)
154
155    if tf_utils.is_ragged(inputs):
156      if isinstance(inputs, ragged_tensor_value.RaggedTensorValue):
157        flat_values = ops.convert_to_tensor_v2_with_dispatch(
158            value=inputs.flat_values, name="flat_values")
159        inputs = ragged_tensor.RaggedTensor.from_nested_row_splits(
160            flat_values, inputs.nested_row_splits, validate=False)
161      return self._ragged_lookup(inputs)
162
163    # For normal tensor inputs
164    inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
165    return self._tensor_lookup(inputs)
166
167
168def num_tokens_in_file(vocabulary_path):
169  """Count the number of lines in a vocab file to get the number of tokens."""
170  num_tokens = 0
171  with gfile.GFile(vocabulary_path, "r") as reader:
172    text = reader.readline()
173    while text:
174      num_tokens += 1
175      text = reader.readline()
176
177  return num_tokens
178
179
180def get_vocabulary_from_file(vocabulary_path, encoding="utf-8"):
181  """Read a vocabulary in from a file."""
182  vocab = []
183  with gfile.GFile(vocabulary_path, "r") as reader:
184    while True:
185      # Get the next line (incl. \n), and break if nothing is left to read.
186      text = reader.readline()
187      if not text:
188        break
189
190      # Convert the raw text and strip whitespace.
191      if isinstance(text, str):
192        token = text
193      elif isinstance(text, bytes):
194        token = text.decode(encoding, "ignore")
195      token = token.rstrip(os.linesep)
196      vocab.append(token)
197  return vocab
198
199
200def find_repeated_tokens(vocabulary):
201  """Return all repeated tokens in a vocabulary."""
202  vocabulary_set = set(vocabulary)
203  if len(vocabulary) != len(vocabulary_set):
204    return [
205        item for item, count in collections.Counter(vocabulary).items()
206        if count > 1
207    ]
208  else:
209    return []
210