• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras CategoryEncoding preprocessing layer."""
16# pylint: disable=g-classes-have-attributes
17
18import numpy as np
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.keras import backend
26from tensorflow.python.keras.engine import base_layer
27from tensorflow.python.keras.utils import layer_utils
28from tensorflow.python.keras.utils import tf_utils
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import bincount_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import sparse_ops
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util.tf_export import keras_export
36
37INT = "int"
38ONE_HOT = "one_hot"
39MULTI_HOT = "multi_hot"
40COUNT = "count"
41
42
43@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding")
44class CategoryEncoding(base_layer.Layer):
45  """Category encoding layer.
46
47  This layer provides options for condensing data into a categorical encoding
48  when the total number of tokens are known in advance. It accepts integer
49  values as inputs and outputs a dense representation (one sample = 1-index
50  tensor of float values representing data about the sample's tokens) of those
51  inputs. For integer inputs where the total number of tokens is not known, see
52  `tf.keras.layers.experimental.preprocessing.IntegerLookup`.
53
54  Examples:
55
56  **One-hot encoding data**
57
58  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
59  ...           num_tokens=4, output_mode="one_hot")
60  >>> layer([3, 2, 0, 1])
61  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
62    array([[0., 0., 0., 1.],
63           [0., 0., 1., 0.],
64           [1., 0., 0., 0.],
65           [0., 1., 0., 0.]], dtype=float32)>
66
67  **Multi-hot encoding data**
68
69  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
70  ...           num_tokens=4, output_mode="multi_hot")
71  >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]])
72  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
73    array([[1., 1., 0., 0.],
74           [1., 0., 0., 0.],
75           [0., 1., 1., 0.],
76           [0., 1., 0., 1.]], dtype=float32)>
77
78  **Using weighted inputs in `"count"` mode**
79
80  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
81  ...           num_tokens=4, output_mode="count")
82  >>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])
83  >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)
84  <tf.Tensor: shape=(4, 4), dtype=float64, numpy=
85    array([[0.1, 0.2, 0. , 0. ],
86           [0.2, 0. , 0. , 0. ],
87           [0. , 0.2, 0.3, 0. ],
88           [0. , 0.2, 0. , 0.4]])>
89
90  Args:
91    num_tokens: The total number of tokens the layer should support. All inputs
92      to the layer must integers in the range 0 <= value < num_tokens or an
93      error will be thrown.
94    output_mode: Specification for the output of the layer.
95      Defaults to `"multi_hot"`. Values can be `"one_hot"`, `"multi_hot"` or
96      `"count"`, configuring the layer as follows:
97        - `"one_hot"`: Encodes each individual element in the input into an
98          array of `num_tokens` size, containing a 1 at the element index. If
99          the last dimension is size 1, will encode on that dimension. If the
100          last dimension is not size 1, will append a new dimension for the
101          encoded output.
102        - `"multi_hot"`: Encodes each sample in the input into a single array
103          of `num_tokens` size, containing a 1 for each vocabulary term present
104          in the sample. Treats the last dimension as the sample dimension, if
105          input shape is (..., sample_length), output shape will be
106          (..., num_tokens).
107        - `"count"`: As `"multi_hot"`, but the int array contains a count of the
108          number of times the token at that index appeared in the sample.
109    sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
110      `Tensor`. Defaults to `False`.
111
112  Call arguments:
113    inputs: A 2D tensor `(samples, timesteps)`.
114    count_weights: A 2D tensor in the same shape as `inputs` indicating the
115      weight for each sample value when summing up in `count` mode. Not used in
116      `"multi_hot"` mode.
117  """
118
119  def __init__(self,
120               num_tokens=None,
121               output_mode=MULTI_HOT,
122               sparse=False,
123               **kwargs):
124    # max_tokens is an old name for the num_tokens arg we continue to support
125    # because of usage.
126    if "max_tokens" in kwargs:
127      logging.warning(
128          "max_tokens is deprecated, please use num_tokens instead.")
129      num_tokens = kwargs["max_tokens"]
130      del kwargs["max_tokens"]
131
132    super(CategoryEncoding, self).__init__(**kwargs)
133
134    # Support deprecated names for output_modes.
135    if output_mode == "binary":
136      output_mode = MULTI_HOT
137    # 'output_mode' must be one of (COUNT, ONE_HOT, MULTI_HOT)
138    layer_utils.validate_string_arg(
139        output_mode,
140        allowable_strings=(COUNT, ONE_HOT, MULTI_HOT),
141        layer_name="CategoryEncoding",
142        arg_name="output_mode")
143
144    if num_tokens is None:
145      raise ValueError("num_tokens must be set to use this layer. If the "
146                       "number of tokens is not known beforehand, use the "
147                       "IntegerLookup layer instead.")
148    if num_tokens < 1:
149      raise ValueError("num_tokens must be >= 1.")
150
151    self.num_tokens = num_tokens
152    self.output_mode = output_mode
153    self.sparse = sparse
154
155  def compute_output_shape(self, input_shape):
156    if not input_shape:
157      return tensor_shape.TensorShape([self.num_tokens])
158    if self.output_mode == ONE_HOT and input_shape[-1] != 1:
159      return tensor_shape.TensorShape(input_shape + [self.num_tokens])
160    else:
161      return tensor_shape.TensorShape(input_shape[:-1] + [self.num_tokens])
162
163  def compute_output_signature(self, input_spec):
164    output_shape = self.compute_output_shape(input_spec.shape.as_list())
165    if self.sparse:
166      return sparse_tensor.SparseTensorSpec(
167          shape=output_shape, dtype=dtypes.int64)
168    else:
169      return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
170
171  def get_config(self):
172    config = {
173        "num_tokens": self.num_tokens,
174        "output_mode": self.output_mode,
175        "sparse": self.sparse,
176    }
177    base_config = super(CategoryEncoding, self).get_config()
178    return dict(list(base_config.items()) + list(config.items()))
179
180  def call(self, inputs, count_weights=None):
181    if isinstance(inputs, (list, np.ndarray)):
182      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
183
184    def expand_dims(inputs, axis):
185      if tf_utils.is_sparse(inputs):
186        return sparse_ops.sparse_expand_dims(inputs, axis)
187      else:
188        return array_ops.expand_dims(inputs, axis)
189
190    original_shape = inputs.shape
191    # In all cases, we should uprank scalar input to a single sample.
192    if inputs.shape.rank == 0:
193      inputs = expand_dims(inputs, -1)
194    # One hot will unprank only if the final output dimension is not already 1.
195    if self.output_mode == ONE_HOT:
196      if inputs.shape[-1] != 1:
197        inputs = expand_dims(inputs, -1)
198
199    # TODO(b/190445202): remove output rank restriction.
200    if inputs.shape.rank > 2:
201      raise ValueError(
202          "Received input shape {}, which would result in output rank {}. "
203          "Currently only outputs up to rank 2 are supported.".format(
204              original_shape, inputs.shape.rank))
205
206    if count_weights is not None and self.output_mode != COUNT:
207      raise ValueError(
208          "`count_weights` is not used when `output_mode` is not `'count'`. "
209          "Received `count_weights={}`.".format(count_weights))
210
211    out_depth = self.num_tokens
212    binary_output = self.output_mode in (MULTI_HOT, ONE_HOT)
213    if isinstance(inputs, sparse_tensor.SparseTensor):
214      max_value = math_ops.reduce_max(inputs.values)
215      min_value = math_ops.reduce_min(inputs.values)
216    else:
217      max_value = math_ops.reduce_max(inputs)
218      min_value = math_ops.reduce_min(inputs)
219    condition = math_ops.logical_and(
220        math_ops.greater(
221            math_ops.cast(out_depth, max_value.dtype), max_value),
222        math_ops.greater_equal(
223            min_value, math_ops.cast(0, min_value.dtype)))
224    assertion = control_flow_ops.Assert(condition, [
225        "Input values must be in the range 0 <= values < num_tokens"
226        " with num_tokens={}".format(out_depth)
227    ])
228    with ops.control_dependencies([assertion]):
229      if self.sparse:
230        return sparse_bincount(inputs, out_depth, binary_output,
231                               count_weights)
232      else:
233        return dense_bincount(inputs, out_depth, binary_output,
234                              count_weights)
235
236
237def sparse_bincount(inputs, out_depth, binary_output, count_weights=None):
238  """Apply binary or count encoding to an input and return a sparse tensor."""
239  result = bincount_ops.sparse_bincount(
240      inputs,
241      weights=count_weights,
242      minlength=out_depth,
243      maxlength=out_depth,
244      axis=-1,
245      binary_output=binary_output)
246  if inputs.shape.rank == 1:
247    output_shape = (out_depth,)
248  else:
249    result = math_ops.cast(result, backend.floatx())
250    batch_size = array_ops.shape(result)[0]
251    output_shape = (batch_size, out_depth)
252  result = sparse_tensor.SparseTensor(
253      indices=result.indices,
254      values=result.values,
255      dense_shape=output_shape)
256  return result
257
258
259def dense_bincount(inputs, out_depth, binary_output, count_weights=None):
260  """Apply binary or count encoding to an input."""
261  result = bincount_ops.bincount(
262      inputs,
263      weights=count_weights,
264      minlength=out_depth,
265      maxlength=out_depth,
266      dtype=backend.floatx(),
267      axis=-1,
268      binary_output=binary_output)
269  if inputs.shape.rank == 1:
270    result.set_shape(tensor_shape.TensorShape((out_depth,)))
271  else:
272    batch_size = inputs.shape.as_list()[0]
273    result.set_shape(tensor_shape.TensorShape((batch_size, out_depth)))
274  return result
275