• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=not-callable
16# pylint: disable=redefined-builtin
17"""Layers that can merge several inputs into one."""
18
19from tensorflow.python.keras import backend
20from tensorflow.python.keras.engine import base_layer_utils
21from tensorflow.python.keras.engine.base_layer import Layer
22from tensorflow.python.keras.utils import tf_utils
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import nn
26from tensorflow.python.util.tf_export import keras_export
27
28
29class _Merge(Layer):
30  """Generic merge layer for elementwise merge functions.
31
32  Used to implement `Sum`, `Average`, etc.
33  """
34
35  def __init__(self, **kwargs):
36    """Intializes a Merge layer.
37
38    Args:
39      **kwargs: standard layer keyword arguments.
40    """
41    super(_Merge, self).__init__(**kwargs)
42    self.supports_masking = True
43
44  def _merge_function(self, inputs):
45    raise NotImplementedError
46
47  def _compute_elemwise_op_output_shape(self, shape1, shape2):
48    """Computes the shape of the resultant of an elementwise operation.
49
50    Args:
51        shape1: tuple or None. Shape of the first tensor
52        shape2: tuple or None. Shape of the second tensor
53
54    Returns:
55        expected output shape when an element-wise operation is
56        carried out on 2 tensors with shapes shape1 and shape2.
57        tuple or None.
58
59    Raises:
60        ValueError: if shape1 and shape2 are not compatible for
61            element-wise operations.
62    """
63    if None in [shape1, shape2]:
64      return None
65    elif len(shape1) < len(shape2):
66      return self._compute_elemwise_op_output_shape(shape2, shape1)
67    elif not shape2:
68      return shape1
69    output_shape = list(shape1[:-len(shape2)])
70    for i, j in zip(shape1[-len(shape2):], shape2):
71      if i is None or j is None:
72        output_shape.append(None)
73      elif i == 1:
74        output_shape.append(j)
75      elif j == 1:
76        output_shape.append(i)
77      else:
78        if i != j:
79          raise ValueError(
80              'Operands could not be broadcast '
81              'together with shapes ' + str(shape1) + ' ' + str(shape2))
82        output_shape.append(i)
83    return tuple(output_shape)
84
85  @tf_utils.shape_type_conversion
86  def build(self, input_shape):
87    # Used purely for shape validation.
88    if not isinstance(input_shape[0], tuple):
89      raise ValueError('A merge layer should be called on a list of inputs.')
90    if len(input_shape) < 2:
91      raise ValueError('A merge layer should be called '
92                       'on a list of at least 2 inputs. '
93                       'Got ' + str(len(input_shape)) + ' inputs.')
94    batch_sizes = {s[0] for s in input_shape if s} - {None}
95    if len(batch_sizes) > 1:
96      raise ValueError(
97          'Can not merge tensors with different '
98          'batch sizes. Got tensors with shapes : ' + str(input_shape))
99    if input_shape[0] is None:
100      output_shape = None
101    else:
102      output_shape = input_shape[0][1:]
103    for i in range(1, len(input_shape)):
104      if input_shape[i] is None:
105        shape = None
106      else:
107        shape = input_shape[i][1:]
108      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
109    # If the inputs have different ranks, we have to reshape them
110    # to make them broadcastable.
111    if None not in input_shape and len(set(map(len, input_shape))) == 1:
112      self._reshape_required = False
113    else:
114      self._reshape_required = True
115
116  def call(self, inputs):
117    if not isinstance(inputs, (list, tuple)):
118      raise ValueError('A merge layer should be called on a list of inputs.')
119    if self._reshape_required:
120      reshaped_inputs = []
121      input_ndims = list(map(backend.ndim, inputs))
122      if None not in input_ndims:
123        # If ranks of all inputs are available,
124        # we simply expand each of them at axis=1
125        # until all of them have the same rank.
126        max_ndim = max(input_ndims)
127        for x in inputs:
128          x_ndim = backend.ndim(x)
129          for _ in range(max_ndim - x_ndim):
130            x = array_ops.expand_dims(x, axis=1)
131          reshaped_inputs.append(x)
132        return self._merge_function(reshaped_inputs)
133      else:
134        # Transpose all inputs so that batch size is the last dimension.
135        # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
136        transposed = False
137        for x in inputs:
138          x_ndim = backend.ndim(x)
139          if x_ndim is None:
140            x_shape = array_ops.shape(x)
141            batch_size = x_shape[0]
142            new_shape = backend.concatenate(
143                [x_shape[1:],
144                 array_ops.expand_dims(batch_size, axis=-1)])
145            x_transposed = array_ops.reshape(
146                x,
147                array_ops.stack(
148                    [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
149            x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
150            x_transposed = array_ops.reshape(x_transposed, new_shape)
151            reshaped_inputs.append(x_transposed)
152            transposed = True
153          elif x_ndim > 1:
154            dims = list(range(1, x_ndim)) + [0]
155            reshaped_inputs.append(array_ops.transpose(x, perm=dims))
156            transposed = True
157          else:
158            # We don't transpose inputs if they are 1D vectors or scalars.
159            reshaped_inputs.append(x)
160        y = self._merge_function(reshaped_inputs)
161        y_ndim = backend.ndim(y)
162        if transposed:
163          # If inputs have been transposed, we have to transpose the output too.
164          if y_ndim is None:
165            y_shape = array_ops.shape(y)
166            y_ndim = array_ops.shape(y_shape)[0]
167            batch_size = y_shape[y_ndim - 1]
168            new_shape = backend.concatenate([
169                array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
170            ])
171            y = array_ops.reshape(y, (-1, batch_size))
172            y = array_ops.transpose(y, perm=(1, 0))
173            y = array_ops.reshape(y, new_shape)
174          elif y_ndim > 1:
175            dims = [y_ndim - 1] + list(range(y_ndim - 1))
176            y = array_ops.transpose(y, perm=dims)
177        return y
178    else:
179      return self._merge_function(inputs)
180
181  @tf_utils.shape_type_conversion
182  def compute_output_shape(self, input_shape):
183    if input_shape[0] is None:
184      output_shape = None
185    else:
186      output_shape = input_shape[0][1:]
187    for i in range(1, len(input_shape)):
188      if input_shape[i] is None:
189        shape = None
190      else:
191        shape = input_shape[i][1:]
192      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
193    batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
194    if len(batch_sizes) == 1:
195      output_shape = (list(batch_sizes)[0],) + output_shape
196    else:
197      output_shape = (None,) + output_shape
198    return output_shape
199
200  def compute_mask(self, inputs, mask=None):
201    if mask is None:
202      return None
203    if not isinstance(mask, (tuple, list)):
204      raise ValueError('`mask` should be a list.')
205    if not isinstance(inputs, (tuple, list)):
206      raise ValueError('`inputs` should be a list.')
207    if len(mask) != len(inputs):
208      raise ValueError('The lists `inputs` and `mask` '
209                       'should have the same length.')
210    if all(m is None for m in mask):
211      return None
212    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
213    return backend.all(
214        backend.concatenate(masks, axis=0), axis=0, keepdims=False)
215
216
217@keras_export('keras.layers.Add')
218class Add(_Merge):
219  """Layer that adds a list of inputs.
220
221  It takes as input a list of tensors,
222  all of the same shape, and returns
223  a single tensor (also of the same shape).
224
225  Examples:
226
227  >>> input_shape = (2, 3, 4)
228  >>> x1 = tf.random.normal(input_shape)
229  >>> x2 = tf.random.normal(input_shape)
230  >>> y = tf.keras.layers.Add()([x1, x2])
231  >>> print(y.shape)
232  (2, 3, 4)
233
234  Used in a functional model:
235
236  >>> input1 = tf.keras.layers.Input(shape=(16,))
237  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
238  >>> input2 = tf.keras.layers.Input(shape=(32,))
239  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
240  >>> # equivalent to `added = tf.keras.layers.add([x1, x2])`
241  >>> added = tf.keras.layers.Add()([x1, x2])
242  >>> out = tf.keras.layers.Dense(4)(added)
243  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
244
245  """
246
247  def _merge_function(self, inputs):
248    output = inputs[0]
249    for i in range(1, len(inputs)):
250      output += inputs[i]
251    return output
252
253
254@keras_export('keras.layers.Subtract')
255class Subtract(_Merge):
256  """Layer that subtracts two inputs.
257
258  It takes as input a list of tensors of size 2,
259  both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
260  also of the same shape.
261
262  Examples:
263
264  ```python
265      import keras
266
267      input1 = keras.layers.Input(shape=(16,))
268      x1 = keras.layers.Dense(8, activation='relu')(input1)
269      input2 = keras.layers.Input(shape=(32,))
270      x2 = keras.layers.Dense(8, activation='relu')(input2)
271      # Equivalent to subtracted = keras.layers.subtract([x1, x2])
272      subtracted = keras.layers.Subtract()([x1, x2])
273
274      out = keras.layers.Dense(4)(subtracted)
275      model = keras.models.Model(inputs=[input1, input2], outputs=out)
276  ```
277  """
278
279  @tf_utils.shape_type_conversion
280  def build(self, input_shape):
281    super(Subtract, self).build(input_shape)
282    if len(input_shape) != 2:
283      raise ValueError('A `Subtract` layer should be called '
284                       'on exactly 2 inputs')
285
286  def _merge_function(self, inputs):
287    if len(inputs) != 2:
288      raise ValueError('A `Subtract` layer should be called '
289                       'on exactly 2 inputs')
290    return inputs[0] - inputs[1]
291
292
293@keras_export('keras.layers.Multiply')
294class Multiply(_Merge):
295  """Layer that multiplies (element-wise) a list of inputs.
296
297  It takes as input a list of tensors, all of the same shape, and returns
298  a single tensor (also of the same shape).
299
300  >>> tf.keras.layers.Multiply()([np.arange(5).reshape(5, 1),
301  ...                             np.arange(5, 10).reshape(5, 1)])
302  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
303  array([[ 0],
304       [ 6],
305       [14],
306       [24],
307       [36]])>
308
309  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
310  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
311  >>> multiplied = tf.keras.layers.Multiply()([x1, x2])
312  >>> multiplied.shape
313  TensorShape([5, 8])
314  """
315
316  def _merge_function(self, inputs):
317    output = inputs[0]
318    for i in range(1, len(inputs)):
319      output = output * inputs[i]
320    return output
321
322
323@keras_export('keras.layers.Average')
324class Average(_Merge):
325  """Layer that averages a list of inputs element-wise.
326
327  It takes as input a list of tensors, all of the same shape, and returns
328  a single tensor (also of the same shape).
329
330  Example:
331
332  >>> x1 = np.ones((2, 2))
333  >>> x2 = np.zeros((2, 2))
334  >>> y = tf.keras.layers.Average()([x1, x2])
335  >>> y.numpy().tolist()
336  [[0.5, 0.5], [0.5, 0.5]]
337
338  Usage in a functional model:
339
340  >>> input1 = tf.keras.layers.Input(shape=(16,))
341  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
342  >>> input2 = tf.keras.layers.Input(shape=(32,))
343  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
344  >>> avg = tf.keras.layers.Average()([x1, x2])
345  >>> out = tf.keras.layers.Dense(4)(avg)
346  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
347
348  Raises:
349    ValueError: If there is a shape mismatch between the inputs and the shapes
350      cannot be broadcasted to match.
351  """
352
353  def _merge_function(self, inputs):
354    output = inputs[0]
355    for i in range(1, len(inputs)):
356      output += inputs[i]
357    return output / len(inputs)
358
359
360@keras_export('keras.layers.Maximum')
361class Maximum(_Merge):
362  """Layer that computes the maximum (element-wise) a list of inputs.
363
364  It takes as input a list of tensors, all of the same shape, and returns
365  a single tensor (also of the same shape).
366
367  >>> tf.keras.layers.Maximum()([np.arange(5).reshape(5, 1),
368  ...                            np.arange(5, 10).reshape(5, 1)])
369  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
370  array([[5],
371       [6],
372       [7],
373       [8],
374       [9]])>
375
376  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
377  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
378  >>> maxed = tf.keras.layers.Maximum()([x1, x2])
379  >>> maxed.shape
380  TensorShape([5, 8])
381  """
382
383  def _merge_function(self, inputs):
384    output = inputs[0]
385    for i in range(1, len(inputs)):
386      output = math_ops.maximum(output, inputs[i])
387    return output
388
389
390@keras_export('keras.layers.Minimum')
391class Minimum(_Merge):
392  """Layer that computes the minimum (element-wise) a list of inputs.
393
394  It takes as input a list of tensors, all of the same shape, and returns
395  a single tensor (also of the same shape).
396
397  >>> tf.keras.layers.Minimum()([np.arange(5).reshape(5, 1),
398  ...                            np.arange(5, 10).reshape(5, 1)])
399  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
400  array([[0],
401       [1],
402       [2],
403       [3],
404       [4]])>
405
406  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
407  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
408  >>> minned = tf.keras.layers.Minimum()([x1, x2])
409  >>> minned.shape
410  TensorShape([5, 8])
411  """
412
413  def _merge_function(self, inputs):
414    output = inputs[0]
415    for i in range(1, len(inputs)):
416      output = math_ops.minimum(output, inputs[i])
417    return output
418
419
420@keras_export('keras.layers.Concatenate')
421class Concatenate(_Merge):
422  """Layer that concatenates a list of inputs.
423
424  It takes as input a list of tensors, all of the same shape except
425  for the concatenation axis, and returns a single tensor that is the
426  concatenation of all inputs.
427
428  >>> x = np.arange(20).reshape(2, 2, 5)
429  >>> print(x)
430  [[[ 0  1  2  3  4]
431    [ 5  6  7  8  9]]
432   [[10 11 12 13 14]
433    [15 16 17 18 19]]]
434  >>> y = np.arange(20, 30).reshape(2, 1, 5)
435  >>> print(y)
436  [[[20 21 22 23 24]]
437   [[25 26 27 28 29]]]
438  >>> tf.keras.layers.Concatenate(axis=1)([x, y])
439  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
440  array([[[ 0,  1,  2,  3,  4],
441          [ 5,  6,  7,  8,  9],
442          [20, 21, 22, 23, 24]],
443         [[10, 11, 12, 13, 14],
444          [15, 16, 17, 18, 19],
445          [25, 26, 27, 28, 29]]])>
446
447  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
448  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
449  >>> concatted = tf.keras.layers.Concatenate()([x1, x2])
450  >>> concatted.shape
451  TensorShape([5, 16])
452
453  """
454
455  def __init__(self, axis=-1, **kwargs):
456    """Instantiates a Concatenate layer.
457
458    >>> x = np.arange(20).reshape(2, 2, 5)
459    >>> print(x)
460    [[[ 0  1  2  3  4]
461      [ 5  6  7  8  9]]
462     [[10 11 12 13 14]
463      [15 16 17 18 19]]]
464    >>> y = np.arange(20, 30).reshape(2, 1, 5)
465    >>> print(y)
466    [[[20 21 22 23 24]]
467     [[25 26 27 28 29]]]
468    >>> tf.keras.layers.Concatenate(axis=1)([x, y])
469    <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
470    array([[[ 0,  1,  2,  3,  4],
471            [ 5,  6,  7,  8,  9],
472            [20, 21, 22, 23, 24]],
473           [[10, 11, 12, 13, 14],
474            [15, 16, 17, 18, 19],
475            [25, 26, 27, 28, 29]]])>
476
477    Args:
478      axis: Axis along which to concatenate.
479      **kwargs: standard layer keyword arguments.
480    """
481    super(Concatenate, self).__init__(**kwargs)
482    self.axis = axis
483    self.supports_masking = True
484    self._reshape_required = False
485
486  @tf_utils.shape_type_conversion
487  def build(self, input_shape):
488    # Used purely for shape validation.
489    if not isinstance(input_shape[0], tuple) or len(input_shape) < 1:
490      raise ValueError('A `Concatenate` layer should be called '
491                       'on a list of at least 1 input.')
492    if all(shape is None for shape in input_shape):
493      return
494    reduced_inputs_shapes = [list(shape) for shape in input_shape]
495    shape_set = set()
496    for i in range(len(reduced_inputs_shapes)):
497      del reduced_inputs_shapes[i][self.axis]
498      shape_set.add(tuple(reduced_inputs_shapes[i]))
499
500    if len(shape_set) != 1:
501      err_msg = ('A `Concatenate` layer requires inputs with matching shapes '
502                 'except for the concat axis. Got inputs shapes: %s' %
503                 input_shape)
504      # Make sure all the shapes have same ranks.
505      ranks = set(len(shape) for shape in shape_set)
506      if len(ranks) != 1:
507        raise ValueError(err_msg)
508      # Get the only rank for the set.
509      (rank,) = ranks
510      for axis in range(rank):
511        # Skip the Nones in the shape since they are dynamic, also the axis for
512        # concat has been removed above.
513        unique_dims = set(
514            shape[axis] for shape in shape_set if shape[axis] is not None)
515        if len(unique_dims) > 1:
516          raise ValueError(err_msg)
517
518  def _merge_function(self, inputs):
519    return backend.concatenate(inputs, axis=self.axis)
520
521  @tf_utils.shape_type_conversion
522  def compute_output_shape(self, input_shape):
523    if ((not isinstance(input_shape, (tuple, list))) or
524        (not isinstance(input_shape[0], (tuple, list)))):
525      # The tf_utils.shape_type_conversion decorator turns tensorshapes
526      # into tuples, so we need to verify that `input_shape` is a list/tuple,
527      # *and* that the individual elements are themselves shape tuples.
528      raise ValueError('A `Concatenate` layer should be called '
529                       'on a list of inputs.')
530    input_shapes = input_shape
531    output_shape = list(input_shapes[0])
532    for shape in input_shapes[1:]:
533      if output_shape[self.axis] is None or shape[self.axis] is None:
534        output_shape[self.axis] = None
535        break
536      output_shape[self.axis] += shape[self.axis]
537    return tuple(output_shape)
538
539  def compute_mask(self, inputs, mask=None):
540    if mask is None:
541      return None
542    if not isinstance(mask, (tuple, list)):
543      raise ValueError('`mask` should be a list.')
544    if not isinstance(inputs, (tuple, list)):
545      raise ValueError('`inputs` should be a list.')
546    if len(mask) != len(inputs):
547      raise ValueError('The lists `inputs` and `mask` '
548                       'should have the same length.')
549    if all(m is None for m in mask):
550      return None
551    # Make a list of masks while making sure
552    # the dimensionality of each mask
553    # is the same as the corresponding input.
554    masks = []
555    for input_i, mask_i in zip(inputs, mask):
556      if mask_i is None:
557        # Input is unmasked. Append all 1s to masks,
558        masks.append(array_ops.ones_like(input_i, dtype='bool'))
559      elif backend.ndim(mask_i) < backend.ndim(input_i):
560        # Mask is smaller than the input, expand it
561        masks.append(array_ops.expand_dims(mask_i, axis=-1))
562      else:
563        masks.append(mask_i)
564    concatenated = backend.concatenate(masks, axis=self.axis)
565    return backend.all(concatenated, axis=-1, keepdims=False)
566
567  def get_config(self):
568    config = {
569        'axis': self.axis,
570    }
571    base_config = super(Concatenate, self).get_config()
572    return dict(list(base_config.items()) + list(config.items()))
573
574
575@keras_export('keras.layers.Dot')
576class Dot(_Merge):
577  """Layer that computes a dot product between samples in two tensors.
578
579  E.g. if applied to a list of two tensors `a` and `b` of shape
580  `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
581  where each entry `i` will be the dot product between
582  `a[i]` and `b[i]`.
583
584  >>> x = np.arange(10).reshape(1, 5, 2)
585  >>> print(x)
586  [[[0 1]
587    [2 3]
588    [4 5]
589    [6 7]
590    [8 9]]]
591  >>> y = np.arange(10, 20).reshape(1, 2, 5)
592  >>> print(y)
593  [[[10 11 12 13 14]
594    [15 16 17 18 19]]]
595  >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
596  <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
597  array([[[260, 360],
598          [320, 445]]])>
599
600  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
601  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
602  >>> dotted = tf.keras.layers.Dot(axes=1)([x1, x2])
603  >>> dotted.shape
604  TensorShape([5, 1])
605
606
607  """
608
609  def __init__(self, axes, normalize=False, **kwargs):
610    """Initializes a layer that computes the element-wise dot product.
611
612      >>> x = np.arange(10).reshape(1, 5, 2)
613      >>> print(x)
614      [[[0 1]
615        [2 3]
616        [4 5]
617        [6 7]
618        [8 9]]]
619      >>> y = np.arange(10, 20).reshape(1, 2, 5)
620      >>> print(y)
621      [[[10 11 12 13 14]
622        [15 16 17 18 19]]]
623      >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
624      <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
625      array([[[260, 360],
626              [320, 445]]])>
627
628    Args:
629      axes: Integer or tuple of integers,
630        axis or axes along which to take the dot product. If a tuple, should
631        be two integers corresponding to the desired axis from the first input
632        and the desired axis from the second input, respectively. Note that the
633        size of the two selected axes must match.
634      normalize: Whether to L2-normalize samples along the
635        dot product axis before taking the dot product.
636        If set to True, then the output of the dot product
637        is the cosine proximity between the two samples.
638      **kwargs: Standard layer keyword arguments.
639    """
640    super(Dot, self).__init__(**kwargs)
641    if not isinstance(axes, int):
642      if not isinstance(axes, (list, tuple)):
643        raise TypeError('Invalid type for `axes` - '
644                        'should be a list or an int.')
645      if len(axes) != 2:
646        raise ValueError('Invalid format for `axes` - '
647                         'should contain two elements.')
648      if not isinstance(axes[0], int) or not isinstance(axes[1], int):
649        raise ValueError('Invalid format for `axes` - '
650                         'list elements should be "int".')
651    self.axes = axes
652    self.normalize = normalize
653    self.supports_masking = True
654    self._reshape_required = False
655
656  @tf_utils.shape_type_conversion
657  def build(self, input_shape):
658    # Used purely for shape validation.
659    if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
660      raise ValueError('A `Dot` layer should be called '
661                       'on a list of 2 inputs.')
662    shape1 = input_shape[0]
663    shape2 = input_shape[1]
664    if shape1 is None or shape2 is None:
665      return
666    if isinstance(self.axes, int):
667      if self.axes < 0:
668        axes = [self.axes % len(shape1), self.axes % len(shape2)]
669      else:
670        axes = [self.axes] * 2
671    else:
672      axes = self.axes
673    if shape1[axes[0]] != shape2[axes[1]]:
674      raise ValueError('Dimension incompatibility '
675                       '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
676                       'Layer shapes: %s, %s. ' % (shape1, shape2) +
677                       'Chosen axes: %s, %s' % (axes[0], axes[1]))
678
679  def _merge_function(self, inputs):
680    base_layer_utils.no_ragged_support(inputs, self.name)
681    if len(inputs) != 2:
682      raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
683    x1 = inputs[0]
684    x2 = inputs[1]
685    if isinstance(self.axes, int):
686      if self.axes < 0:
687        axes = [self.axes % backend.ndim(x1), self.axes % backend.ndim(x2)]
688      else:
689        axes = [self.axes] * 2
690    else:
691      axes = []
692      for i in range(len(self.axes)):
693        if self.axes[i] < 0:
694          axes.append(self.axes[i] % backend.ndim(inputs[i]))
695        else:
696          axes.append(self.axes[i])
697    if self.normalize:
698      x1 = nn.l2_normalize(x1, axis=axes[0])
699      x2 = nn.l2_normalize(x2, axis=axes[1])
700    output = backend.batch_dot(x1, x2, axes)
701    return output
702
703  @tf_utils.shape_type_conversion
704  def compute_output_shape(self, input_shape):
705    if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
706      raise ValueError('A `Dot` layer should be called '
707                       'on a list of 2 inputs.')
708    shape1 = list(input_shape[0])
709    shape2 = list(input_shape[1])
710    if isinstance(self.axes, int):
711      if self.axes < 0:
712        axes = [self.axes % len(shape1), self.axes % len(shape2)]
713      else:
714        axes = [self.axes] * 2
715    else:
716      axes = self.axes
717    shape1.pop(axes[0])
718    shape2.pop(axes[1])
719    shape2.pop(0)
720    output_shape = shape1 + shape2
721    if len(output_shape) == 1:
722      output_shape += [1]
723    return tuple(output_shape)
724
725  def compute_mask(self, inputs, mask=None):
726    return None
727
728  def get_config(self):
729    config = {
730        'axes': self.axes,
731        'normalize': self.normalize,
732    }
733    base_config = super(Dot, self).get_config()
734    return dict(list(base_config.items()) + list(config.items()))
735
736
737@keras_export('keras.layers.add')
738def add(inputs, **kwargs):
739  """Functional interface to the `tf.keras.layers.Add` layer.
740
741  Args:
742      inputs: A list of input tensors (at least 2) with the same shape.
743      **kwargs: Standard layer keyword arguments.
744
745  Returns:
746      A tensor as the sum of the inputs. It has the same shape as the inputs.
747
748  Examples:
749
750  >>> input_shape = (2, 3, 4)
751  >>> x1 = tf.random.normal(input_shape)
752  >>> x2 = tf.random.normal(input_shape)
753  >>> y = tf.keras.layers.add([x1, x2])
754  >>> print(y.shape)
755  (2, 3, 4)
756
757  Used in a functional model:
758
759  >>> input1 = tf.keras.layers.Input(shape=(16,))
760  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
761  >>> input2 = tf.keras.layers.Input(shape=(32,))
762  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
763  >>> added = tf.keras.layers.add([x1, x2])
764  >>> out = tf.keras.layers.Dense(4)(added)
765  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
766
767  """
768  return Add(**kwargs)(inputs)
769
770
771@keras_export('keras.layers.subtract')
772def subtract(inputs, **kwargs):
773  """Functional interface to the `Subtract` layer.
774
775  Args:
776      inputs: A list of input tensors (exactly 2).
777      **kwargs: Standard layer keyword arguments.
778
779  Returns:
780      A tensor, the difference of the inputs.
781
782  Examples:
783
784  ```python
785      import keras
786
787      input1 = keras.layers.Input(shape=(16,))
788      x1 = keras.layers.Dense(8, activation='relu')(input1)
789      input2 = keras.layers.Input(shape=(32,))
790      x2 = keras.layers.Dense(8, activation='relu')(input2)
791      subtracted = keras.layers.subtract([x1, x2])
792
793      out = keras.layers.Dense(4)(subtracted)
794      model = keras.models.Model(inputs=[input1, input2], outputs=out)
795  ```
796  """
797  return Subtract(**kwargs)(inputs)
798
799
800@keras_export('keras.layers.multiply')
801def multiply(inputs, **kwargs):
802  """Functional interface to the `Multiply` layer.
803
804  Example:
805
806  >>> x1 = np.arange(3.0)
807  >>> x2 = np.arange(3.0)
808  >>> tf.keras.layers.multiply([x1, x2])
809  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 4.], ...)>
810
811  Usage in a functional model:
812
813  >>> input1 = tf.keras.layers.Input(shape=(16,))
814  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1) #shape=(None, 8)
815  >>> input2 = tf.keras.layers.Input(shape=(32,))
816  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2) #shape=(None, 8)
817  >>> out = tf.keras.layers.multiply([x1,x2]) #shape=(None, 8)
818  >>> out = tf.keras.layers.Dense(4)(out)
819  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
820
821  Args:
822      inputs: A list of input tensors (at least 2).
823      **kwargs: Standard layer keyword arguments.
824
825  Returns:
826      A tensor, the element-wise product of the inputs.
827  """
828  return Multiply(**kwargs)(inputs)
829
830
831@keras_export('keras.layers.average')
832def average(inputs, **kwargs):
833  """Functional interface to the `tf.keras.layers.Average` layer.
834
835  Example:
836
837  >>> x1 = np.ones((2, 2))
838  >>> x2 = np.zeros((2, 2))
839  >>> y = tf.keras.layers.Average()([x1, x2])
840  >>> y.numpy().tolist()
841  [[0.5, 0.5], [0.5, 0.5]]
842
843  Usage in a functional model:
844
845  >>> input1 = tf.keras.layers.Input(shape=(16,))
846  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
847  >>> input2 = tf.keras.layers.Input(shape=(32,))
848  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
849  >>> avg = tf.keras.layers.Average()([x1, x2])
850  >>> out = tf.keras.layers.Dense(4)(avg)
851  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
852
853  Args:
854      inputs: A list of input tensors (at least 2).
855      **kwargs: Standard layer keyword arguments.
856
857  Returns:
858      A tensor, the average of the inputs.
859
860  Raises:
861    ValueError: If there is a shape mismatch between the inputs and the shapes
862      cannot be broadcasted to match.
863  """
864  return Average(**kwargs)(inputs)
865
866
867@keras_export('keras.layers.maximum')
868def maximum(inputs, **kwargs):
869  """Functional interface to compute maximum (element-wise) list of `inputs`.
870
871  This is equivalent to the `tf.keras.layers.Maximum` layer.
872
873  For example:
874
875  ```python
876  input1 = tf.keras.layers.Input(shape=(16,))
877  x1 = tf.keras.layers.Dense(8, activation='relu')(input1) #shape=(None, 8)
878  input2 = tf.keras.layers.Input(shape=(32,))
879  x2 = tf.keras.layers.Dense(8, activation='relu')(input2) #shape=(None, 8)
880  max_inp=tf.keras.layers.maximum([x1,x2]) #shape=(None, 8)
881  out = tf.keras.layers.Dense(4)(max_inp)
882  model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
883  ```
884
885  Args:
886      inputs: A list of input tensors (at least 2) of same shape.
887      **kwargs: Standard layer keyword arguments.
888
889  Returns:
890      A tensor (of same shape as input tensor) with the element-wise
891      maximum of the inputs.
892
893  Raises:
894      ValueError: If input tensors are of different shape.
895  """
896  return Maximum(**kwargs)(inputs)
897
898
899@keras_export('keras.layers.minimum')
900def minimum(inputs, **kwargs):
901  """Functional interface to the `Minimum` layer.
902
903  Args:
904      inputs: A list of input tensors (at least 2).
905      **kwargs: Standard layer keyword arguments.
906
907  Returns:
908      A tensor, the element-wise minimum of the inputs.
909  """
910  return Minimum(**kwargs)(inputs)
911
912
913@keras_export('keras.layers.concatenate')
914def concatenate(inputs, axis=-1, **kwargs):
915  """Functional interface to the `Concatenate` layer.
916
917  >>> x = np.arange(20).reshape(2, 2, 5)
918  >>> print(x)
919  [[[ 0  1  2  3  4]
920    [ 5  6  7  8  9]]
921   [[10 11 12 13 14]
922    [15 16 17 18 19]]]
923  >>> y = np.arange(20, 30).reshape(2, 1, 5)
924  >>> print(y)
925  [[[20 21 22 23 24]]
926   [[25 26 27 28 29]]]
927  >>> tf.keras.layers.concatenate([x, y],
928  ...                             axis=1)
929  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
930  array([[[ 0,  1,  2,  3,  4],
931        [ 5,  6,  7,  8,  9],
932        [20, 21, 22, 23, 24]],
933       [[10, 11, 12, 13, 14],
934        [15, 16, 17, 18, 19],
935        [25, 26, 27, 28, 29]]])>
936
937  Args:
938      inputs: A list of input tensors (at least 2).
939      axis: Concatenation axis.
940      **kwargs: Standard layer keyword arguments.
941
942  Returns:
943      A tensor, the concatenation of the inputs alongside axis `axis`.
944  """
945  return Concatenate(axis=axis, **kwargs)(inputs)
946
947
948@keras_export('keras.layers.dot')
949def dot(inputs, axes, normalize=False, **kwargs):
950  """Functional interface to the `Dot` layer.
951
952  Args:
953      inputs: A list of input tensors (at least 2).
954      axes: Integer or tuple of integers,
955          axis or axes along which to take the dot product.
956      normalize: Whether to L2-normalize samples along the
957          dot product axis before taking the dot product.
958          If set to True, then the output of the dot product
959          is the cosine proximity between the two samples.
960      **kwargs: Standard layer keyword arguments.
961
962  Returns:
963      A tensor, the dot product of the samples from the inputs.
964  """
965  return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)
966