• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=not-callable
16# pylint: disable=redefined-builtin
17"""Layers that can merge several inputs into one.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras.engine.base_layer import Layer
25from tensorflow.python.keras.utils import tf_utils
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn
29from tensorflow.python.util.tf_export import keras_export
30
31
32class _Merge(Layer):
33  """Generic merge layer for elementwise merge functions.
34
35  Used to implement `Sum`, `Average`, etc.
36
37  Arguments:
38      **kwargs: standard layer keyword arguments.
39  """
40
41  def __init__(self, **kwargs):
42    super(_Merge, self).__init__(**kwargs)
43    self.supports_masking = True
44
45  def _merge_function(self, inputs):
46    raise NotImplementedError
47
48  def _compute_elemwise_op_output_shape(self, shape1, shape2):
49    """Computes the shape of the resultant of an elementwise operation.
50
51    Arguments:
52        shape1: tuple or None. Shape of the first tensor
53        shape2: tuple or None. Shape of the second tensor
54
55    Returns:
56        expected output shape when an element-wise operation is
57        carried out on 2 tensors with shapes shape1 and shape2.
58        tuple or None.
59
60    Raises:
61        ValueError: if shape1 and shape2 are not compatible for
62            element-wise operations.
63    """
64    if None in [shape1, shape2]:
65      return None
66    elif len(shape1) < len(shape2):
67      return self._compute_elemwise_op_output_shape(shape2, shape1)
68    elif not shape2:
69      return shape1
70    output_shape = list(shape1[:-len(shape2)])
71    for i, j in zip(shape1[-len(shape2):], shape2):
72      if i is None or j is None:
73        output_shape.append(None)
74      elif i == 1:
75        output_shape.append(j)
76      elif j == 1:
77        output_shape.append(i)
78      else:
79        if i != j:
80          raise ValueError(
81              'Operands could not be broadcast '
82              'together with shapes ' + str(shape1) + ' ' + str(shape2))
83        output_shape.append(i)
84    return tuple(output_shape)
85
86  @tf_utils.shape_type_conversion
87  def build(self, input_shape):
88    # Used purely for shape validation.
89    if not isinstance(input_shape, list):
90      raise ValueError('A merge layer should be called on a list of inputs.')
91    if len(input_shape) < 2:
92      raise ValueError('A merge layer should be called '
93                       'on a list of at least 2 inputs. '
94                       'Got ' + str(len(input_shape)) + ' inputs.')
95    batch_sizes = [s[0] for s in input_shape if s is not None]
96    batch_sizes = set(batch_sizes)
97    batch_sizes -= set([None])
98    if len(batch_sizes) > 1:
99      raise ValueError(
100          'Can not merge tensors with different '
101          'batch sizes. Got tensors with shapes : ' + str(input_shape))
102    if input_shape[0] is None:
103      output_shape = None
104    else:
105      output_shape = input_shape[0][1:]
106    for i in range(1, len(input_shape)):
107      if input_shape[i] is None:
108        shape = None
109      else:
110        shape = input_shape[i][1:]
111      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
112    # If the inputs have different ranks, we have to reshape them
113    # to make them broadcastable.
114    if None not in input_shape and len(set(map(len, input_shape))) == 1:
115      self._reshape_required = False
116    else:
117      self._reshape_required = True
118
119  def call(self, inputs):
120    if not isinstance(inputs, list):
121      raise ValueError('A merge layer should be called on a list of inputs.')
122    if self._reshape_required:
123      reshaped_inputs = []
124      input_ndims = list(map(K.ndim, inputs))
125      if None not in input_ndims:
126        # If ranks of all inputs are available,
127        # we simply expand each of them at axis=1
128        # until all of them have the same rank.
129        max_ndim = max(input_ndims)
130        for x in inputs:
131          x_ndim = K.ndim(x)
132          for _ in range(max_ndim - x_ndim):
133            x = array_ops.expand_dims(x, axis=1)
134          reshaped_inputs.append(x)
135        return self._merge_function(reshaped_inputs)
136      else:
137        # Transpose all inputs so that batch size is the last dimension.
138        # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
139        transposed = False
140        for x in inputs:
141          x_ndim = K.ndim(x)
142          if x_ndim is None:
143            x_shape = array_ops.shape(x)
144            batch_size = x_shape[0]
145            new_shape = K.concatenate(
146                [x_shape[1:],
147                 array_ops.expand_dims(batch_size, axis=-1)])
148            x_transposed = array_ops.reshape(
149                x,
150                array_ops.stack(
151                    [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
152            x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
153            x_transposed = array_ops.reshape(x_transposed, new_shape)
154            reshaped_inputs.append(x_transposed)
155            transposed = True
156          elif x_ndim > 1:
157            dims = list(range(1, x_ndim)) + [0]
158            reshaped_inputs.append(array_ops.transpose(x, perm=dims))
159            transposed = True
160          else:
161            # We don't transpose inputs if they are 1D vectors or scalars.
162            reshaped_inputs.append(x)
163        y = self._merge_function(reshaped_inputs)
164        y_ndim = K.ndim(y)
165        if transposed:
166          # If inputs have been transposed, we have to transpose the output too.
167          if y_ndim is None:
168            y_shape = array_ops.shape(y)
169            y_ndim = array_ops.shape(y_shape)[0]
170            batch_size = y_shape[y_ndim - 1]
171            new_shape = K.concatenate([
172                array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
173            ])
174            y = array_ops.reshape(y, (-1, batch_size))
175            y = array_ops.transpose(y, perm=(1, 0))
176            y = array_ops.reshape(y, new_shape)
177          elif y_ndim > 1:
178            dims = [y_ndim - 1] + list(range(y_ndim - 1))
179            y = array_ops.transpose(y, perm=dims)
180        return y
181    else:
182      return self._merge_function(inputs)
183
184  @tf_utils.shape_type_conversion
185  def compute_output_shape(self, input_shape):
186    if input_shape[0] is None:
187      output_shape = None
188    else:
189      output_shape = input_shape[0][1:]
190    for i in range(1, len(input_shape)):
191      if input_shape[i] is None:
192        shape = None
193      else:
194        shape = input_shape[i][1:]
195      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
196    batch_sizes = [s[0] for s in input_shape if s is not None]
197    batch_sizes = set(batch_sizes)
198    batch_sizes -= set([None])
199    if len(batch_sizes) == 1:
200      output_shape = (list(batch_sizes)[0],) + output_shape
201    else:
202      output_shape = (None,) + output_shape
203    return output_shape
204
205  def compute_mask(self, inputs, mask=None):
206    if mask is None:
207      return None
208    if not isinstance(mask, list):
209      raise ValueError('`mask` should be a list.')
210    if not isinstance(inputs, list):
211      raise ValueError('`inputs` should be a list.')
212    if len(mask) != len(inputs):
213      raise ValueError('The lists `inputs` and `mask` '
214                       'should have the same length.')
215    if all(m is None for m in mask):
216      return None
217    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
218    return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
219
220
221@keras_export('keras.layers.Add')
222class Add(_Merge):
223  """Layer that adds a list of inputs.
224
225  It takes as input a list of tensors,
226  all of the same shape, and returns
227  a single tensor (also of the same shape).
228
229  Examples:
230
231  ```python
232      import keras
233
234      input1 = keras.layers.Input(shape=(16,))
235      x1 = keras.layers.Dense(8, activation='relu')(input1)
236      input2 = keras.layers.Input(shape=(32,))
237      x2 = keras.layers.Dense(8, activation='relu')(input2)
238      added = keras.layers.Add()([x1, x2])  # equivalent to added =
239      keras.layers.add([x1, x2])
240
241      out = keras.layers.Dense(4)(added)
242      model = keras.models.Model(inputs=[input1, input2], outputs=out)
243  ```
244  """
245
246  def _merge_function(self, inputs):
247    output = inputs[0]
248    for i in range(1, len(inputs)):
249      output += inputs[i]
250    return output
251
252
253@keras_export('keras.layers.Subtract')
254class Subtract(_Merge):
255  """Layer that subtracts two inputs.
256
257  It takes as input a list of tensors of size 2,
258  both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
259  also of the same shape.
260
261  Examples:
262
263  ```python
264      import keras
265
266      input1 = keras.layers.Input(shape=(16,))
267      x1 = keras.layers.Dense(8, activation='relu')(input1)
268      input2 = keras.layers.Input(shape=(32,))
269      x2 = keras.layers.Dense(8, activation='relu')(input2)
270      # Equivalent to subtracted = keras.layers.subtract([x1, x2])
271      subtracted = keras.layers.Subtract()([x1, x2])
272
273      out = keras.layers.Dense(4)(subtracted)
274      model = keras.models.Model(inputs=[input1, input2], outputs=out)
275  ```
276  """
277
278  @tf_utils.shape_type_conversion
279  def build(self, input_shape):
280    super(Subtract, self).build(input_shape)
281    if len(input_shape) != 2:
282      raise ValueError('A `Subtract` layer should be called '
283                       'on exactly 2 inputs')
284
285  def _merge_function(self, inputs):
286    if len(inputs) != 2:
287      raise ValueError('A `Subtract` layer should be called '
288                       'on exactly 2 inputs')
289    return inputs[0] - inputs[1]
290
291
292@keras_export('keras.layers.Multiply')
293class Multiply(_Merge):
294  """Layer that multiplies (element-wise) a list of inputs.
295
296  It takes as input a list of tensors,
297  all of the same shape, and returns
298  a single tensor (also of the same shape).
299  """
300
301  def _merge_function(self, inputs):
302    output = inputs[0]
303    for i in range(1, len(inputs)):
304      output *= inputs[i]
305    return output
306
307
308@keras_export('keras.layers.Average')
309class Average(_Merge):
310  """Layer that averages a list of inputs.
311
312  It takes as input a list of tensors,
313  all of the same shape, and returns
314  a single tensor (also of the same shape).
315  """
316
317  def _merge_function(self, inputs):
318    output = inputs[0]
319    for i in range(1, len(inputs)):
320      output += inputs[i]
321    return output / len(inputs)
322
323
324@keras_export('keras.layers.Maximum')
325class Maximum(_Merge):
326  """Layer that computes the maximum (element-wise) a list of inputs.
327
328  It takes as input a list of tensors,
329  all of the same shape, and returns
330  a single tensor (also of the same shape).
331  """
332
333  def _merge_function(self, inputs):
334    output = inputs[0]
335    for i in range(1, len(inputs)):
336      output = math_ops.maximum(output, inputs[i])
337    return output
338
339
340@keras_export('keras.layers.Minimum')
341class Minimum(_Merge):
342  """Layer that computes the minimum (element-wise) a list of inputs.
343
344  It takes as input a list of tensors,
345  all of the same shape, and returns
346  a single tensor (also of the same shape).
347  """
348
349  def _merge_function(self, inputs):
350    output = inputs[0]
351    for i in range(1, len(inputs)):
352      output = math_ops.minimum(output, inputs[i])
353    return output
354
355
356@keras_export('keras.layers.Concatenate')
357class Concatenate(_Merge):
358  """Layer that concatenates a list of inputs.
359
360  It takes as input a list of tensors,
361  all of the same shape except for the concatenation axis,
362  and returns a single tensor, the concatenation of all inputs.
363
364  Arguments:
365      axis: Axis along which to concatenate.
366      **kwargs: standard layer keyword arguments.
367  """
368
369  def __init__(self, axis=-1, **kwargs):
370    super(Concatenate, self).__init__(**kwargs)
371    self.axis = axis
372    self.supports_masking = True
373    self._reshape_required = False
374
375  @tf_utils.shape_type_conversion
376  def build(self, input_shape):
377    # Used purely for shape validation.
378    if not isinstance(input_shape, list) or len(input_shape) < 2:
379      raise ValueError('A `Concatenate` layer should be called '
380                       'on a list of at least 2 inputs')
381    if all(shape is None for shape in input_shape):
382      return
383    reduced_inputs_shapes = [list(shape) for shape in input_shape]
384    shape_set = set()
385    for i in range(len(reduced_inputs_shapes)):
386      del reduced_inputs_shapes[i][self.axis]
387      shape_set.add(tuple(reduced_inputs_shapes[i]))
388    if len(shape_set) > 1:
389      raise ValueError('A `Concatenate` layer requires '
390                       'inputs with matching shapes '
391                       'except for the concat axis. '
392                       'Got inputs shapes: %s' % (input_shape))
393
394  def _merge_function(self, inputs):
395    return K.concatenate(inputs, axis=self.axis)
396
397  @tf_utils.shape_type_conversion
398  def compute_output_shape(self, input_shape):
399    if not isinstance(input_shape, list):
400      raise ValueError('A `Concatenate` layer should be called '
401                       'on a list of inputs.')
402    input_shapes = input_shape
403    output_shape = list(input_shapes[0])
404    for shape in input_shapes[1:]:
405      if output_shape[self.axis] is None or shape[self.axis] is None:
406        output_shape[self.axis] = None
407        break
408      output_shape[self.axis] += shape[self.axis]
409    return tuple(output_shape)
410
411  def compute_mask(self, inputs, mask=None):
412    if mask is None:
413      return None
414    if not isinstance(mask, list):
415      raise ValueError('`mask` should be a list.')
416    if not isinstance(inputs, list):
417      raise ValueError('`inputs` should be a list.')
418    if len(mask) != len(inputs):
419      raise ValueError('The lists `inputs` and `mask` '
420                       'should have the same length.')
421    if all(m is None for m in mask):
422      return None
423    # Make a list of masks while making sure
424    # the dimensionality of each mask
425    # is the same as the corresponding input.
426    masks = []
427    for input_i, mask_i in zip(inputs, mask):
428      if mask_i is None:
429        # Input is unmasked. Append all 1s to masks,
430        masks.append(array_ops.ones_like(input_i, dtype='bool'))
431      elif K.ndim(mask_i) < K.ndim(input_i):
432        # Mask is smaller than the input, expand it
433        masks.append(array_ops.expand_dims(mask_i, axis=-1))
434      else:
435        masks.append(mask_i)
436    concatenated = K.concatenate(masks, axis=self.axis)
437    return K.all(concatenated, axis=-1, keepdims=False)
438
439  def get_config(self):
440    config = {
441        'axis': self.axis,
442    }
443    base_config = super(Concatenate, self).get_config()
444    return dict(list(base_config.items()) + list(config.items()))
445
446
447@keras_export('keras.layers.Dot')
448class Dot(_Merge):
449  """Layer that computes a dot product between samples in two tensors.
450
451  E.g. if applied to a list of two tensors `a` and `b` of shape
452  `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
453  where each entry `i` will be the dot product between
454  `a[i]` and `b[i]`.
455
456  Arguments:
457      axes: Integer or tuple of integers,
458          axis or axes along which to take the dot product.
459      normalize: Whether to L2-normalize samples along the
460          dot product axis before taking the dot product.
461          If set to True, then the output of the dot product
462          is the cosine proximity between the two samples.
463      **kwargs: Standard layer keyword arguments.
464  """
465
466  def __init__(self, axes, normalize=False, **kwargs):
467    super(Dot, self).__init__(**kwargs)
468    if not isinstance(axes, int):
469      if not isinstance(axes, (list, tuple)):
470        raise TypeError('Invalid type for `axes` - '
471                        'should be a list or an int.')
472      if len(axes) != 2:
473        raise ValueError('Invalid format for `axes` - '
474                         'should contain two elements.')
475      if not isinstance(axes[0], int) or not isinstance(axes[1], int):
476        raise ValueError('Invalid format for `axes` - '
477                         'list elements should be "int".')
478    self.axes = axes
479    self.normalize = normalize
480    self.supports_masking = True
481    self._reshape_required = False
482
483  @tf_utils.shape_type_conversion
484  def build(self, input_shape):
485    # Used purely for shape validation.
486    if not isinstance(input_shape, list) or len(input_shape) != 2:
487      raise ValueError('A `Dot` layer should be called '
488                       'on a list of 2 inputs.')
489    shape1 = input_shape[0]
490    shape2 = input_shape[1]
491    if shape1 is None or shape2 is None:
492      return
493    if isinstance(self.axes, int):
494      if self.axes < 0:
495        axes = [self.axes % len(shape1), self.axes % len(shape2)]
496      else:
497        axes = [self.axes] * 2
498    else:
499      axes = self.axes
500    if shape1[axes[0]] != shape2[axes[1]]:
501      raise ValueError('Dimension incompatibility '
502                       '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
503                       'Layer shapes: %s, %s' % (shape1, shape2))
504
505  def _merge_function(self, inputs):
506    if len(inputs) != 2:
507      raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
508    x1 = inputs[0]
509    x2 = inputs[1]
510    if isinstance(self.axes, int):
511      if self.axes < 0:
512        axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
513      else:
514        axes = [self.axes] * 2
515    else:
516      axes = []
517      for i in range(len(self.axes)):
518        if self.axes[i] < 0:
519          axes.append(self.axes[i] % K.ndim(inputs[i]))
520        else:
521          axes.append(self.axes[i])
522    if self.normalize:
523      x1 = nn.l2_normalize(x1, axis=axes[0])
524      x2 = nn.l2_normalize(x2, axis=axes[1])
525    output = K.batch_dot(x1, x2, axes)
526    return output
527
528  @tf_utils.shape_type_conversion
529  def compute_output_shape(self, input_shape):
530    if not isinstance(input_shape, list) or len(input_shape) != 2:
531      raise ValueError('A `Dot` layer should be called '
532                       'on a list of 2 inputs.')
533    shape1 = list(input_shape[0])
534    shape2 = list(input_shape[1])
535    if isinstance(self.axes, int):
536      if self.axes < 0:
537        axes = [self.axes % len(shape1), self.axes % len(shape2)]
538      else:
539        axes = [self.axes] * 2
540    else:
541      axes = self.axes
542    shape1.pop(axes[0])
543    shape2.pop(axes[1])
544    shape2.pop(0)
545    output_shape = shape1 + shape2
546    if len(output_shape) == 1:
547      output_shape += [1]
548    return tuple(output_shape)
549
550  def compute_mask(self, inputs, mask=None):
551    return None
552
553  def get_config(self):
554    config = {
555        'axes': self.axes,
556        'normalize': self.normalize,
557    }
558    base_config = super(Dot, self).get_config()
559    return dict(list(base_config.items()) + list(config.items()))
560
561
562@keras_export('keras.layers.add')
563def add(inputs, **kwargs):
564  """Functional interface to the `Add` layer.
565
566  Arguments:
567      inputs: A list of input tensors (at least 2).
568      **kwargs: Standard layer keyword arguments.
569
570  Returns:
571      A tensor, the sum of the inputs.
572
573  Examples:
574
575  ```python
576      import keras
577
578      input1 = keras.layers.Input(shape=(16,))
579      x1 = keras.layers.Dense(8, activation='relu')(input1)
580      input2 = keras.layers.Input(shape=(32,))
581      x2 = keras.layers.Dense(8, activation='relu')(input2)
582      added = keras.layers.add([x1, x2])
583
584      out = keras.layers.Dense(4)(added)
585      model = keras.models.Model(inputs=[input1, input2], outputs=out)
586  ```
587  """
588  return Add(**kwargs)(inputs)
589
590
591@keras_export('keras.layers.subtract')
592def subtract(inputs, **kwargs):
593  """Functional interface to the `Subtract` layer.
594
595  Arguments:
596      inputs: A list of input tensors (exactly 2).
597      **kwargs: Standard layer keyword arguments.
598
599  Returns:
600      A tensor, the difference of the inputs.
601
602  Examples:
603
604  ```python
605      import keras
606
607      input1 = keras.layers.Input(shape=(16,))
608      x1 = keras.layers.Dense(8, activation='relu')(input1)
609      input2 = keras.layers.Input(shape=(32,))
610      x2 = keras.layers.Dense(8, activation='relu')(input2)
611      subtracted = keras.layers.subtract([x1, x2])
612
613      out = keras.layers.Dense(4)(subtracted)
614      model = keras.models.Model(inputs=[input1, input2], outputs=out)
615  ```
616  """
617  return Subtract(**kwargs)(inputs)
618
619
620@keras_export('keras.layers.multiply')
621def multiply(inputs, **kwargs):
622  """Functional interface to the `Multiply` layer.
623
624  Arguments:
625      inputs: A list of input tensors (at least 2).
626      **kwargs: Standard layer keyword arguments.
627
628  Returns:
629      A tensor, the element-wise product of the inputs.
630  """
631  return Multiply(**kwargs)(inputs)
632
633
634@keras_export('keras.layers.average')
635def average(inputs, **kwargs):
636  """Functional interface to the `Average` layer.
637
638  Arguments:
639      inputs: A list of input tensors (at least 2).
640      **kwargs: Standard layer keyword arguments.
641
642  Returns:
643      A tensor, the average of the inputs.
644  """
645  return Average(**kwargs)(inputs)
646
647
648@keras_export('keras.layers.maximum')
649def maximum(inputs, **kwargs):
650  """Functional interface to the `Maximum` layer.
651
652  Arguments:
653      inputs: A list of input tensors (at least 2).
654      **kwargs: Standard layer keyword arguments.
655
656  Returns:
657      A tensor, the element-wise maximum of the inputs.
658  """
659  return Maximum(**kwargs)(inputs)
660
661
662@keras_export('keras.layers.minimum')
663def minimum(inputs, **kwargs):
664  """Functional interface to the `Minimum` layer.
665
666  Arguments:
667      inputs: A list of input tensors (at least 2).
668      **kwargs: Standard layer keyword arguments.
669
670  Returns:
671      A tensor, the element-wise minimum of the inputs.
672  """
673  return Minimum(**kwargs)(inputs)
674
675
676@keras_export('keras.layers.concatenate')
677def concatenate(inputs, axis=-1, **kwargs):
678  """Functional interface to the `Concatenate` layer.
679
680  Arguments:
681      inputs: A list of input tensors (at least 2).
682      axis: Concatenation axis.
683      **kwargs: Standard layer keyword arguments.
684
685  Returns:
686      A tensor, the concatenation of the inputs alongside axis `axis`.
687  """
688  return Concatenate(axis=axis, **kwargs)(inputs)
689
690
691@keras_export('keras.layers.dot')
692def dot(inputs, axes, normalize=False, **kwargs):
693  """Functional interface to the `Dot` layer.
694
695  Arguments:
696      inputs: A list of input tensors (at least 2).
697      axes: Integer or tuple of integers,
698          axis or axes along which to take the dot product.
699      normalize: Whether to L2-normalize samples along the
700          dot product axis before taking the dot product.
701          If set to True, then the output of the dot product
702          is the cosine proximity between the two samples.
703      **kwargs: Standard layer keyword arguments.
704
705  Returns:
706      A tensor, the dot product of the samples from the inputs.
707  """
708  return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)
709