• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras category crossing preprocessing layers."""
16# pylint: disable=g-classes-have-attributes
17
18import itertools
19import numpy as np
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.keras.engine import base_layer
27from tensorflow.python.keras.utils import tf_utils
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops.ragged import ragged_array_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.util.tf_export import keras_export
33
34
35@keras_export('keras.layers.experimental.preprocessing.CategoryCrossing')
36class CategoryCrossing(base_layer.Layer):
37  """Category crossing layer.
38
39  This layer concatenates multiple categorical inputs into a single categorical
40  output (similar to Cartesian product). The output dtype is string.
41
42  Usage:
43  >>> inp_1 = ['a', 'b', 'c']
44  >>> inp_2 = ['d', 'e', 'f']
45  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
46  >>> layer([inp_1, inp_2])
47  <tf.Tensor: shape=(3, 1), dtype=string, numpy=
48    array([[b'a_X_d'],
49           [b'b_X_e'],
50           [b'c_X_f']], dtype=object)>
51
52
53  >>> inp_1 = ['a', 'b', 'c']
54  >>> inp_2 = ['d', 'e', 'f']
55  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
56  ...    separator='-')
57  >>> layer([inp_1, inp_2])
58  <tf.Tensor: shape=(3, 1), dtype=string, numpy=
59    array([[b'a-d'],
60           [b'b-e'],
61           [b'c-f']], dtype=object)>
62
63  Args:
64    depth: depth of input crossing. By default None, all inputs are crossed into
65      one output. It can also be an int or tuple/list of ints. Passing an
66      integer will create combinations of crossed outputs with depth up to that
67      integer, i.e., [1, 2, ..., `depth`), and passing a tuple of integers will
68      create crossed outputs with depth for the specified values in the tuple,
69      i.e., `depth`=(N1, N2) will create all possible crossed outputs with depth
70      equal to N1 or N2. Passing `None` means a single crossed output with all
71      inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the
72      output will be [a;b;c;cross(a, b);cross(bc);cross(ca)].
73    separator: A string added between each input being joined. Defaults to
74      '_X_'.
75    name: Name to give to the layer.
76    **kwargs: Keyword arguments to construct a layer.
77
78  Input shape: a list of string or int tensors or sparse tensors of shape
79    `[batch_size, d1, ..., dm]`
80
81  Output shape: a single string or int tensor or sparse tensor of shape
82    `[batch_size, d1, ..., dm]`
83
84  Returns:
85    If any input is `RaggedTensor`, the output is `RaggedTensor`.
86    Else, if any input is `SparseTensor`, the output is `SparseTensor`.
87    Otherwise, the output is `Tensor`.
88
89  Example: (`depth`=None)
90    If the layer receives three inputs:
91    `a=[[1], [4]]`, `b=[[2], [5]]`, `c=[[3], [6]]`
92    the output will be a string tensor:
93    `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
94
95  Example: (`depth` is an integer)
96    With the same input above, and if `depth`=2,
97    the output will be a list of 6 string tensors:
98    `[[b'1'], [b'4']]`
99    `[[b'2'], [b'5']]`
100    `[[b'3'], [b'6']]`
101    `[[b'1_X_2'], [b'4_X_5']]`,
102    `[[b'2_X_3'], [b'5_X_6']]`,
103    `[[b'3_X_1'], [b'6_X_4']]`
104
105  Example: (`depth` is a tuple/list of integers)
106    With the same input above, and if `depth`=(2, 3)
107    the output will be a list of 4 string tensors:
108    `[[b'1_X_2'], [b'4_X_5']]`,
109    `[[b'2_X_3'], [b'5_X_6']]`,
110    `[[b'3_X_1'], [b'6_X_4']]`,
111    `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
112  """
113
114  def __init__(self, depth=None, name=None, separator='_X_', **kwargs):
115    super(CategoryCrossing, self).__init__(name=name, **kwargs)
116    self.depth = depth
117    self.separator = separator
118    if isinstance(depth, (tuple, list)):
119      self._depth_tuple = depth
120    elif depth is not None:
121      self._depth_tuple = tuple([i for i in range(1, depth + 1)])
122
123  def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
124    """Gets the crossed output from a partial list/tuple of inputs."""
125    # If ragged_out=True, convert output from sparse to ragged.
126    if ragged_out:
127      # TODO(momernick): Support separator with ragged_cross.
128      if self.separator != '_X_':
129        raise ValueError('Non-default separator with ragged input is not '
130                         'supported yet, given {}'.format(self.separator))
131      return ragged_array_ops.cross(partial_inputs)
132    elif sparse_out:
133      return sparse_ops.sparse_cross(partial_inputs, separator=self.separator)
134    else:
135      return sparse_ops.sparse_tensor_to_dense(
136          sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
137
138  def _preprocess_input(self, inp):
139    if isinstance(inp, (list, tuple, np.ndarray)):
140      inp = ops.convert_to_tensor_v2_with_dispatch(inp)
141    if inp.shape.rank == 1:
142      inp = array_ops.expand_dims(inp, axis=-1)
143    return inp
144
145  def call(self, inputs):
146    inputs = [self._preprocess_input(inp) for inp in inputs]
147    depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
148    ragged_out = sparse_out = False
149    if any(tf_utils.is_ragged(inp) for inp in inputs):
150      ragged_out = True
151    elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs):
152      sparse_out = True
153
154    outputs = []
155    for depth in depth_tuple:
156      if len(inputs) < depth:
157        raise ValueError(
158            'Number of inputs cannot be less than depth, got {} input tensors, '
159            'and depth {}'.format(len(inputs), depth))
160      for partial_inps in itertools.combinations(inputs, depth):
161        partial_out = self.partial_crossing(
162            partial_inps, ragged_out, sparse_out)
163        outputs.append(partial_out)
164    if sparse_out:
165      return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs)
166    return array_ops.concat(outputs, axis=1)
167
168  def compute_output_shape(self, input_shape):
169    if not isinstance(input_shape, (tuple, list)):
170      raise ValueError('A `CategoryCrossing` layer should be called '
171                       'on a list of inputs.')
172    input_shapes = input_shape
173    batch_size = None
174    for inp_shape in input_shapes:
175      inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list()
176      if len(inp_tensor_shape) != 2:
177        raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes))
178      if batch_size is None:
179        batch_size = inp_tensor_shape[0]
180    # The second dimension is dynamic based on inputs.
181    output_shape = [batch_size, None]
182    return tensor_shape.TensorShape(output_shape)
183
184  def compute_output_signature(self, input_spec):
185    input_shapes = [x.shape for x in input_spec]
186    output_shape = self.compute_output_shape(input_shapes)
187    if any(
188        isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
189        for inp_spec in input_spec):
190      return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
191    elif any(
192        isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
193        for inp_spec in input_spec):
194      return sparse_tensor.SparseTensorSpec(
195          shape=output_shape, dtype=dtypes.string)
196    return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
197
198  def get_config(self):
199    config = {
200        'depth': self.depth,
201        'separator': self.separator,
202    }
203    base_config = super(CategoryCrossing, self).get_config()
204    return dict(list(base_config.items()) + list(config.items()))
205