1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the InputSpec class.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from six.moves import zip # pylint: disable=redefined-builtin 22 23from tensorflow.python.util import nest 24from tensorflow.python.util.tf_export import keras_export 25from tensorflow.python.util.tf_export import tf_export 26 27 28@keras_export('keras.layers.InputSpec', v1=['keras.layers.InputSpec']) 29@tf_export(v1=['layers.InputSpec']) 30class InputSpec(object): 31 """Specifies the ndim, dtype and shape of every input to a layer. 32 33 Every layer should expose (if appropriate) an `input_spec` attribute: 34 a list of instances of InputSpec (one per input tensor). 35 36 A None entry in a shape is compatible with any dimension, 37 a None shape is compatible with any shape. 38 39 Arguments: 40 dtype: Expected DataType of the input. 41 shape: Shape tuple, expected shape of the input 42 (may include None for unchecked axes). 43 ndim: Integer, expected rank of the input. 44 max_ndim: Integer, maximum rank of the input. 45 min_ndim: Integer, minimum rank of the input. 46 axes: Dictionary mapping integer axes to 47 a specific dimension value. 48 """ 49 50 def __init__(self, 51 dtype=None, 52 shape=None, 53 ndim=None, 54 max_ndim=None, 55 min_ndim=None, 56 axes=None): 57 self.dtype = dtype 58 self.shape = shape 59 if shape is not None: 60 self.ndim = len(shape) 61 else: 62 self.ndim = ndim 63 self.max_ndim = max_ndim 64 self.min_ndim = min_ndim 65 self.axes = axes or {} 66 67 def __repr__(self): 68 spec = [('dtype=' + str(self.dtype)) if self.dtype else '', 69 ('shape=' + str(self.shape)) if self.shape else '', 70 ('ndim=' + str(self.ndim)) if self.ndim else '', 71 ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '', 72 ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '', 73 ('axes=' + str(self.axes)) if self.axes else ''] 74 return 'InputSpec(%s)' % ', '.join(x for x in spec if x) 75 76 77def assert_input_compatibility(input_spec, inputs, layer_name): 78 """Checks compatibility between the layer and provided inputs. 79 80 This checks that the tensor(s) `inputs` verify the input assumptions 81 of a layer (if any). If not, a clear and actional exception gets raised. 82 83 Arguments: 84 input_spec: An InputSpec instance, or None. 85 inputs: Input tensor or list of input tensors. 86 layer_name: String, name of the layer (for error message formatting). 87 88 Raises: 89 ValueError: in case of mismatch between 90 the provided inputs and the expectations of the layer. 91 """ 92 if not input_spec: 93 return 94 if not isinstance(input_spec, (list, tuple)): 95 input_spec = nest.flatten(input_spec) 96 97 inputs = nest.flatten(inputs) 98 if len(inputs) != len(input_spec): 99 raise ValueError('Layer ' + layer_name + ' expects ' + 100 str(len(input_spec)) + ' inputs, ' 101 'but it received ' + str(len(inputs)) + 102 ' input tensors. Inputs received: ' + str(inputs)) 103 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): 104 if spec is None: 105 continue 106 107 if (spec.ndim is not None or 108 spec.min_ndim is not None or 109 spec.max_ndim is not None): 110 if x.shape.ndims is None: 111 raise ValueError('Input ' + str(input_index) + ' of layer ' + 112 layer_name + ' is incompatible with the layer: ' 113 'its rank is undefined, but the layer requires a ' 114 'defined rank.') 115 116 # Check ndim. 117 if spec.ndim is not None: 118 ndim = x.shape.ndims 119 if ndim != spec.ndim: 120 raise ValueError('Input ' + str(input_index) + ' of layer ' + 121 layer_name + ' is incompatible with the layer: ' 122 'expected ndim=' + str(spec.ndim) + ', found ndim=' + 123 str(ndim) + '. Full shape received: ' + 124 str(x.shape.as_list())) 125 if spec.max_ndim is not None: 126 ndim = x.shape.ndims 127 if ndim is not None and ndim > spec.max_ndim: 128 raise ValueError('Input ' + str(input_index) + ' of layer ' + 129 layer_name + ' is incompatible with the layer: ' 130 'expected max_ndim=' + str(spec.max_ndim) + 131 ', found ndim=' + str(ndim)) 132 if spec.min_ndim is not None: 133 ndim = x.shape.ndims 134 if ndim is not None and ndim < spec.min_ndim: 135 raise ValueError('Input ' + str(input_index) + ' of layer ' + 136 layer_name + ' is incompatible with the layer: ' 137 ': expected min_ndim=' + str(spec.min_ndim) + 138 ', found ndim=' + str(ndim) + 139 '. Full shape received: ' + 140 str(x.shape.as_list())) 141 # Check dtype. 142 if spec.dtype is not None: 143 if x.dtype != spec.dtype: 144 raise ValueError('Input ' + str(input_index) + ' of layer ' + 145 layer_name + ' is incompatible with the layer: ' 146 'expected dtype=' + str(spec.dtype) + 147 ', found dtype=' + str(x.dtype)) 148 # Check specific shape axes. 149 if spec.axes: 150 shape = x.shape.as_list() 151 if shape is not None: 152 for axis, value in spec.axes.items(): 153 if hasattr(value, 'value'): 154 value = value.value 155 if value is not None and shape[int(axis)] not in {value, None}: 156 raise ValueError( 157 'Input ' + str(input_index) + ' of layer ' + layer_name + ' is' 158 ' incompatible with the layer: expected axis ' + str(axis) + 159 ' of input shape to have value ' + str(value) + 160 ' but received input with shape ' + str(shape)) 161 # Check shape. 162 if spec.shape is not None: 163 shape = x.shape.as_list() 164 if shape is not None: 165 for spec_dim, dim in zip(spec.shape, shape): 166 if spec_dim is not None and dim is not None: 167 if spec_dim != dim: 168 raise ValueError('Input ' + str(input_index) + 169 ' is incompatible with layer ' + layer_name + 170 ': expected shape=' + str(spec.shape) + 171 ', found shape=' + str(shape)) 172