• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities used by convolution layers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import itertools
21
22import numpy as np
23from six.moves import range  # pylint: disable=redefined-builtin
24
25from tensorflow.python.keras import backend
26
27
28def convert_data_format(data_format, ndim):
29  if data_format == 'channels_last':
30    if ndim == 3:
31      return 'NWC'
32    elif ndim == 4:
33      return 'NHWC'
34    elif ndim == 5:
35      return 'NDHWC'
36    else:
37      raise ValueError('Input rank not supported:', ndim)
38  elif data_format == 'channels_first':
39    if ndim == 3:
40      return 'NCW'
41    elif ndim == 4:
42      return 'NCHW'
43    elif ndim == 5:
44      return 'NCDHW'
45    else:
46      raise ValueError('Input rank not supported:', ndim)
47  else:
48    raise ValueError('Invalid data_format:', data_format)
49
50
51def normalize_tuple(value, n, name):
52  """Transforms a single integer or iterable of integers into an integer tuple.
53
54  Arguments:
55    value: The value to validate and convert. Could an int, or any iterable of
56      ints.
57    n: The size of the tuple to be returned.
58    name: The name of the argument being validated, e.g. "strides" or
59      "kernel_size". This is only used to format error messages.
60
61  Returns:
62    A tuple of n integers.
63
64  Raises:
65    ValueError: If something else than an int/long or iterable thereof was
66      passed.
67  """
68  if isinstance(value, int):
69    return (value,) * n
70  else:
71    try:
72      value_tuple = tuple(value)
73    except TypeError:
74      raise ValueError('The `' + name + '` argument must be a tuple of ' +
75                       str(n) + ' integers. Received: ' + str(value))
76    if len(value_tuple) != n:
77      raise ValueError('The `' + name + '` argument must be a tuple of ' +
78                       str(n) + ' integers. Received: ' + str(value))
79    for single_value in value_tuple:
80      try:
81        int(single_value)
82      except (ValueError, TypeError):
83        raise ValueError('The `' + name + '` argument must be a tuple of ' +
84                         str(n) + ' integers. Received: ' + str(value) + ' '
85                         'including element ' + str(single_value) + ' of type' +
86                         ' ' + str(type(single_value)))
87    return value_tuple
88
89
90def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
91  """Determines output length of a convolution given input length.
92
93  Arguments:
94      input_length: integer.
95      filter_size: integer.
96      padding: one of "same", "valid", "full", "causal"
97      stride: integer.
98      dilation: dilation rate, integer.
99
100  Returns:
101      The output length (integer).
102  """
103  if input_length is None:
104    return None
105  assert padding in {'same', 'valid', 'full', 'causal'}
106  dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
107  if padding in ['same', 'causal']:
108    output_length = input_length
109  elif padding == 'valid':
110    output_length = input_length - dilated_filter_size + 1
111  elif padding == 'full':
112    output_length = input_length + dilated_filter_size - 1
113  return (output_length + stride - 1) // stride
114
115
116def conv_input_length(output_length, filter_size, padding, stride):
117  """Determines input length of a convolution given output length.
118
119  Arguments:
120      output_length: integer.
121      filter_size: integer.
122      padding: one of "same", "valid", "full".
123      stride: integer.
124
125  Returns:
126      The input length (integer).
127  """
128  if output_length is None:
129    return None
130  assert padding in {'same', 'valid', 'full'}
131  if padding == 'same':
132    pad = filter_size // 2
133  elif padding == 'valid':
134    pad = 0
135  elif padding == 'full':
136    pad = filter_size - 1
137  return (output_length - 1) * stride - 2 * pad + filter_size
138
139
140def deconv_output_length(input_length,
141                         filter_size,
142                         padding,
143                         output_padding=None,
144                         stride=0,
145                         dilation=1):
146  """Determines output length of a transposed convolution given input length.
147
148  Arguments:
149      input_length: Integer.
150      filter_size: Integer.
151      padding: one of `"same"`, `"valid"`, `"full"`.
152      output_padding: Integer, amount of padding along the output dimension. Can
153        be set to `None` in which case the output length is inferred.
154      stride: Integer.
155      dilation: Integer.
156
157  Returns:
158      The output length (integer).
159  """
160  assert padding in {'same', 'valid', 'full'}
161  if input_length is None:
162    return None
163
164  # Get the dilated kernel size
165  filter_size = filter_size + (filter_size - 1) * (dilation - 1)
166
167  # Infer length if output padding is None, else compute the exact length
168  if output_padding is None:
169    if padding == 'valid':
170      length = input_length * stride + max(filter_size - stride, 0)
171    elif padding == 'full':
172      length = input_length * stride - (stride + filter_size - 2)
173    elif padding == 'same':
174      length = input_length * stride
175
176  else:
177    if padding == 'same':
178      pad = filter_size // 2
179    elif padding == 'valid':
180      pad = 0
181    elif padding == 'full':
182      pad = filter_size - 1
183
184    length = ((input_length - 1) * stride + filter_size - 2 * pad +
185              output_padding)
186  return length
187
188
189def normalize_data_format(value):
190  if value is None:
191    value = backend.image_data_format()
192  data_format = value.lower()
193  if data_format not in {'channels_first', 'channels_last'}:
194    raise ValueError('The `data_format` argument must be one of '
195                     '"channels_first", "channels_last". Received: ' +
196                     str(value))
197  return data_format
198
199
200def normalize_padding(value):
201  if isinstance(value, (list, tuple)):
202    return value
203  padding = value.lower()
204  if padding not in {'valid', 'same', 'causal'}:
205    raise ValueError('The `padding` argument must be a list/tuple or one of '
206                     '"valid", "same" (or "causal", only for `Conv1D). '
207                     'Received: ' + str(padding))
208  return padding
209
210
211def convert_kernel(kernel):
212  """Converts a Numpy kernel matrix from Theano format to TensorFlow format.
213
214  Also works reciprocally, since the transformation is its own inverse.
215
216  This is used for converting legacy Theano-saved model files.
217
218  Arguments:
219      kernel: Numpy array (3D, 4D or 5D).
220
221  Returns:
222      The converted kernel.
223
224  Raises:
225      ValueError: in case of invalid kernel shape or invalid data_format.
226  """
227  kernel = np.asarray(kernel)
228  if not 3 <= kernel.ndim <= 5:
229    raise ValueError('Invalid kernel shape:', kernel.shape)
230  slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
231  no_flip = (slice(None, None), slice(None, None))
232  slices[-2:] = no_flip
233  return np.copy(kernel[slices])
234
235
236def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
237  """Compute a mask representing the connectivity of a convolution operation.
238
239  Assume a convolution with given parameters is applied to an input having N
240  spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
241  output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
242  of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
243  indicating pairs of input and output locations that are connected by a weight.
244
245  Example:
246
247    >>> input_shape = (4,)
248    >>> kernel_shape = (2,)
249    >>> strides = (1,)
250    >>> padding = "valid"
251    >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
252    array([[ True, False, False],
253           [ True,  True, False],
254           [False,  True,  True],
255           [False, False,  True]])
256
257    where rows and columns correspond to inputs and outputs respectively.
258
259
260  Args:
261    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
262      input.
263    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
264      receptive field.
265    strides: tuple of size N, strides along each spatial dimension.
266    padding: type of padding, string `"same"` or `"valid"`.
267
268  Returns:
269    A boolean 2N-D `np.ndarray` of shape
270    `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
271    is the spatial shape of the output. `True` entries in the mask represent
272    pairs of input-output locations that are connected by a weight.
273
274  Raises:
275    ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
276        same number of dimensions.
277    NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
278  """
279  if padding not in {'same', 'valid'}:
280    raise NotImplementedError('Padding type %s not supported. '
281                              'Only "valid" and "same" '
282                              'are implemented.' % padding)
283
284  in_dims = len(input_shape)
285  if isinstance(kernel_shape, int):
286    kernel_shape = (kernel_shape,) * in_dims
287  if isinstance(strides, int):
288    strides = (strides,) * in_dims
289
290  kernel_dims = len(kernel_shape)
291  stride_dims = len(strides)
292  if kernel_dims != in_dims or stride_dims != in_dims:
293    raise ValueError('Number of strides, input and kernel dimensions must all '
294                     'match. Received: %d, %d, %d.' %
295                     (stride_dims, in_dims, kernel_dims))
296
297  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
298
299  mask_shape = input_shape + output_shape
300  mask = np.zeros(mask_shape, np.bool)
301
302  output_axes_ticks = [range(dim) for dim in output_shape]
303  for output_position in itertools.product(*output_axes_ticks):
304    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
305                                             output_position, strides, padding)
306    for input_position in itertools.product(*input_axes_ticks):
307      mask[input_position + output_position] = True
308
309  return mask
310
311
312def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in,
313                     filters_out, data_format):
314  """Yields output-input tuples of indices in a CNN layer.
315
316  The generator iterates over all `(output_idx, input_idx)` tuples, where
317    `output_idx` is an integer index in a flattened tensor representing a single
318    output image of a convolutional layer that is connected (via the layer
319    weights) to the respective single input image at `input_idx`
320
321  Example:
322
323    >>> input_shape = (2, 2)
324    >>> kernel_shape = (2, 1)
325    >>> strides = (1, 1)
326    >>> padding = "valid"
327    >>> filters_in = 1
328    >>> filters_out = 1
329    >>> data_format = "channels_last"
330    >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
331    ...                       filters_in, filters_out, data_format))
332    [(0, 0), (0, 2), (1, 1), (1, 3)]
333
334  Args:
335    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
336      input.
337    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
338      receptive field.
339    strides: tuple of size N, strides along each spatial dimension.
340    padding: type of padding, string `"same"` or `"valid"`.
341    filters_in: `int`, number if filters in the input to the layer.
342    filters_out: `int', number if filters in the output of the layer.
343    data_format: string, "channels_first" or "channels_last".
344
345  Yields:
346    The next tuple `(output_idx, input_idx)`, where
347    `output_idx` is an integer index in a flattened tensor representing a single
348    output image of a convolutional layer that is connected (via the layer
349    weights) to the respective single input image at `input_idx`.
350
351  Raises:
352      ValueError: if `data_format` is neither
353      `"channels_last"` nor `"channels_first"`, or if number of strides, input,
354      and kernel number of dimensions do not match.
355
356      NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
357  """
358  if padding not in ('same', 'valid'):
359    raise NotImplementedError('Padding type %s not supported. '
360                              'Only "valid" and "same" '
361                              'are implemented.' % padding)
362
363  in_dims = len(input_shape)
364  if isinstance(kernel_shape, int):
365    kernel_shape = (kernel_shape,) * in_dims
366  if isinstance(strides, int):
367    strides = (strides,) * in_dims
368
369  kernel_dims = len(kernel_shape)
370  stride_dims = len(strides)
371  if kernel_dims != in_dims or stride_dims != in_dims:
372    raise ValueError('Number of strides, input and kernel dimensions must all '
373                     'match. Received: %d, %d, %d.' %
374                     (stride_dims, in_dims, kernel_dims))
375
376  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
377  output_axes_ticks = [range(dim) for dim in output_shape]
378
379  if data_format == 'channels_first':
380    concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
381  elif data_format == 'channels_last':
382    concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,)
383  else:
384    raise ValueError('Data format %s not recignized.'
385                     '`data_format` must be "channels_first" or '
386                     '"channels_last".' % data_format)
387
388  for output_position in itertools.product(*output_axes_ticks):
389    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
390                                             output_position, strides, padding)
391    for input_position in itertools.product(*input_axes_ticks):
392      for f_in in range(filters_in):
393        for f_out in range(filters_out):
394          out_idx = np.ravel_multi_index(
395              multi_index=concat_idxs(output_position, f_out),
396              dims=concat_idxs(output_shape, filters_out))
397          in_idx = np.ravel_multi_index(
398              multi_index=concat_idxs(input_position, f_in),
399              dims=concat_idxs(input_shape, filters_in))
400          yield (out_idx, in_idx)
401
402
403def conv_connected_inputs(input_shape, kernel_shape, output_position, strides,
404                          padding):
405  """Return locations of the input connected to an output position.
406
407  Assume a convolution with given parameters is applied to an input having N
408  spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
409  returns N ranges specifying the input region that was convolved with the
410  kernel to produce the output at position
411  `output_position = (p_out1, ..., p_outN)`.
412
413  Example:
414
415    >>> input_shape = (4, 4)
416    >>> kernel_shape = (2, 1)
417    >>> output_position = (1, 1)
418    >>> strides = (1, 1)
419    >>> padding = "valid"
420    >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
421    ...                       strides, padding)
422    [range(1, 3), range(1, 2)]
423
424  Args:
425    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
426      input.
427    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
428      receptive field.
429    output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position
430      in the output of the convolution.
431    strides: tuple of size N, strides along each spatial dimension.
432    padding: type of padding, string `"same"` or `"valid"`.
433
434  Returns:
435    N ranges `[[p_in_left1, ..., p_in_right1], ...,
436              [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
437    input connected to output_position.
438  """
439  ranges = []
440
441  ndims = len(input_shape)
442  for d in range(ndims):
443    left_shift = int(kernel_shape[d] / 2)
444    right_shift = kernel_shape[d] - left_shift
445
446    center = output_position[d] * strides[d]
447
448    if padding == 'valid':
449      center += left_shift
450
451    start = max(0, center - left_shift)
452    end = min(input_shape[d], center + right_shift)
453
454    ranges.append(range(start, end))
455
456  return ranges
457
458
459def conv_output_shape(input_shape, kernel_shape, strides, padding):
460  """Return the output shape of an N-D convolution.
461
462  Forces dimensions where input is empty (size 0) to remain empty.
463
464  Args:
465    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
466      input.
467    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
468      receptive field.
469    strides: tuple of size N, strides along each spatial dimension.
470    padding: type of padding, string `"same"` or `"valid"`.
471
472  Returns:
473    tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
474  """
475  dims = range(len(kernel_shape))
476  output_shape = [
477      conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
478      for d in dims
479  ]
480  output_shape = tuple(
481      [0 if input_shape[d] == 0 else output_shape[d] for d in dims])
482  return output_shape
483