• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras-based einsum dense layer."""
16# pylint: disable=g-classes-have-attributes
17
18import re
19
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.keras import activations
22from tensorflow.python.keras import constraints
23from tensorflow.python.keras import initializers
24from tensorflow.python.keras import regularizers
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.ops import special_math_ops
27from tensorflow.python.util.tf_export import keras_export
28
29
30@keras_export("keras.layers.experimental.EinsumDense")
31class EinsumDense(Layer):
32  """A layer that uses tf.einsum as the backing computation.
33
34  This layer can perform einsum calculations of arbitrary dimensionality.
35
36  Args:
37    equation: An equation describing the einsum to perform. This equation must
38      be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
39      `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
40      expression sequence.
41    output_shape: The expected shape of the output tensor (excluding the batch
42      dimension and any dimensions represented by ellipses). You can specify
43      None for any dimension that is unknown or can be inferred from the input
44      shape.
45    activation: Activation function to use. If you don't specify anything, no
46      activation is applied (that is, a "linear" activation: `a(x) = x`).
47    bias_axes: A string containing the output dimension(s) to apply a bias to.
48      Each character in the `bias_axes` string should correspond to a character
49      in the output portion of the `equation` string.
50    kernel_initializer: Initializer for the `kernel` weights matrix.
51    bias_initializer: Initializer for the bias vector.
52    kernel_regularizer: Regularizer function applied to the `kernel` weights
53      matrix.
54    bias_regularizer: Regularizer function applied to the bias vector.
55    activity_regularizer: Regularizer function applied to the output of the
56      layer (its "activation")..
57    kernel_constraint: Constraint function applied to the `kernel` weights
58      matrix.
59    bias_constraint: Constraint function applied to the bias vector.
60
61  Examples:
62
63  **Biased dense layer with einsums**
64
65  This example shows how to instantiate a standard Keras dense layer using
66  einsum operations. This example is equivalent to
67  `tf.keras.layers.Dense(64, use_bias=True)`.
68
69  >>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c")
70  >>> input_tensor = tf.keras.Input(shape=[32])
71  >>> output_tensor = layer(input_tensor)
72  >>> output_tensor
73  <... shape=(None, 64) dtype=...>
74
75  **Applying a dense layer to a sequence**
76
77  This example shows how to instantiate a layer that applies the same dense
78  operation to every element in a sequence. Here, the 'output_shape' has two
79  values (since there are two non-batch dimensions in the output); the first
80  dimension in the output_shape is `None`, because the sequence dimension `b`
81  has an unknown shape.
82
83  >>> layer = EinsumDense("abc,cd->abd",
84  ...                     output_shape=(None, 64),
85  ...                     bias_axes="d")
86  >>> input_tensor = tf.keras.Input(shape=[32, 128])
87  >>> output_tensor = layer(input_tensor)
88  >>> output_tensor
89  <... shape=(None, 32, 64) dtype=...>
90
91  **Applying a dense layer to a sequence using ellipses**
92
93  This example shows how to instantiate a layer that applies the same dense
94  operation to every element in a sequence, but uses the ellipsis notation
95  instead of specifying the batch and sequence dimensions.
96
97  Because we are using ellipsis notation and have specified only one axis, the
98  output_shape arg is a single value. When instantiated in this way, the layer
99  can handle any number of sequence dimensions - including the case where no
100  sequence dimension exists.
101
102  >>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y")
103  >>> input_tensor = tf.keras.Input(shape=[32, 128])
104  >>> output_tensor = layer(input_tensor)
105  >>> output_tensor
106  <... shape=(None, 32, 64) dtype=...>
107  """
108
109  def __init__(self,
110               equation,
111               output_shape,
112               activation=None,
113               bias_axes=None,
114               kernel_initializer="glorot_uniform",
115               bias_initializer="zeros",
116               kernel_regularizer=None,
117               bias_regularizer=None,
118               activity_regularizer=None,
119               kernel_constraint=None,
120               bias_constraint=None,
121               **kwargs):
122    super(EinsumDense, self).__init__(**kwargs)
123    self.equation = equation
124    if isinstance(output_shape, int):
125      self.partial_output_shape = [output_shape]
126    else:
127      self.partial_output_shape = list(output_shape)
128    self.bias_axes = bias_axes
129    self.activation = activations.get(activation)
130    self.kernel_initializer = initializers.get(kernel_initializer)
131    self.bias_initializer = initializers.get(bias_initializer)
132    self.kernel_regularizer = regularizers.get(kernel_regularizer)
133    self.bias_regularizer = regularizers.get(bias_regularizer)
134    self.kernel_constraint = constraints.get(kernel_constraint)
135    self.bias_constraint = constraints.get(bias_constraint)
136
137  def build(self, input_shape):
138    input_shape = tensor_shape.TensorShape(input_shape)
139    shape_data = _analyze_einsum_string(self.equation,
140                                        self.bias_axes,
141                                        input_shape,
142                                        self.partial_output_shape)
143    kernel_shape, bias_shape, self.full_output_shape = shape_data
144    self.kernel = self.add_weight(
145        "kernel",
146        shape=kernel_shape,
147        initializer=self.kernel_initializer,
148        regularizer=self.kernel_regularizer,
149        constraint=self.kernel_constraint,
150        dtype=self.dtype,
151        trainable=True)
152
153    if bias_shape is not None:
154      self.bias = self.add_weight(
155          "bias",
156          shape=bias_shape,
157          initializer=self.bias_initializer,
158          regularizer=self.bias_regularizer,
159          constraint=self.bias_constraint,
160          dtype=self.dtype,
161          trainable=True)
162    else:
163      self.bias = None
164    super(EinsumDense, self).build(input_shape)
165
166  def compute_output_shape(self, _):
167    return tensor_shape.TensorShape(self.full_output_shape)
168
169  def get_config(self):
170    config = {
171        "output_shape":
172            self.partial_output_shape,
173        "equation":
174            self.equation,
175        "activation":
176            activations.serialize(self.activation),
177        "bias_axes":
178            self.bias_axes,
179        "kernel_initializer":
180            initializers.serialize(self.kernel_initializer),
181        "bias_initializer":
182            initializers.serialize(self.bias_initializer),
183        "kernel_regularizer":
184            regularizers.serialize(self.kernel_regularizer),
185        "bias_regularizer":
186            regularizers.serialize(self.bias_regularizer),
187        "activity_regularizer":
188            regularizers.serialize(self.activity_regularizer),
189        "kernel_constraint":
190            constraints.serialize(self.kernel_constraint),
191        "bias_constraint":
192            constraints.serialize(self.bias_constraint),
193    }
194    base_config = super(EinsumDense, self).get_config()
195    return dict(list(base_config.items()) + list(config.items()))
196
197  def call(self, inputs):
198    ret = special_math_ops.einsum(self.equation, inputs, self.kernel)
199    if self.bias is not None:
200      ret += self.bias
201    if self.activation is not None:
202      ret = self.activation(ret)
203    return ret
204
205
206def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
207  """Analyzes an einsum string to determine the required weight shape."""
208
209  dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
210
211  # This is the case where no ellipses are present in the string.
212  split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)",
213                          dot_replaced_string)
214  if split_string:
215    return _analyze_split_string(split_string, bias_axes, input_shape,
216                                 output_shape)
217
218  # This is the case where ellipses are present on the left.
219  split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)",
220                          dot_replaced_string)
221  if split_string:
222    return _analyze_split_string(
223        split_string, bias_axes, input_shape, output_shape, left_elided=True)
224
225  # This is the case where ellipses are present on the right.
226  split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0",
227                          dot_replaced_string)
228  if split_string:
229    return _analyze_split_string(split_string, bias_axes, input_shape,
230                                 output_shape)
231
232  raise ValueError(
233      "Invalid einsum equation '%s'. Equations must be in the form "
234      "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation)
235
236
237def _analyze_split_string(split_string,
238                          bias_axes,
239                          input_shape,
240                          output_shape,
241                          left_elided=False):
242  """Analyze an pre-split einsum string to find the weight shape."""
243  input_spec = split_string.group(1)
244  weight_spec = split_string.group(2)
245  output_spec = split_string.group(3)
246  elided = len(input_shape) - len(input_spec)
247
248  if isinstance(output_shape, int):
249    output_shape = [output_shape]
250  else:
251    output_shape = list(output_shape)
252
253  output_shape.insert(0, input_shape[0])
254
255  if elided > 0 and left_elided:
256    for i in range(1, elided):
257      # We already inserted the 0th input dimension at dim 0, so we need to
258      # start at location 1 here.
259      output_shape.insert(1, input_shape[i])
260  elif elided > 0 and not left_elided:
261    for i in range(len(input_shape) - elided, len(input_shape)):
262      output_shape.append(input_shape[i])
263
264  if left_elided:
265    # If we have beginning dimensions elided, we need to use negative indexing
266    # to determine where in the input dimension our values are.
267    input_dim_map = {
268        dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
269    }
270    # Because we've constructed the full output shape already, we don't need
271    # to do negative indexing.
272    output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
273  else:
274    input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
275    output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
276
277  for i, dim in enumerate(input_spec):
278    input_shape_at_dim = input_shape[i]
279    if dim in output_dim_map:
280      output_shape_at_dim = output_shape[output_dim_map[dim]]
281      if (output_shape_at_dim is not None and
282          output_shape_at_dim != input_shape_at_dim):
283        raise ValueError(
284            "Input shape and output shape do not match at shared "
285            "dimension '%s'. Input shape is %s, and output shape "
286            "is %s." %
287            (dim, input_shape_at_dim, output_shape[output_dim_map[dim]]))
288
289  for dim in output_spec:
290    if dim not in input_spec and dim not in weight_spec:
291      raise ValueError("Dimension '%s' was specified in the output '%s' but "
292                       "has no corresponding dim in the input spec '%s' or "
293                       "weight spec '%s.'" % (dim, output_spec, input_spec,
294                                              output_spec))
295
296  weight_shape = []
297  for dim in weight_spec:
298    if dim in input_dim_map:
299      weight_shape.append(input_shape[input_dim_map[dim]])
300    elif dim in output_dim_map:
301      weight_shape.append(output_shape[output_dim_map[dim]])
302    else:
303      raise ValueError("Weight dimension '%s' did not have a match in either "
304                       "the input spec '%s' or the output spec '%s'. For this "
305                       "layer, the weight must be fully specified." %
306                       (dim, input_spec, output_spec))
307
308  if bias_axes is not None:
309    num_left_elided = elided if left_elided else 0
310    idx_map = {
311        char: output_shape[i + num_left_elided]
312        for i, char in enumerate(output_spec)
313    }
314
315    for char in bias_axes:
316      if char not in output_spec:
317        raise ValueError("Bias dimension '%s' was requested, but is not a part "
318                         "of the output specification '%s'" %
319                         (char, output_spec))
320
321    first_bias_location = min([output_spec.find(char) for char in bias_axes])
322    bias_output_spec = output_spec[first_bias_location:]
323
324    bias_shape = [
325        idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
326    ]
327
328    if not left_elided:
329      for _ in range(elided):
330        bias_shape.append(1)
331  else:
332    bias_shape = None
333
334  return weight_shape, bias_shape, output_shape
335