1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities used by convolution layers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import itertools 21 22import numpy as np 23from six.moves import range # pylint: disable=redefined-builtin 24 25from tensorflow.python.keras import backend 26 27 28def convert_data_format(data_format, ndim): 29 if data_format == 'channels_last': 30 if ndim == 3: 31 return 'NWC' 32 elif ndim == 4: 33 return 'NHWC' 34 elif ndim == 5: 35 return 'NDHWC' 36 else: 37 raise ValueError('Input rank not supported:', ndim) 38 elif data_format == 'channels_first': 39 if ndim == 3: 40 return 'NCW' 41 elif ndim == 4: 42 return 'NCHW' 43 elif ndim == 5: 44 return 'NCDHW' 45 else: 46 raise ValueError('Input rank not supported:', ndim) 47 else: 48 raise ValueError('Invalid data_format:', data_format) 49 50 51def normalize_tuple(value, n, name): 52 """Transforms a single integer or iterable of integers into an integer tuple. 53 54 Arguments: 55 value: The value to validate and convert. Could an int, or any iterable of 56 ints. 57 n: The size of the tuple to be returned. 58 name: The name of the argument being validated, e.g. "strides" or 59 "kernel_size". This is only used to format error messages. 60 61 Returns: 62 A tuple of n integers. 63 64 Raises: 65 ValueError: If something else than an int/long or iterable thereof was 66 passed. 67 """ 68 if isinstance(value, int): 69 return (value,) * n 70 else: 71 try: 72 value_tuple = tuple(value) 73 except TypeError: 74 raise ValueError('The `' + name + '` argument must be a tuple of ' + 75 str(n) + ' integers. Received: ' + str(value)) 76 if len(value_tuple) != n: 77 raise ValueError('The `' + name + '` argument must be a tuple of ' + 78 str(n) + ' integers. Received: ' + str(value)) 79 for single_value in value_tuple: 80 try: 81 int(single_value) 82 except (ValueError, TypeError): 83 raise ValueError('The `' + name + '` argument must be a tuple of ' + 84 str(n) + ' integers. Received: ' + str(value) + ' ' 85 'including element ' + str(single_value) + ' of type' + 86 ' ' + str(type(single_value))) 87 return value_tuple 88 89 90def conv_output_length(input_length, filter_size, padding, stride, dilation=1): 91 """Determines output length of a convolution given input length. 92 93 Arguments: 94 input_length: integer. 95 filter_size: integer. 96 padding: one of "same", "valid", "full", "causal" 97 stride: integer. 98 dilation: dilation rate, integer. 99 100 Returns: 101 The output length (integer). 102 """ 103 if input_length is None: 104 return None 105 assert padding in {'same', 'valid', 'full', 'causal'} 106 dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) 107 if padding in ['same', 'causal']: 108 output_length = input_length 109 elif padding == 'valid': 110 output_length = input_length - dilated_filter_size + 1 111 elif padding == 'full': 112 output_length = input_length + dilated_filter_size - 1 113 return (output_length + stride - 1) // stride 114 115 116def conv_input_length(output_length, filter_size, padding, stride): 117 """Determines input length of a convolution given output length. 118 119 Arguments: 120 output_length: integer. 121 filter_size: integer. 122 padding: one of "same", "valid", "full". 123 stride: integer. 124 125 Returns: 126 The input length (integer). 127 """ 128 if output_length is None: 129 return None 130 assert padding in {'same', 'valid', 'full'} 131 if padding == 'same': 132 pad = filter_size // 2 133 elif padding == 'valid': 134 pad = 0 135 elif padding == 'full': 136 pad = filter_size - 1 137 return (output_length - 1) * stride - 2 * pad + filter_size 138 139 140def deconv_output_length(input_length, 141 filter_size, 142 padding, 143 output_padding=None, 144 stride=0, 145 dilation=1): 146 """Determines output length of a transposed convolution given input length. 147 148 Arguments: 149 input_length: Integer. 150 filter_size: Integer. 151 padding: one of `"same"`, `"valid"`, `"full"`. 152 output_padding: Integer, amount of padding along the output dimension. Can 153 be set to `None` in which case the output length is inferred. 154 stride: Integer. 155 dilation: Integer. 156 157 Returns: 158 The output length (integer). 159 """ 160 assert padding in {'same', 'valid', 'full'} 161 if input_length is None: 162 return None 163 164 # Get the dilated kernel size 165 filter_size = filter_size + (filter_size - 1) * (dilation - 1) 166 167 # Infer length if output padding is None, else compute the exact length 168 if output_padding is None: 169 if padding == 'valid': 170 length = input_length * stride + max(filter_size - stride, 0) 171 elif padding == 'full': 172 length = input_length * stride - (stride + filter_size - 2) 173 elif padding == 'same': 174 length = input_length * stride 175 176 else: 177 if padding == 'same': 178 pad = filter_size // 2 179 elif padding == 'valid': 180 pad = 0 181 elif padding == 'full': 182 pad = filter_size - 1 183 184 length = ((input_length - 1) * stride + filter_size - 2 * pad + 185 output_padding) 186 return length 187 188 189def normalize_data_format(value): 190 if value is None: 191 value = backend.image_data_format() 192 data_format = value.lower() 193 if data_format not in {'channels_first', 'channels_last'}: 194 raise ValueError('The `data_format` argument must be one of ' 195 '"channels_first", "channels_last". Received: ' + 196 str(value)) 197 return data_format 198 199 200def normalize_padding(value): 201 if isinstance(value, (list, tuple)): 202 return value 203 padding = value.lower() 204 if padding not in {'valid', 'same', 'causal'}: 205 raise ValueError('The `padding` argument must be a list/tuple or one of ' 206 '"valid", "same" (or "causal", only for `Conv1D). ' 207 'Received: ' + str(padding)) 208 return padding 209 210 211def convert_kernel(kernel): 212 """Converts a Numpy kernel matrix from Theano format to TensorFlow format. 213 214 Also works reciprocally, since the transformation is its own inverse. 215 216 This is used for converting legacy Theano-saved model files. 217 218 Arguments: 219 kernel: Numpy array (3D, 4D or 5D). 220 221 Returns: 222 The converted kernel. 223 224 Raises: 225 ValueError: in case of invalid kernel shape or invalid data_format. 226 """ 227 kernel = np.asarray(kernel) 228 if not 3 <= kernel.ndim <= 5: 229 raise ValueError('Invalid kernel shape:', kernel.shape) 230 slices = [slice(None, None, -1) for _ in range(kernel.ndim)] 231 no_flip = (slice(None, None), slice(None, None)) 232 slices[-2:] = no_flip 233 return np.copy(kernel[slices]) 234 235 236def conv_kernel_mask(input_shape, kernel_shape, strides, padding): 237 """Compute a mask representing the connectivity of a convolution operation. 238 239 Assume a convolution with given parameters is applied to an input having N 240 spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an 241 output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array 242 of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries 243 indicating pairs of input and output locations that are connected by a weight. 244 245 Example: 246 247 >>> input_shape = (4,) 248 >>> kernel_shape = (2,) 249 >>> strides = (1,) 250 >>> padding = "valid" 251 >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding) 252 array([[ True, False, False], 253 [ True, True, False], 254 [False, True, True], 255 [False, False, True]]) 256 257 where rows and columns correspond to inputs and outputs respectively. 258 259 260 Args: 261 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 262 input. 263 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 264 receptive field. 265 strides: tuple of size N, strides along each spatial dimension. 266 padding: type of padding, string `"same"` or `"valid"`. 267 268 Returns: 269 A boolean 2N-D `np.ndarray` of shape 270 `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)` 271 is the spatial shape of the output. `True` entries in the mask represent 272 pairs of input-output locations that are connected by a weight. 273 274 Raises: 275 ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the 276 same number of dimensions. 277 NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}. 278 """ 279 if padding not in {'same', 'valid'}: 280 raise NotImplementedError('Padding type %s not supported. ' 281 'Only "valid" and "same" ' 282 'are implemented.' % padding) 283 284 in_dims = len(input_shape) 285 if isinstance(kernel_shape, int): 286 kernel_shape = (kernel_shape,) * in_dims 287 if isinstance(strides, int): 288 strides = (strides,) * in_dims 289 290 kernel_dims = len(kernel_shape) 291 stride_dims = len(strides) 292 if kernel_dims != in_dims or stride_dims != in_dims: 293 raise ValueError('Number of strides, input and kernel dimensions must all ' 294 'match. Received: %d, %d, %d.' % 295 (stride_dims, in_dims, kernel_dims)) 296 297 output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding) 298 299 mask_shape = input_shape + output_shape 300 mask = np.zeros(mask_shape, np.bool) 301 302 output_axes_ticks = [range(dim) for dim in output_shape] 303 for output_position in itertools.product(*output_axes_ticks): 304 input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape, 305 output_position, strides, padding) 306 for input_position in itertools.product(*input_axes_ticks): 307 mask[input_position + output_position] = True 308 309 return mask 310 311 312def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in, 313 filters_out, data_format): 314 """Yields output-input tuples of indices in a CNN layer. 315 316 The generator iterates over all `(output_idx, input_idx)` tuples, where 317 `output_idx` is an integer index in a flattened tensor representing a single 318 output image of a convolutional layer that is connected (via the layer 319 weights) to the respective single input image at `input_idx` 320 321 Example: 322 323 >>> input_shape = (2, 2) 324 >>> kernel_shape = (2, 1) 325 >>> strides = (1, 1) 326 >>> padding = "valid" 327 >>> filters_in = 1 328 >>> filters_out = 1 329 >>> data_format = "channels_last" 330 >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding, 331 ... filters_in, filters_out, data_format)) 332 [(0, 0), (0, 2), (1, 1), (1, 3)] 333 334 Args: 335 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 336 input. 337 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 338 receptive field. 339 strides: tuple of size N, strides along each spatial dimension. 340 padding: type of padding, string `"same"` or `"valid"`. 341 filters_in: `int`, number if filters in the input to the layer. 342 filters_out: `int', number if filters in the output of the layer. 343 data_format: string, "channels_first" or "channels_last". 344 345 Yields: 346 The next tuple `(output_idx, input_idx)`, where 347 `output_idx` is an integer index in a flattened tensor representing a single 348 output image of a convolutional layer that is connected (via the layer 349 weights) to the respective single input image at `input_idx`. 350 351 Raises: 352 ValueError: if `data_format` is neither 353 `"channels_last"` nor `"channels_first"`, or if number of strides, input, 354 and kernel number of dimensions do not match. 355 356 NotImplementedError: if `padding` is neither `"same"` nor `"valid"`. 357 """ 358 if padding not in ('same', 'valid'): 359 raise NotImplementedError('Padding type %s not supported. ' 360 'Only "valid" and "same" ' 361 'are implemented.' % padding) 362 363 in_dims = len(input_shape) 364 if isinstance(kernel_shape, int): 365 kernel_shape = (kernel_shape,) * in_dims 366 if isinstance(strides, int): 367 strides = (strides,) * in_dims 368 369 kernel_dims = len(kernel_shape) 370 stride_dims = len(strides) 371 if kernel_dims != in_dims or stride_dims != in_dims: 372 raise ValueError('Number of strides, input and kernel dimensions must all ' 373 'match. Received: %d, %d, %d.' % 374 (stride_dims, in_dims, kernel_dims)) 375 376 output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding) 377 output_axes_ticks = [range(dim) for dim in output_shape] 378 379 if data_format == 'channels_first': 380 concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx 381 elif data_format == 'channels_last': 382 concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,) 383 else: 384 raise ValueError('Data format %s not recignized.' 385 '`data_format` must be "channels_first" or ' 386 '"channels_last".' % data_format) 387 388 for output_position in itertools.product(*output_axes_ticks): 389 input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape, 390 output_position, strides, padding) 391 for input_position in itertools.product(*input_axes_ticks): 392 for f_in in range(filters_in): 393 for f_out in range(filters_out): 394 out_idx = np.ravel_multi_index( 395 multi_index=concat_idxs(output_position, f_out), 396 dims=concat_idxs(output_shape, filters_out)) 397 in_idx = np.ravel_multi_index( 398 multi_index=concat_idxs(input_position, f_in), 399 dims=concat_idxs(input_shape, filters_in)) 400 yield (out_idx, in_idx) 401 402 403def conv_connected_inputs(input_shape, kernel_shape, output_position, strides, 404 padding): 405 """Return locations of the input connected to an output position. 406 407 Assume a convolution with given parameters is applied to an input having N 408 spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method 409 returns N ranges specifying the input region that was convolved with the 410 kernel to produce the output at position 411 `output_position = (p_out1, ..., p_outN)`. 412 413 Example: 414 415 >>> input_shape = (4, 4) 416 >>> kernel_shape = (2, 1) 417 >>> output_position = (1, 1) 418 >>> strides = (1, 1) 419 >>> padding = "valid" 420 >>> conv_connected_inputs(input_shape, kernel_shape, output_position, 421 ... strides, padding) 422 [range(1, 3), range(1, 2)] 423 424 Args: 425 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 426 input. 427 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 428 receptive field. 429 output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position 430 in the output of the convolution. 431 strides: tuple of size N, strides along each spatial dimension. 432 padding: type of padding, string `"same"` or `"valid"`. 433 434 Returns: 435 N ranges `[[p_in_left1, ..., p_in_right1], ..., 436 [p_in_leftN, ..., p_in_rightN]]` specifying the region in the 437 input connected to output_position. 438 """ 439 ranges = [] 440 441 ndims = len(input_shape) 442 for d in range(ndims): 443 left_shift = int(kernel_shape[d] / 2) 444 right_shift = kernel_shape[d] - left_shift 445 446 center = output_position[d] * strides[d] 447 448 if padding == 'valid': 449 center += left_shift 450 451 start = max(0, center - left_shift) 452 end = min(input_shape[d], center + right_shift) 453 454 ranges.append(range(start, end)) 455 456 return ranges 457 458 459def conv_output_shape(input_shape, kernel_shape, strides, padding): 460 """Return the output shape of an N-D convolution. 461 462 Forces dimensions where input is empty (size 0) to remain empty. 463 464 Args: 465 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 466 input. 467 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 468 receptive field. 469 strides: tuple of size N, strides along each spatial dimension. 470 padding: type of padding, string `"same"` or `"valid"`. 471 472 Returns: 473 tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output. 474 """ 475 dims = range(len(kernel_shape)) 476 output_shape = [ 477 conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d]) 478 for d in dims 479 ] 480 output_shape = tuple( 481 [0 if input_shape[d] == 0 else output_shape[d] for d in dims]) 482 return output_shape 483