1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=not-callable 16# pylint: disable=redefined-builtin 17"""Layers that can merge several inputs into one. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras.engine.base_layer import Layer 25from tensorflow.python.keras.utils import tf_utils 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nn 29from tensorflow.python.util.tf_export import keras_export 30 31 32class _Merge(Layer): 33 """Generic merge layer for elementwise merge functions. 34 35 Used to implement `Sum`, `Average`, etc. 36 37 Arguments: 38 **kwargs: standard layer keyword arguments. 39 """ 40 41 def __init__(self, **kwargs): 42 super(_Merge, self).__init__(**kwargs) 43 self.supports_masking = True 44 45 def _merge_function(self, inputs): 46 raise NotImplementedError 47 48 def _compute_elemwise_op_output_shape(self, shape1, shape2): 49 """Computes the shape of the resultant of an elementwise operation. 50 51 Arguments: 52 shape1: tuple or None. Shape of the first tensor 53 shape2: tuple or None. Shape of the second tensor 54 55 Returns: 56 expected output shape when an element-wise operation is 57 carried out on 2 tensors with shapes shape1 and shape2. 58 tuple or None. 59 60 Raises: 61 ValueError: if shape1 and shape2 are not compatible for 62 element-wise operations. 63 """ 64 if None in [shape1, shape2]: 65 return None 66 elif len(shape1) < len(shape2): 67 return self._compute_elemwise_op_output_shape(shape2, shape1) 68 elif not shape2: 69 return shape1 70 output_shape = list(shape1[:-len(shape2)]) 71 for i, j in zip(shape1[-len(shape2):], shape2): 72 if i is None or j is None: 73 output_shape.append(None) 74 elif i == 1: 75 output_shape.append(j) 76 elif j == 1: 77 output_shape.append(i) 78 else: 79 if i != j: 80 raise ValueError( 81 'Operands could not be broadcast ' 82 'together with shapes ' + str(shape1) + ' ' + str(shape2)) 83 output_shape.append(i) 84 return tuple(output_shape) 85 86 @tf_utils.shape_type_conversion 87 def build(self, input_shape): 88 # Used purely for shape validation. 89 if not isinstance(input_shape, list): 90 raise ValueError('A merge layer should be called on a list of inputs.') 91 if len(input_shape) < 2: 92 raise ValueError('A merge layer should be called ' 93 'on a list of at least 2 inputs. ' 94 'Got ' + str(len(input_shape)) + ' inputs.') 95 batch_sizes = [s[0] for s in input_shape if s is not None] 96 batch_sizes = set(batch_sizes) 97 batch_sizes -= set([None]) 98 if len(batch_sizes) > 1: 99 raise ValueError( 100 'Can not merge tensors with different ' 101 'batch sizes. Got tensors with shapes : ' + str(input_shape)) 102 if input_shape[0] is None: 103 output_shape = None 104 else: 105 output_shape = input_shape[0][1:] 106 for i in range(1, len(input_shape)): 107 if input_shape[i] is None: 108 shape = None 109 else: 110 shape = input_shape[i][1:] 111 output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) 112 # If the inputs have different ranks, we have to reshape them 113 # to make them broadcastable. 114 if None not in input_shape and len(set(map(len, input_shape))) == 1: 115 self._reshape_required = False 116 else: 117 self._reshape_required = True 118 119 def call(self, inputs): 120 if not isinstance(inputs, list): 121 raise ValueError('A merge layer should be called on a list of inputs.') 122 if self._reshape_required: 123 reshaped_inputs = [] 124 input_ndims = list(map(K.ndim, inputs)) 125 if None not in input_ndims: 126 # If ranks of all inputs are available, 127 # we simply expand each of them at axis=1 128 # until all of them have the same rank. 129 max_ndim = max(input_ndims) 130 for x in inputs: 131 x_ndim = K.ndim(x) 132 for _ in range(max_ndim - x_ndim): 133 x = array_ops.expand_dims(x, axis=1) 134 reshaped_inputs.append(x) 135 return self._merge_function(reshaped_inputs) 136 else: 137 # Transpose all inputs so that batch size is the last dimension. 138 # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) 139 transposed = False 140 for x in inputs: 141 x_ndim = K.ndim(x) 142 if x_ndim is None: 143 x_shape = array_ops.shape(x) 144 batch_size = x_shape[0] 145 new_shape = K.concatenate( 146 [x_shape[1:], 147 array_ops.expand_dims(batch_size, axis=-1)]) 148 x_transposed = array_ops.reshape( 149 x, 150 array_ops.stack( 151 [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0)) 152 x_transposed = array_ops.transpose(x_transposed, perm=(1, 0)) 153 x_transposed = array_ops.reshape(x_transposed, new_shape) 154 reshaped_inputs.append(x_transposed) 155 transposed = True 156 elif x_ndim > 1: 157 dims = list(range(1, x_ndim)) + [0] 158 reshaped_inputs.append(array_ops.transpose(x, perm=dims)) 159 transposed = True 160 else: 161 # We don't transpose inputs if they are 1D vectors or scalars. 162 reshaped_inputs.append(x) 163 y = self._merge_function(reshaped_inputs) 164 y_ndim = K.ndim(y) 165 if transposed: 166 # If inputs have been transposed, we have to transpose the output too. 167 if y_ndim is None: 168 y_shape = array_ops.shape(y) 169 y_ndim = array_ops.shape(y_shape)[0] 170 batch_size = y_shape[y_ndim - 1] 171 new_shape = K.concatenate([ 172 array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1] 173 ]) 174 y = array_ops.reshape(y, (-1, batch_size)) 175 y = array_ops.transpose(y, perm=(1, 0)) 176 y = array_ops.reshape(y, new_shape) 177 elif y_ndim > 1: 178 dims = [y_ndim - 1] + list(range(y_ndim - 1)) 179 y = array_ops.transpose(y, perm=dims) 180 return y 181 else: 182 return self._merge_function(inputs) 183 184 @tf_utils.shape_type_conversion 185 def compute_output_shape(self, input_shape): 186 if input_shape[0] is None: 187 output_shape = None 188 else: 189 output_shape = input_shape[0][1:] 190 for i in range(1, len(input_shape)): 191 if input_shape[i] is None: 192 shape = None 193 else: 194 shape = input_shape[i][1:] 195 output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) 196 batch_sizes = [s[0] for s in input_shape if s is not None] 197 batch_sizes = set(batch_sizes) 198 batch_sizes -= set([None]) 199 if len(batch_sizes) == 1: 200 output_shape = (list(batch_sizes)[0],) + output_shape 201 else: 202 output_shape = (None,) + output_shape 203 return output_shape 204 205 def compute_mask(self, inputs, mask=None): 206 if mask is None: 207 return None 208 if not isinstance(mask, list): 209 raise ValueError('`mask` should be a list.') 210 if not isinstance(inputs, list): 211 raise ValueError('`inputs` should be a list.') 212 if len(mask) != len(inputs): 213 raise ValueError('The lists `inputs` and `mask` ' 214 'should have the same length.') 215 if all(m is None for m in mask): 216 return None 217 masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None] 218 return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False) 219 220 221@keras_export('keras.layers.Add') 222class Add(_Merge): 223 """Layer that adds a list of inputs. 224 225 It takes as input a list of tensors, 226 all of the same shape, and returns 227 a single tensor (also of the same shape). 228 229 Examples: 230 231 ```python 232 import keras 233 234 input1 = keras.layers.Input(shape=(16,)) 235 x1 = keras.layers.Dense(8, activation='relu')(input1) 236 input2 = keras.layers.Input(shape=(32,)) 237 x2 = keras.layers.Dense(8, activation='relu')(input2) 238 added = keras.layers.Add()([x1, x2]) # equivalent to added = 239 keras.layers.add([x1, x2]) 240 241 out = keras.layers.Dense(4)(added) 242 model = keras.models.Model(inputs=[input1, input2], outputs=out) 243 ``` 244 """ 245 246 def _merge_function(self, inputs): 247 output = inputs[0] 248 for i in range(1, len(inputs)): 249 output += inputs[i] 250 return output 251 252 253@keras_export('keras.layers.Subtract') 254class Subtract(_Merge): 255 """Layer that subtracts two inputs. 256 257 It takes as input a list of tensors of size 2, 258 both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]), 259 also of the same shape. 260 261 Examples: 262 263 ```python 264 import keras 265 266 input1 = keras.layers.Input(shape=(16,)) 267 x1 = keras.layers.Dense(8, activation='relu')(input1) 268 input2 = keras.layers.Input(shape=(32,)) 269 x2 = keras.layers.Dense(8, activation='relu')(input2) 270 # Equivalent to subtracted = keras.layers.subtract([x1, x2]) 271 subtracted = keras.layers.Subtract()([x1, x2]) 272 273 out = keras.layers.Dense(4)(subtracted) 274 model = keras.models.Model(inputs=[input1, input2], outputs=out) 275 ``` 276 """ 277 278 @tf_utils.shape_type_conversion 279 def build(self, input_shape): 280 super(Subtract, self).build(input_shape) 281 if len(input_shape) != 2: 282 raise ValueError('A `Subtract` layer should be called ' 283 'on exactly 2 inputs') 284 285 def _merge_function(self, inputs): 286 if len(inputs) != 2: 287 raise ValueError('A `Subtract` layer should be called ' 288 'on exactly 2 inputs') 289 return inputs[0] - inputs[1] 290 291 292@keras_export('keras.layers.Multiply') 293class Multiply(_Merge): 294 """Layer that multiplies (element-wise) a list of inputs. 295 296 It takes as input a list of tensors, 297 all of the same shape, and returns 298 a single tensor (also of the same shape). 299 """ 300 301 def _merge_function(self, inputs): 302 output = inputs[0] 303 for i in range(1, len(inputs)): 304 output *= inputs[i] 305 return output 306 307 308@keras_export('keras.layers.Average') 309class Average(_Merge): 310 """Layer that averages a list of inputs. 311 312 It takes as input a list of tensors, 313 all of the same shape, and returns 314 a single tensor (also of the same shape). 315 """ 316 317 def _merge_function(self, inputs): 318 output = inputs[0] 319 for i in range(1, len(inputs)): 320 output += inputs[i] 321 return output / len(inputs) 322 323 324@keras_export('keras.layers.Maximum') 325class Maximum(_Merge): 326 """Layer that computes the maximum (element-wise) a list of inputs. 327 328 It takes as input a list of tensors, 329 all of the same shape, and returns 330 a single tensor (also of the same shape). 331 """ 332 333 def _merge_function(self, inputs): 334 output = inputs[0] 335 for i in range(1, len(inputs)): 336 output = math_ops.maximum(output, inputs[i]) 337 return output 338 339 340@keras_export('keras.layers.Minimum') 341class Minimum(_Merge): 342 """Layer that computes the minimum (element-wise) a list of inputs. 343 344 It takes as input a list of tensors, 345 all of the same shape, and returns 346 a single tensor (also of the same shape). 347 """ 348 349 def _merge_function(self, inputs): 350 output = inputs[0] 351 for i in range(1, len(inputs)): 352 output = math_ops.minimum(output, inputs[i]) 353 return output 354 355 356@keras_export('keras.layers.Concatenate') 357class Concatenate(_Merge): 358 """Layer that concatenates a list of inputs. 359 360 It takes as input a list of tensors, 361 all of the same shape except for the concatenation axis, 362 and returns a single tensor, the concatenation of all inputs. 363 364 Arguments: 365 axis: Axis along which to concatenate. 366 **kwargs: standard layer keyword arguments. 367 """ 368 369 def __init__(self, axis=-1, **kwargs): 370 super(Concatenate, self).__init__(**kwargs) 371 self.axis = axis 372 self.supports_masking = True 373 self._reshape_required = False 374 375 @tf_utils.shape_type_conversion 376 def build(self, input_shape): 377 # Used purely for shape validation. 378 if not isinstance(input_shape, list) or len(input_shape) < 2: 379 raise ValueError('A `Concatenate` layer should be called ' 380 'on a list of at least 2 inputs') 381 if all(shape is None for shape in input_shape): 382 return 383 reduced_inputs_shapes = [list(shape) for shape in input_shape] 384 shape_set = set() 385 for i in range(len(reduced_inputs_shapes)): 386 del reduced_inputs_shapes[i][self.axis] 387 shape_set.add(tuple(reduced_inputs_shapes[i])) 388 if len(shape_set) > 1: 389 raise ValueError('A `Concatenate` layer requires ' 390 'inputs with matching shapes ' 391 'except for the concat axis. ' 392 'Got inputs shapes: %s' % (input_shape)) 393 394 def _merge_function(self, inputs): 395 return K.concatenate(inputs, axis=self.axis) 396 397 @tf_utils.shape_type_conversion 398 def compute_output_shape(self, input_shape): 399 if not isinstance(input_shape, list): 400 raise ValueError('A `Concatenate` layer should be called ' 401 'on a list of inputs.') 402 input_shapes = input_shape 403 output_shape = list(input_shapes[0]) 404 for shape in input_shapes[1:]: 405 if output_shape[self.axis] is None or shape[self.axis] is None: 406 output_shape[self.axis] = None 407 break 408 output_shape[self.axis] += shape[self.axis] 409 return tuple(output_shape) 410 411 def compute_mask(self, inputs, mask=None): 412 if mask is None: 413 return None 414 if not isinstance(mask, list): 415 raise ValueError('`mask` should be a list.') 416 if not isinstance(inputs, list): 417 raise ValueError('`inputs` should be a list.') 418 if len(mask) != len(inputs): 419 raise ValueError('The lists `inputs` and `mask` ' 420 'should have the same length.') 421 if all(m is None for m in mask): 422 return None 423 # Make a list of masks while making sure 424 # the dimensionality of each mask 425 # is the same as the corresponding input. 426 masks = [] 427 for input_i, mask_i in zip(inputs, mask): 428 if mask_i is None: 429 # Input is unmasked. Append all 1s to masks, 430 masks.append(array_ops.ones_like(input_i, dtype='bool')) 431 elif K.ndim(mask_i) < K.ndim(input_i): 432 # Mask is smaller than the input, expand it 433 masks.append(array_ops.expand_dims(mask_i, axis=-1)) 434 else: 435 masks.append(mask_i) 436 concatenated = K.concatenate(masks, axis=self.axis) 437 return K.all(concatenated, axis=-1, keepdims=False) 438 439 def get_config(self): 440 config = { 441 'axis': self.axis, 442 } 443 base_config = super(Concatenate, self).get_config() 444 return dict(list(base_config.items()) + list(config.items())) 445 446 447@keras_export('keras.layers.Dot') 448class Dot(_Merge): 449 """Layer that computes a dot product between samples in two tensors. 450 451 E.g. if applied to a list of two tensors `a` and `b` of shape 452 `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)` 453 where each entry `i` will be the dot product between 454 `a[i]` and `b[i]`. 455 456 Arguments: 457 axes: Integer or tuple of integers, 458 axis or axes along which to take the dot product. 459 normalize: Whether to L2-normalize samples along the 460 dot product axis before taking the dot product. 461 If set to True, then the output of the dot product 462 is the cosine proximity between the two samples. 463 **kwargs: Standard layer keyword arguments. 464 """ 465 466 def __init__(self, axes, normalize=False, **kwargs): 467 super(Dot, self).__init__(**kwargs) 468 if not isinstance(axes, int): 469 if not isinstance(axes, (list, tuple)): 470 raise TypeError('Invalid type for `axes` - ' 471 'should be a list or an int.') 472 if len(axes) != 2: 473 raise ValueError('Invalid format for `axes` - ' 474 'should contain two elements.') 475 if not isinstance(axes[0], int) or not isinstance(axes[1], int): 476 raise ValueError('Invalid format for `axes` - ' 477 'list elements should be "int".') 478 self.axes = axes 479 self.normalize = normalize 480 self.supports_masking = True 481 self._reshape_required = False 482 483 @tf_utils.shape_type_conversion 484 def build(self, input_shape): 485 # Used purely for shape validation. 486 if not isinstance(input_shape, list) or len(input_shape) != 2: 487 raise ValueError('A `Dot` layer should be called ' 488 'on a list of 2 inputs.') 489 shape1 = input_shape[0] 490 shape2 = input_shape[1] 491 if shape1 is None or shape2 is None: 492 return 493 if isinstance(self.axes, int): 494 if self.axes < 0: 495 axes = [self.axes % len(shape1), self.axes % len(shape2)] 496 else: 497 axes = [self.axes] * 2 498 else: 499 axes = self.axes 500 if shape1[axes[0]] != shape2[axes[1]]: 501 raise ValueError('Dimension incompatibility ' 502 '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) + 503 'Layer shapes: %s, %s' % (shape1, shape2)) 504 505 def _merge_function(self, inputs): 506 if len(inputs) != 2: 507 raise ValueError('A `Dot` layer should be called on exactly 2 inputs') 508 x1 = inputs[0] 509 x2 = inputs[1] 510 if isinstance(self.axes, int): 511 if self.axes < 0: 512 axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)] 513 else: 514 axes = [self.axes] * 2 515 else: 516 axes = [] 517 for i in range(len(self.axes)): 518 if self.axes[i] < 0: 519 axes.append(self.axes[i] % K.ndim(inputs[i])) 520 else: 521 axes.append(self.axes[i]) 522 if self.normalize: 523 x1 = nn.l2_normalize(x1, axis=axes[0]) 524 x2 = nn.l2_normalize(x2, axis=axes[1]) 525 output = K.batch_dot(x1, x2, axes) 526 return output 527 528 @tf_utils.shape_type_conversion 529 def compute_output_shape(self, input_shape): 530 if not isinstance(input_shape, list) or len(input_shape) != 2: 531 raise ValueError('A `Dot` layer should be called ' 532 'on a list of 2 inputs.') 533 shape1 = list(input_shape[0]) 534 shape2 = list(input_shape[1]) 535 if isinstance(self.axes, int): 536 if self.axes < 0: 537 axes = [self.axes % len(shape1), self.axes % len(shape2)] 538 else: 539 axes = [self.axes] * 2 540 else: 541 axes = self.axes 542 shape1.pop(axes[0]) 543 shape2.pop(axes[1]) 544 shape2.pop(0) 545 output_shape = shape1 + shape2 546 if len(output_shape) == 1: 547 output_shape += [1] 548 return tuple(output_shape) 549 550 def compute_mask(self, inputs, mask=None): 551 return None 552 553 def get_config(self): 554 config = { 555 'axes': self.axes, 556 'normalize': self.normalize, 557 } 558 base_config = super(Dot, self).get_config() 559 return dict(list(base_config.items()) + list(config.items())) 560 561 562@keras_export('keras.layers.add') 563def add(inputs, **kwargs): 564 """Functional interface to the `Add` layer. 565 566 Arguments: 567 inputs: A list of input tensors (at least 2). 568 **kwargs: Standard layer keyword arguments. 569 570 Returns: 571 A tensor, the sum of the inputs. 572 573 Examples: 574 575 ```python 576 import keras 577 578 input1 = keras.layers.Input(shape=(16,)) 579 x1 = keras.layers.Dense(8, activation='relu')(input1) 580 input2 = keras.layers.Input(shape=(32,)) 581 x2 = keras.layers.Dense(8, activation='relu')(input2) 582 added = keras.layers.add([x1, x2]) 583 584 out = keras.layers.Dense(4)(added) 585 model = keras.models.Model(inputs=[input1, input2], outputs=out) 586 ``` 587 """ 588 return Add(**kwargs)(inputs) 589 590 591@keras_export('keras.layers.subtract') 592def subtract(inputs, **kwargs): 593 """Functional interface to the `Subtract` layer. 594 595 Arguments: 596 inputs: A list of input tensors (exactly 2). 597 **kwargs: Standard layer keyword arguments. 598 599 Returns: 600 A tensor, the difference of the inputs. 601 602 Examples: 603 604 ```python 605 import keras 606 607 input1 = keras.layers.Input(shape=(16,)) 608 x1 = keras.layers.Dense(8, activation='relu')(input1) 609 input2 = keras.layers.Input(shape=(32,)) 610 x2 = keras.layers.Dense(8, activation='relu')(input2) 611 subtracted = keras.layers.subtract([x1, x2]) 612 613 out = keras.layers.Dense(4)(subtracted) 614 model = keras.models.Model(inputs=[input1, input2], outputs=out) 615 ``` 616 """ 617 return Subtract(**kwargs)(inputs) 618 619 620@keras_export('keras.layers.multiply') 621def multiply(inputs, **kwargs): 622 """Functional interface to the `Multiply` layer. 623 624 Arguments: 625 inputs: A list of input tensors (at least 2). 626 **kwargs: Standard layer keyword arguments. 627 628 Returns: 629 A tensor, the element-wise product of the inputs. 630 """ 631 return Multiply(**kwargs)(inputs) 632 633 634@keras_export('keras.layers.average') 635def average(inputs, **kwargs): 636 """Functional interface to the `Average` layer. 637 638 Arguments: 639 inputs: A list of input tensors (at least 2). 640 **kwargs: Standard layer keyword arguments. 641 642 Returns: 643 A tensor, the average of the inputs. 644 """ 645 return Average(**kwargs)(inputs) 646 647 648@keras_export('keras.layers.maximum') 649def maximum(inputs, **kwargs): 650 """Functional interface to the `Maximum` layer. 651 652 Arguments: 653 inputs: A list of input tensors (at least 2). 654 **kwargs: Standard layer keyword arguments. 655 656 Returns: 657 A tensor, the element-wise maximum of the inputs. 658 """ 659 return Maximum(**kwargs)(inputs) 660 661 662@keras_export('keras.layers.minimum') 663def minimum(inputs, **kwargs): 664 """Functional interface to the `Minimum` layer. 665 666 Arguments: 667 inputs: A list of input tensors (at least 2). 668 **kwargs: Standard layer keyword arguments. 669 670 Returns: 671 A tensor, the element-wise minimum of the inputs. 672 """ 673 return Minimum(**kwargs)(inputs) 674 675 676@keras_export('keras.layers.concatenate') 677def concatenate(inputs, axis=-1, **kwargs): 678 """Functional interface to the `Concatenate` layer. 679 680 Arguments: 681 inputs: A list of input tensors (at least 2). 682 axis: Concatenation axis. 683 **kwargs: Standard layer keyword arguments. 684 685 Returns: 686 A tensor, the concatenation of the inputs alongside axis `axis`. 687 """ 688 return Concatenate(axis=axis, **kwargs)(inputs) 689 690 691@keras_export('keras.layers.dot') 692def dot(inputs, axes, normalize=False, **kwargs): 693 """Functional interface to the `Dot` layer. 694 695 Arguments: 696 inputs: A list of input tensors (at least 2). 697 axes: Integer or tuple of integers, 698 axis or axes along which to take the dot product. 699 normalize: Whether to L2-normalize samples along the 700 dot product axis before taking the dot product. 701 If set to True, then the output of the dot product 702 is the cosine proximity between the two samples. 703 **kwargs: Standard layer keyword arguments. 704 705 Returns: 706 A tensor, the dot product of the samples from the inputs. 707 """ 708 return Dot(axes=axes, normalize=normalize, **kwargs)(inputs) 709