1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the InputSpec class.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from six.moves import zip # pylint: disable=redefined-builtin 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_spec 26from tensorflow.python.keras import backend 27from tensorflow.python.util import nest 28from tensorflow.python.util.tf_export import keras_export 29from tensorflow.python.util.tf_export import tf_export 30 31 32@keras_export('keras.layers.InputSpec') 33@tf_export(v1=['layers.InputSpec']) 34class InputSpec(object): 35 """Specifies the rank, dtype and shape of every input to a layer. 36 37 Layers can expose (if appropriate) an `input_spec` attribute: 38 an instance of `InputSpec`, or a nested structure of `InputSpec` instances 39 (one per input tensor). These objects enable the layer to run input 40 compatibility checks for input structure, input rank, input shape, and 41 input dtype. 42 43 A None entry in a shape is compatible with any dimension, 44 a None shape is compatible with any shape. 45 46 Args: 47 dtype: Expected DataType of the input. 48 shape: Shape tuple, expected shape of the input 49 (may include None for unchecked axes). Includes the batch size. 50 ndim: Integer, expected rank of the input. 51 max_ndim: Integer, maximum rank of the input. 52 min_ndim: Integer, minimum rank of the input. 53 axes: Dictionary mapping integer axes to 54 a specific dimension value. 55 allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long 56 as the last axis of the input is 1, as well as inputs of rank N-1 57 as long as the last axis of the spec is 1. 58 name: Expected key corresponding to this input when passing data as 59 a dictionary. 60 61 Example: 62 63 ```python 64 class MyLayer(Layer): 65 def __init__(self): 66 super(MyLayer, self).__init__() 67 # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1) 68 # and raise an appropriate error message otherwise. 69 self.input_spec = InputSpec( 70 shape=(None, 28, 28, 1), 71 allow_last_axis_squeeze=True) 72 ``` 73 """ 74 75 def __init__(self, 76 dtype=None, 77 shape=None, 78 ndim=None, 79 max_ndim=None, 80 min_ndim=None, 81 axes=None, 82 allow_last_axis_squeeze=False, 83 name=None): 84 self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None 85 shape = tensor_shape.TensorShape(shape) 86 if shape.rank is None: 87 shape = None 88 else: 89 shape = tuple(shape.as_list()) 90 if shape is not None: 91 self.ndim = len(shape) 92 self.shape = shape 93 else: 94 self.ndim = ndim 95 self.shape = None 96 self.max_ndim = max_ndim 97 self.min_ndim = min_ndim 98 self.name = name 99 self.allow_last_axis_squeeze = allow_last_axis_squeeze 100 try: 101 axes = axes or {} 102 self.axes = {int(k): axes[k] for k in axes} 103 except (ValueError, TypeError): 104 raise TypeError('The keys in axes must be integers.') 105 106 if self.axes and (self.ndim is not None or self.max_ndim is not None): 107 max_dim = (self.ndim if self.ndim else self.max_ndim) - 1 108 max_axis = max(self.axes) 109 if max_axis > max_dim: 110 raise ValueError('Axis {} is greater than the maximum allowed value: {}' 111 .format(max_axis, max_dim)) 112 113 def __repr__(self): 114 spec = [('dtype=' + str(self.dtype)) if self.dtype else '', 115 ('shape=' + str(self.shape)) if self.shape else '', 116 ('ndim=' + str(self.ndim)) if self.ndim else '', 117 ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '', 118 ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '', 119 ('axes=' + str(self.axes)) if self.axes else ''] 120 return 'InputSpec(%s)' % ', '.join(x for x in spec if x) 121 122 def get_config(self): 123 return { 124 'dtype': self.dtype, 125 'shape': self.shape, 126 'ndim': self.ndim, 127 'max_ndim': self.max_ndim, 128 'min_ndim': self.min_ndim, 129 'axes': self.axes} 130 131 @classmethod 132 def from_config(cls, config): 133 return cls(**config) 134 135 136def to_tensor_shape(spec): 137 """Returns a tf.TensorShape object that matches the shape specifications. 138 139 If the InputSpec's shape or ndim is defined, this method will return a fully 140 or partially-known shape. Otherwise, the returned TensorShape is None. 141 142 Args: 143 spec: an InputSpec object. 144 145 Returns: 146 a tf.TensorShape object 147 """ 148 if spec.ndim is None and spec.shape is None: 149 return tensor_shape.TensorShape(None) 150 elif spec.shape is not None: 151 return tensor_shape.TensorShape(spec.shape) 152 else: 153 shape = [None] * spec.ndim 154 for a in spec.axes: 155 shape[a] = spec.axes[a] # Assume that axes is defined 156 return tensor_shape.TensorShape(shape) 157 158 159def assert_input_compatibility(input_spec, inputs, layer_name): 160 """Checks compatibility between the layer and provided inputs. 161 162 This checks that the tensor(s) `inputs` verify the input assumptions 163 of a layer (if any). If not, a clear and actional exception gets raised. 164 165 Args: 166 input_spec: An InputSpec instance, list of InputSpec instances, a nested 167 structure of InputSpec instances, or None. 168 inputs: Input tensor, list of input tensors, or a nested structure of 169 input tensors. 170 layer_name: String, name of the layer (for error message formatting). 171 172 Raises: 173 ValueError: in case of mismatch between 174 the provided inputs and the expectations of the layer. 175 """ 176 if not input_spec: 177 return 178 179 input_spec = nest.flatten(input_spec) 180 if isinstance(inputs, dict): 181 # Flatten `inputs` by reference order if input spec names are provided 182 names = [spec.name for spec in input_spec] 183 if all(names): 184 list_inputs = [] 185 for name in names: 186 if name not in inputs: 187 raise ValueError('Missing data for input "%s". ' 188 'You passed a data dictionary with keys %s. ' 189 'Expected the following keys: %s' % 190 (name, list(inputs.keys()), names)) 191 list_inputs.append(inputs[name]) 192 inputs = list_inputs 193 194 inputs = nest.flatten(inputs) 195 for x in inputs: 196 # Having a shape/dtype is the only commonality of the various tensor-like 197 # objects that may be passed. The most common kind of invalid type we are 198 # guarding for is a Layer instance (Functional API), which does not 199 # have a `shape` attribute. 200 if not hasattr(x, 'shape'): 201 raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,)) 202 203 if len(inputs) != len(input_spec): 204 raise ValueError('Layer ' + layer_name + ' expects ' + 205 str(len(input_spec)) + ' input(s), ' 206 'but it received ' + str(len(inputs)) + 207 ' input tensors. Inputs received: ' + str(inputs)) 208 for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): 209 if spec is None: 210 continue 211 212 shape = tensor_shape.TensorShape(x.shape) 213 if shape.rank is None: 214 return 215 # Check ndim. 216 if spec.ndim is not None and not spec.allow_last_axis_squeeze: 217 ndim = shape.rank 218 if ndim != spec.ndim: 219 raise ValueError('Input ' + str(input_index) + ' of layer ' + 220 layer_name + ' is incompatible with the layer: ' 221 'expected ndim=' + str(spec.ndim) + ', found ndim=' + 222 str(ndim) + '. Full shape received: ' + 223 str(tuple(shape))) 224 if spec.max_ndim is not None: 225 ndim = x.shape.rank 226 if ndim is not None and ndim > spec.max_ndim: 227 raise ValueError('Input ' + str(input_index) + ' of layer ' + 228 layer_name + ' is incompatible with the layer: ' 229 'expected max_ndim=' + str(spec.max_ndim) + 230 ', found ndim=' + str(ndim)) 231 if spec.min_ndim is not None: 232 ndim = x.shape.rank 233 if ndim is not None and ndim < spec.min_ndim: 234 raise ValueError('Input ' + str(input_index) + ' of layer ' + 235 layer_name + ' is incompatible with the layer: ' 236 ': expected min_ndim=' + str(spec.min_ndim) + 237 ', found ndim=' + str(ndim) + 238 '. Full shape received: ' + 239 str(tuple(shape))) 240 # Check dtype. 241 if spec.dtype is not None: 242 if x.dtype.name != spec.dtype: 243 raise ValueError('Input ' + str(input_index) + ' of layer ' + 244 layer_name + ' is incompatible with the layer: ' 245 'expected dtype=' + str(spec.dtype) + 246 ', found dtype=' + str(x.dtype)) 247 248 # Check specific shape axes. 249 shape_as_list = shape.as_list() 250 if spec.axes: 251 for axis, value in spec.axes.items(): 252 if hasattr(value, 'value'): 253 value = value.value 254 if value is not None and shape_as_list[int(axis)] not in {value, None}: 255 raise ValueError( 256 'Input ' + str(input_index) + ' of layer ' + layer_name + ' is' 257 ' incompatible with the layer: expected axis ' + str(axis) + 258 ' of input shape to have value ' + str(value) + 259 ' but received input with shape ' + display_shape(x.shape)) 260 # Check shape. 261 if spec.shape is not None and shape.rank is not None: 262 spec_shape = spec.shape 263 if spec.allow_last_axis_squeeze: 264 if shape_as_list and shape_as_list[-1] == 1: 265 shape_as_list = shape_as_list[:-1] 266 if spec_shape and spec_shape[-1] == 1: 267 spec_shape = spec_shape[:-1] 268 for spec_dim, dim in zip(spec_shape, shape_as_list): 269 if spec_dim is not None and dim is not None: 270 if spec_dim != dim: 271 raise ValueError('Input ' + str(input_index) + 272 ' is incompatible with layer ' + layer_name + 273 ': expected shape=' + str(spec.shape) + 274 ', found shape=' + display_shape(x.shape)) 275 276 277def display_shape(shape): 278 return str(tuple(shape.as_list())) 279 280 281def to_tensor_spec(input_spec, default_dtype=None): 282 """Converts a Keras InputSpec object to a TensorSpec.""" 283 default_dtype = default_dtype or backend.floatx() 284 if isinstance(input_spec, InputSpec): 285 dtype = input_spec.dtype or default_dtype 286 return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype) 287 return tensor_spec.TensorSpec(None, default_dtype) 288