1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras CategoryEncoding preprocessing layer.""" 16# pylint: disable=g-classes-have-attributes 17 18import numpy as np 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import sparse_tensor 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.keras import backend 26from tensorflow.python.keras.engine import base_layer 27from tensorflow.python.keras.utils import layer_utils 28from tensorflow.python.keras.utils import tf_utils 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import bincount_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import sparse_ops 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util.tf_export import keras_export 36 37INT = "int" 38ONE_HOT = "one_hot" 39MULTI_HOT = "multi_hot" 40COUNT = "count" 41 42 43@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding") 44class CategoryEncoding(base_layer.Layer): 45 """Category encoding layer. 46 47 This layer provides options for condensing data into a categorical encoding 48 when the total number of tokens are known in advance. It accepts integer 49 values as inputs and outputs a dense representation (one sample = 1-index 50 tensor of float values representing data about the sample's tokens) of those 51 inputs. For integer inputs where the total number of tokens is not known, see 52 `tf.keras.layers.experimental.preprocessing.IntegerLookup`. 53 54 Examples: 55 56 **One-hot encoding data** 57 58 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding( 59 ... num_tokens=4, output_mode="one_hot") 60 >>> layer([3, 2, 0, 1]) 61 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 62 array([[0., 0., 0., 1.], 63 [0., 0., 1., 0.], 64 [1., 0., 0., 0.], 65 [0., 1., 0., 0.]], dtype=float32)> 66 67 **Multi-hot encoding data** 68 69 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding( 70 ... num_tokens=4, output_mode="multi_hot") 71 >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]]) 72 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 73 array([[1., 1., 0., 0.], 74 [1., 0., 0., 0.], 75 [0., 1., 1., 0.], 76 [0., 1., 0., 1.]], dtype=float32)> 77 78 **Using weighted inputs in `"count"` mode** 79 80 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding( 81 ... num_tokens=4, output_mode="count") 82 >>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]]) 83 >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights) 84 <tf.Tensor: shape=(4, 4), dtype=float64, numpy= 85 array([[0.1, 0.2, 0. , 0. ], 86 [0.2, 0. , 0. , 0. ], 87 [0. , 0.2, 0.3, 0. ], 88 [0. , 0.2, 0. , 0.4]])> 89 90 Args: 91 num_tokens: The total number of tokens the layer should support. All inputs 92 to the layer must integers in the range 0 <= value < num_tokens or an 93 error will be thrown. 94 output_mode: Specification for the output of the layer. 95 Defaults to `"multi_hot"`. Values can be `"one_hot"`, `"multi_hot"` or 96 `"count"`, configuring the layer as follows: 97 - `"one_hot"`: Encodes each individual element in the input into an 98 array of `num_tokens` size, containing a 1 at the element index. If 99 the last dimension is size 1, will encode on that dimension. If the 100 last dimension is not size 1, will append a new dimension for the 101 encoded output. 102 - `"multi_hot"`: Encodes each sample in the input into a single array 103 of `num_tokens` size, containing a 1 for each vocabulary term present 104 in the sample. Treats the last dimension as the sample dimension, if 105 input shape is (..., sample_length), output shape will be 106 (..., num_tokens). 107 - `"count"`: As `"multi_hot"`, but the int array contains a count of the 108 number of times the token at that index appeared in the sample. 109 sparse: Boolean. If true, returns a `SparseTensor` instead of a dense 110 `Tensor`. Defaults to `False`. 111 112 Call arguments: 113 inputs: A 2D tensor `(samples, timesteps)`. 114 count_weights: A 2D tensor in the same shape as `inputs` indicating the 115 weight for each sample value when summing up in `count` mode. Not used in 116 `"multi_hot"` mode. 117 """ 118 119 def __init__(self, 120 num_tokens=None, 121 output_mode=MULTI_HOT, 122 sparse=False, 123 **kwargs): 124 # max_tokens is an old name for the num_tokens arg we continue to support 125 # because of usage. 126 if "max_tokens" in kwargs: 127 logging.warning( 128 "max_tokens is deprecated, please use num_tokens instead.") 129 num_tokens = kwargs["max_tokens"] 130 del kwargs["max_tokens"] 131 132 super(CategoryEncoding, self).__init__(**kwargs) 133 134 # Support deprecated names for output_modes. 135 if output_mode == "binary": 136 output_mode = MULTI_HOT 137 # 'output_mode' must be one of (COUNT, ONE_HOT, MULTI_HOT) 138 layer_utils.validate_string_arg( 139 output_mode, 140 allowable_strings=(COUNT, ONE_HOT, MULTI_HOT), 141 layer_name="CategoryEncoding", 142 arg_name="output_mode") 143 144 if num_tokens is None: 145 raise ValueError("num_tokens must be set to use this layer. If the " 146 "number of tokens is not known beforehand, use the " 147 "IntegerLookup layer instead.") 148 if num_tokens < 1: 149 raise ValueError("num_tokens must be >= 1.") 150 151 self.num_tokens = num_tokens 152 self.output_mode = output_mode 153 self.sparse = sparse 154 155 def compute_output_shape(self, input_shape): 156 if not input_shape: 157 return tensor_shape.TensorShape([self.num_tokens]) 158 if self.output_mode == ONE_HOT and input_shape[-1] != 1: 159 return tensor_shape.TensorShape(input_shape + [self.num_tokens]) 160 else: 161 return tensor_shape.TensorShape(input_shape[:-1] + [self.num_tokens]) 162 163 def compute_output_signature(self, input_spec): 164 output_shape = self.compute_output_shape(input_spec.shape.as_list()) 165 if self.sparse: 166 return sparse_tensor.SparseTensorSpec( 167 shape=output_shape, dtype=dtypes.int64) 168 else: 169 return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) 170 171 def get_config(self): 172 config = { 173 "num_tokens": self.num_tokens, 174 "output_mode": self.output_mode, 175 "sparse": self.sparse, 176 } 177 base_config = super(CategoryEncoding, self).get_config() 178 return dict(list(base_config.items()) + list(config.items())) 179 180 def call(self, inputs, count_weights=None): 181 if isinstance(inputs, (list, np.ndarray)): 182 inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) 183 184 def expand_dims(inputs, axis): 185 if tf_utils.is_sparse(inputs): 186 return sparse_ops.sparse_expand_dims(inputs, axis) 187 else: 188 return array_ops.expand_dims(inputs, axis) 189 190 original_shape = inputs.shape 191 # In all cases, we should uprank scalar input to a single sample. 192 if inputs.shape.rank == 0: 193 inputs = expand_dims(inputs, -1) 194 # One hot will unprank only if the final output dimension is not already 1. 195 if self.output_mode == ONE_HOT: 196 if inputs.shape[-1] != 1: 197 inputs = expand_dims(inputs, -1) 198 199 # TODO(b/190445202): remove output rank restriction. 200 if inputs.shape.rank > 2: 201 raise ValueError( 202 "Received input shape {}, which would result in output rank {}. " 203 "Currently only outputs up to rank 2 are supported.".format( 204 original_shape, inputs.shape.rank)) 205 206 if count_weights is not None and self.output_mode != COUNT: 207 raise ValueError( 208 "`count_weights` is not used when `output_mode` is not `'count'`. " 209 "Received `count_weights={}`.".format(count_weights)) 210 211 out_depth = self.num_tokens 212 binary_output = self.output_mode in (MULTI_HOT, ONE_HOT) 213 if isinstance(inputs, sparse_tensor.SparseTensor): 214 max_value = math_ops.reduce_max(inputs.values) 215 min_value = math_ops.reduce_min(inputs.values) 216 else: 217 max_value = math_ops.reduce_max(inputs) 218 min_value = math_ops.reduce_min(inputs) 219 condition = math_ops.logical_and( 220 math_ops.greater( 221 math_ops.cast(out_depth, max_value.dtype), max_value), 222 math_ops.greater_equal( 223 min_value, math_ops.cast(0, min_value.dtype))) 224 assertion = control_flow_ops.Assert(condition, [ 225 "Input values must be in the range 0 <= values < num_tokens" 226 " with num_tokens={}".format(out_depth) 227 ]) 228 with ops.control_dependencies([assertion]): 229 if self.sparse: 230 return sparse_bincount(inputs, out_depth, binary_output, 231 count_weights) 232 else: 233 return dense_bincount(inputs, out_depth, binary_output, 234 count_weights) 235 236 237def sparse_bincount(inputs, out_depth, binary_output, count_weights=None): 238 """Apply binary or count encoding to an input and return a sparse tensor.""" 239 result = bincount_ops.sparse_bincount( 240 inputs, 241 weights=count_weights, 242 minlength=out_depth, 243 maxlength=out_depth, 244 axis=-1, 245 binary_output=binary_output) 246 if inputs.shape.rank == 1: 247 output_shape = (out_depth,) 248 else: 249 result = math_ops.cast(result, backend.floatx()) 250 batch_size = array_ops.shape(result)[0] 251 output_shape = (batch_size, out_depth) 252 result = sparse_tensor.SparseTensor( 253 indices=result.indices, 254 values=result.values, 255 dense_shape=output_shape) 256 return result 257 258 259def dense_bincount(inputs, out_depth, binary_output, count_weights=None): 260 """Apply binary or count encoding to an input.""" 261 result = bincount_ops.bincount( 262 inputs, 263 weights=count_weights, 264 minlength=out_depth, 265 maxlength=out_depth, 266 dtype=backend.floatx(), 267 axis=-1, 268 binary_output=binary_output) 269 if inputs.shape.rank == 1: 270 result.set_shape(tensor_shape.TensorShape((out_depth,))) 271 else: 272 batch_size = inputs.shape.as_list()[0] 273 result.set_shape(tensor_shape.TensorShape((batch_size, out_depth))) 274 return result 275