1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras category crossing preprocessing layers.""" 16# pylint: disable=g-classes-have-attributes 17 18import itertools 19import numpy as np 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_spec 26from tensorflow.python.keras.engine import base_layer 27from tensorflow.python.keras.utils import tf_utils 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import sparse_ops 30from tensorflow.python.ops.ragged import ragged_array_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.util.tf_export import keras_export 33 34 35@keras_export('keras.layers.experimental.preprocessing.CategoryCrossing') 36class CategoryCrossing(base_layer.Layer): 37 """Category crossing layer. 38 39 This layer concatenates multiple categorical inputs into a single categorical 40 output (similar to Cartesian product). The output dtype is string. 41 42 Usage: 43 >>> inp_1 = ['a', 'b', 'c'] 44 >>> inp_2 = ['d', 'e', 'f'] 45 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing() 46 >>> layer([inp_1, inp_2]) 47 <tf.Tensor: shape=(3, 1), dtype=string, numpy= 48 array([[b'a_X_d'], 49 [b'b_X_e'], 50 [b'c_X_f']], dtype=object)> 51 52 53 >>> inp_1 = ['a', 'b', 'c'] 54 >>> inp_2 = ['d', 'e', 'f'] 55 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing( 56 ... separator='-') 57 >>> layer([inp_1, inp_2]) 58 <tf.Tensor: shape=(3, 1), dtype=string, numpy= 59 array([[b'a-d'], 60 [b'b-e'], 61 [b'c-f']], dtype=object)> 62 63 Args: 64 depth: depth of input crossing. By default None, all inputs are crossed into 65 one output. It can also be an int or tuple/list of ints. Passing an 66 integer will create combinations of crossed outputs with depth up to that 67 integer, i.e., [1, 2, ..., `depth`), and passing a tuple of integers will 68 create crossed outputs with depth for the specified values in the tuple, 69 i.e., `depth`=(N1, N2) will create all possible crossed outputs with depth 70 equal to N1 or N2. Passing `None` means a single crossed output with all 71 inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the 72 output will be [a;b;c;cross(a, b);cross(bc);cross(ca)]. 73 separator: A string added between each input being joined. Defaults to 74 '_X_'. 75 name: Name to give to the layer. 76 **kwargs: Keyword arguments to construct a layer. 77 78 Input shape: a list of string or int tensors or sparse tensors of shape 79 `[batch_size, d1, ..., dm]` 80 81 Output shape: a single string or int tensor or sparse tensor of shape 82 `[batch_size, d1, ..., dm]` 83 84 Returns: 85 If any input is `RaggedTensor`, the output is `RaggedTensor`. 86 Else, if any input is `SparseTensor`, the output is `SparseTensor`. 87 Otherwise, the output is `Tensor`. 88 89 Example: (`depth`=None) 90 If the layer receives three inputs: 91 `a=[[1], [4]]`, `b=[[2], [5]]`, `c=[[3], [6]]` 92 the output will be a string tensor: 93 `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` 94 95 Example: (`depth` is an integer) 96 With the same input above, and if `depth`=2, 97 the output will be a list of 6 string tensors: 98 `[[b'1'], [b'4']]` 99 `[[b'2'], [b'5']]` 100 `[[b'3'], [b'6']]` 101 `[[b'1_X_2'], [b'4_X_5']]`, 102 `[[b'2_X_3'], [b'5_X_6']]`, 103 `[[b'3_X_1'], [b'6_X_4']]` 104 105 Example: (`depth` is a tuple/list of integers) 106 With the same input above, and if `depth`=(2, 3) 107 the output will be a list of 4 string tensors: 108 `[[b'1_X_2'], [b'4_X_5']]`, 109 `[[b'2_X_3'], [b'5_X_6']]`, 110 `[[b'3_X_1'], [b'6_X_4']]`, 111 `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` 112 """ 113 114 def __init__(self, depth=None, name=None, separator='_X_', **kwargs): 115 super(CategoryCrossing, self).__init__(name=name, **kwargs) 116 self.depth = depth 117 self.separator = separator 118 if isinstance(depth, (tuple, list)): 119 self._depth_tuple = depth 120 elif depth is not None: 121 self._depth_tuple = tuple([i for i in range(1, depth + 1)]) 122 123 def partial_crossing(self, partial_inputs, ragged_out, sparse_out): 124 """Gets the crossed output from a partial list/tuple of inputs.""" 125 # If ragged_out=True, convert output from sparse to ragged. 126 if ragged_out: 127 # TODO(momernick): Support separator with ragged_cross. 128 if self.separator != '_X_': 129 raise ValueError('Non-default separator with ragged input is not ' 130 'supported yet, given {}'.format(self.separator)) 131 return ragged_array_ops.cross(partial_inputs) 132 elif sparse_out: 133 return sparse_ops.sparse_cross(partial_inputs, separator=self.separator) 134 else: 135 return sparse_ops.sparse_tensor_to_dense( 136 sparse_ops.sparse_cross(partial_inputs, separator=self.separator)) 137 138 def _preprocess_input(self, inp): 139 if isinstance(inp, (list, tuple, np.ndarray)): 140 inp = ops.convert_to_tensor_v2_with_dispatch(inp) 141 if inp.shape.rank == 1: 142 inp = array_ops.expand_dims(inp, axis=-1) 143 return inp 144 145 def call(self, inputs): 146 inputs = [self._preprocess_input(inp) for inp in inputs] 147 depth_tuple = self._depth_tuple if self.depth else (len(inputs),) 148 ragged_out = sparse_out = False 149 if any(tf_utils.is_ragged(inp) for inp in inputs): 150 ragged_out = True 151 elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs): 152 sparse_out = True 153 154 outputs = [] 155 for depth in depth_tuple: 156 if len(inputs) < depth: 157 raise ValueError( 158 'Number of inputs cannot be less than depth, got {} input tensors, ' 159 'and depth {}'.format(len(inputs), depth)) 160 for partial_inps in itertools.combinations(inputs, depth): 161 partial_out = self.partial_crossing( 162 partial_inps, ragged_out, sparse_out) 163 outputs.append(partial_out) 164 if sparse_out: 165 return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs) 166 return array_ops.concat(outputs, axis=1) 167 168 def compute_output_shape(self, input_shape): 169 if not isinstance(input_shape, (tuple, list)): 170 raise ValueError('A `CategoryCrossing` layer should be called ' 171 'on a list of inputs.') 172 input_shapes = input_shape 173 batch_size = None 174 for inp_shape in input_shapes: 175 inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list() 176 if len(inp_tensor_shape) != 2: 177 raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes)) 178 if batch_size is None: 179 batch_size = inp_tensor_shape[0] 180 # The second dimension is dynamic based on inputs. 181 output_shape = [batch_size, None] 182 return tensor_shape.TensorShape(output_shape) 183 184 def compute_output_signature(self, input_spec): 185 input_shapes = [x.shape for x in input_spec] 186 output_shape = self.compute_output_shape(input_shapes) 187 if any( 188 isinstance(inp_spec, ragged_tensor.RaggedTensorSpec) 189 for inp_spec in input_spec): 190 return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string) 191 elif any( 192 isinstance(inp_spec, sparse_tensor.SparseTensorSpec) 193 for inp_spec in input_spec): 194 return sparse_tensor.SparseTensorSpec( 195 shape=output_shape, dtype=dtypes.string) 196 return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string) 197 198 def get_config(self): 199 config = { 200 'depth': self.depth, 201 'separator': self.separator, 202 } 203 base_config = super(CategoryCrossing, self).get_config() 204 return dict(list(base_config.items()) + list(config.items())) 205