• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Primitive Neural Net (NN) Operations.
16
17## Notes on padding
18
19Several neural network operations, such as `tf.nn.conv2d` and
20`tf.nn.max_pool2d`, take a `padding` parameter, which controls how the input is
21padded before running the operation. The input is padded by inserting values
22(typically zeros) before and after the tensor in each spatial dimension. The
23`padding` parameter can either be the string `'VALID'`, which means use no
24padding, or `'SAME'` which adds padding according to a formula which is
25described below. Certain ops also allow the amount of padding per dimension to
26be explicitly specified by passing a list to `padding`.
27
28In the case of convolutions, the input is padded with zeros. In case of pools,
29the padded input values are ignored. For example, in a max pool, the sliding
30window ignores padded values, which is equivalent to the padded values being
31`-infinity`.
32
33### `'VALID'` padding
34
35Passing `padding='VALID'` to an op causes no padding to be used. This causes the
36output size to typically be smaller than the input size, even when the stride is
37one. In the 2D case, the output size is computed as:
38
39```
40out_height = ceil((in_height - filter_height + 1) / stride_height)
41out_width  = ceil((in_width - filter_width + 1) / stride_width)
42```
43
44The 1D and 3D cases are similar. Note `filter_height` and `filter_width` refer
45to the filter size after dilations (if any) for convolutions, and refer to the
46window size for pools.
47
48### `'SAME'` padding
49
50With `'SAME'` padding, padding is applied to each spatial dimension. When the
51strides are 1, the input is padded such that the output size is the same as the
52input size. In the 2D case, the output size is computed as:
53
54```
55out_height = ceil(in_height / stride_height)
56out_width  = ceil(in_width / stride_width)
57```
58
59The amount of padding used is the smallest amount that results in the output
60size. The formula for the total amount of padding per dimension is:
61
62```
63if (in_height % strides[1] == 0):
64  pad_along_height = max(filter_height - stride_height, 0)
65else:
66  pad_along_height = max(filter_height - (in_height % stride_height), 0)
67if (in_width % strides[2] == 0):
68  pad_along_width = max(filter_width - stride_width, 0)
69else:
70  pad_along_width = max(filter_width - (in_width % stride_width), 0)
71```
72
73Finally, the padding on the top, bottom, left and right are:
74
75```
76pad_top = pad_along_height // 2
77pad_bottom = pad_along_height - pad_top
78pad_left = pad_along_width // 2
79pad_right = pad_along_width - pad_left
80```
81
82Note that the division by 2 means that there might be cases when the padding on
83both sides (top vs bottom, right vs left) are off by one. In this case, the
84bottom and right sides always get the one additional padded pixel. For example,
85when pad_along_height is 5, we pad 2 pixels at the top and 3 pixels at the
86bottom. Note that this is different from existing libraries such as PyTorch and
87Caffe, which explicitly specify the number of padded pixels and always pad the
88same number of pixels on both sides.
89
90Here is an example of `'SAME'` padding:
91
92>>> in_height = 5
93>>> filter_height = 3
94>>> stride_height = 2
95>>>
96>>> in_width = 2
97>>> filter_width = 2
98>>> stride_width = 1
99>>>
100>>> inp = tf.ones((2, in_height, in_width, 2))
101>>> filter = tf.ones((filter_height, filter_width, 2, 2))
102>>> strides = [stride_height, stride_width]
103>>> output = tf.nn.conv2d(inp, filter, strides, padding='SAME')
104>>> output.shape[1]  # output_height: ceil(5 / 2)
1053
106>>> output.shape[2] # output_width: ceil(2 / 1)
1072
108
109### Explicit padding
110
111Certain ops, like `tf.nn.conv2d`, also allow a list of explicit padding amounts
112to be passed to the `padding` parameter. This list is in the same format as what
113is passed to `tf.pad`, except the padding must be a nested list, not a tensor.
114For example, in the 2D case, the list is in the format `[[0, 0], [pad_top,
115pad_bottom], [pad_left, pad_right], [0, 0]]` when `data_format` is its default
116value of `'NHWC'`. The two `[0, 0]` pairs  indicate the batch and channel
117dimensions have no padding, which is required, as only spatial dimensions can
118have padding.
119
120For example:
121
122>>> inp = tf.ones((1, 3, 3, 1))
123>>> filter = tf.ones((2, 2, 1, 1))
124>>> strides = [1, 1]
125>>> padding = [[0, 0], [1, 2], [0, 1], [0, 0]]
126>>> output = tf.nn.conv2d(inp, filter, strides, padding=padding)
127>>> tuple(output.shape)
128(1, 5, 3, 1)
129>>> # Equivalently, tf.pad can be used, since convolutions pad with zeros.
130>>> inp = tf.pad(inp, padding)
131>>> # 'VALID' means to use no padding in conv2d (we already padded inp)
132>>> output2 = tf.nn.conv2d(inp, filter, strides, padding='VALID')
133>>> tf.debugging.assert_equal(output, output2)
134"""
135
136from __future__ import absolute_import
137from __future__ import division
138from __future__ import print_function
139
140import functools
141import numbers
142
143import numpy as np
144
145from tensorflow.python.eager import context
146from tensorflow.python.framework import config
147from tensorflow.python.framework import constant_op
148from tensorflow.python.framework import dtypes
149from tensorflow.python.framework import errors_impl
150from tensorflow.python.framework import graph_util
151from tensorflow.python.framework import ops
152from tensorflow.python.framework import random_seed
153from tensorflow.python.framework import tensor_shape
154from tensorflow.python.framework import tensor_util
155from tensorflow.python.ops import array_ops
156from tensorflow.python.ops import check_ops
157from tensorflow.python.ops import gen_math_ops
158from tensorflow.python.ops import gen_nn_ops
159from tensorflow.python.ops import math_ops
160from tensorflow.python.ops import random_ops
161from tensorflow.python.ops import stateless_random_ops
162from tensorflow.python.ops import variables as variables_lib
163# go/tf-wildcard-import
164# pylint: disable=wildcard-import
165from tensorflow.python.ops.gen_nn_ops import *
166# pylint: enable=wildcard-import
167from tensorflow.python.platform import device_context
168from tensorflow.python.util import deprecation
169from tensorflow.python.util import dispatch
170from tensorflow.python.util.compat import collections_abc
171from tensorflow.python.util.deprecation import deprecated_args
172from tensorflow.python.util.deprecation import deprecated_argument_lookup
173
174from tensorflow.python.util.tf_export import tf_export
175
176# Aliases for some automatically-generated names.
177local_response_normalization = gen_nn_ops.lrn
178
179# pylint: disable=protected-access
180
181# Acceptable channels last formats (robust to H, W, D order).
182_CHANNELS_LAST_FORMATS = frozenset({
183    "NWC", "NHC", "NHWC", "NWHC", "NDHWC", "NDWHC", "NHDWC", "NHWDC", "NWDHC",
184    "NWHDC"
185})
186
187
188def _get_sequence(value, n, channel_index, name):
189  """Formats a value input for gen_nn_ops."""
190  # Performance is fast-pathed for common cases:
191  # `None`, `list`, `tuple` and `int`.
192  if value is None:
193    return [1] * (n + 2)
194
195  # Always convert `value` to a `list`.
196  if isinstance(value, list):
197    pass
198  elif isinstance(value, tuple):
199    value = list(value)
200  elif isinstance(value, int):
201    value = [value]
202  elif not isinstance(value, collections_abc.Sized):
203    value = [value]
204  else:
205    value = list(value)  # Try casting to a list.
206
207  len_value = len(value)
208
209  # Fully specified, including batch and channel dims.
210  if len_value == n + 2:
211    return value
212
213  # Apply value to spatial dims only.
214  if len_value == 1:
215    value = value * n  # Broadcast to spatial dimensions.
216  elif len_value != n:
217    raise ValueError("{} should be of length 1, {} or {} but was {}".format(
218        name, n, n + 2, len_value))
219
220  # Add batch and channel dims (always 1).
221  if channel_index == 1:
222    return [1, 1] + value
223  else:
224    return [1] + value + [1]
225
226
227def _non_atrous_convolution(
228    input,  # pylint: disable=redefined-builtin
229    filter,  # pylint: disable=redefined-builtin
230    padding,
231    data_format=None,  # pylint: disable=redefined-builtin
232    strides=None,
233    name=None):
234  """Computes sums of N-D convolutions (actually cross correlation).
235
236  It is required that 1 <= N <= 3.
237
238  This is used to implement the more generic `convolution` function, which
239  extends the interface of this function with a `dilation_rate` parameter.
240
241  Args:
242
243    input: Rank N+2 tensor of type T of shape
244      `[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
245      does not start with `"NC"`, or
246      `[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
247      with `"NC"`.
248    filter: Rank N+2 tensor of type T of shape
249      `filter_spatial_shape + [in_channels, out_channels]`.  Rank of either
250      `input` or `filter` must be known.
251    padding: Padding method to use, must be either "VALID" or "SAME".
252    data_format: A string or None.  Specifies whether the channel dimension of
253      the `input` and output is the last dimension (default, or if `data_format`
254      does not start with "NC"), or the second dimension (if `data_format`
255      starts with "NC").  For N=1, the valid values are "NWC" (default) and
256      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
257      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
258    strides: Sequence of N positive integers, defaults to `[1] * N`.
259    name: Name prefix to use.
260
261  Returns:
262    Rank N+2 tensor of type T of shape
263    `[batch_size] + output_spatial_shape + [out_channels]`, where
264    if padding == "SAME":
265      output_spatial_shape = input_spatial_shape
266    if padding == "VALID":
267      output_spatial_shape = input_spatial_shape - filter_spatial_shape + 1.
268
269  Raises:
270    ValueError: if ranks are incompatible.
271
272  """
273  with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
274    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
275    input_shape = input.shape
276    filter = ops.convert_to_tensor(filter, name="filter")  # pylint: disable=redefined-builtin
277    filter_shape = filter.shape
278    op = _NonAtrousConvolution(
279        input_shape,
280        filter_shape=filter_shape,
281        padding=padding,
282        data_format=data_format,
283        strides=strides,
284        name=scope)
285    return op(input, filter)
286
287
288class _NonAtrousConvolution(object):
289  """Helper class for _non_atrous_convolution.
290
291  Note that this class assumes that shapes of input and filter passed to
292  `__call__` are compatible with `input_shape` and filter_shape passed to the
293  constructor.
294
295  Args:
296    input_shape: static input shape, i.e. input.shape.
297    filter_shape: static filter shape, i.e. filter.shape.
298    padding: see _non_atrous_convolution.
299    data_format: see _non_atrous_convolution.
300    strides: see _non_atrous_convolution.
301    name: see _non_atrous_convolution.
302    num_batch_dims: (Optional.)  The number of batch dimensions in the input;
303     if not provided, the default of `1` is used.
304  """
305
306  def __init__(
307      self,
308      input_shape,
309      filter_shape,
310      padding,
311      data_format=None,
312      strides=None,
313      name=None,
314      num_batch_dims=1):
315    # filter shape is always rank num_spatial_dims + 2
316    # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
317    if input_shape.ndims is not None:
318      filter_shape = filter_shape.with_rank(
319          input_shape.ndims - num_batch_dims + 1)
320    self.padding = padding
321    self.name = name
322    # input shape is == num_spatial_dims + num_batch_dims + 1
323    # and filter_shape is always rank num_spatial_dims + 2
324    if filter_shape.ndims is not None:
325      input_shape = input_shape.with_rank(
326          filter_shape.ndims + num_batch_dims - 1)
327    if input_shape.ndims is None:
328      raise ValueError(
329          "Rank of convolution must be known, but saw input_shape.ndims == {}"
330          .format(input_shape.ndims))
331    if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
332      raise ValueError(
333          "`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at "
334          "most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`"
335          .format(input_shape.ndims, num_batch_dims))
336    conv_dims = input_shape.ndims - num_batch_dims - 1
337    if strides is None:
338      strides = [1] * conv_dims
339    elif len(strides) != conv_dims:
340      raise ValueError("len(strides)=%d, but should be %d" % (len(strides),
341                                                              conv_dims))
342    if conv_dims == 1:
343      # conv1d uses the 2-d data format names
344      if data_format is None:
345        data_format = "NWC"
346      elif data_format not in {"NCW", "NWC", "NCHW", "NHWC"}:
347        raise ValueError("data_format must be \"NWC\" or \"NCW\".")
348      self.strides = strides[0]
349      self.data_format = data_format
350      self.conv_op = self._conv1d
351    elif conv_dims == 2:
352      if data_format is None or data_format == "NHWC":
353        data_format = "NHWC"
354        strides = [1] + list(strides) + [1]
355      elif data_format == "NCHW":
356        strides = [1, 1] + list(strides)
357      else:
358        raise ValueError("data_format must be \"NHWC\" or \"NCHW\".")
359      self.strides = strides
360      self.data_format = data_format
361      self.conv_op = conv2d
362    elif conv_dims == 3:
363      if data_format is None or data_format == "NDHWC":
364        strides = [1] + list(strides) + [1]
365      elif data_format == "NCDHW":
366        strides = [1, 1] + list(strides)
367      else:
368        raise ValueError("data_format must be \"NDHWC\" or \"NCDHW\". Have: %s"
369                         % data_format)
370      self.strides = strides
371      self.data_format = data_format
372      self.conv_op = _conv3d_expanded_batch
373
374  # Note that we need this adapter since argument names for conv1d don't match
375  # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
376  # pylint: disable=redefined-builtin
377  def _conv1d(self, input, filter, strides, padding, data_format, name):
378    return conv1d(
379        value=input,
380        filters=filter,
381        stride=strides,
382        padding=padding,
383        data_format=data_format,
384        name=name)
385  # pylint: enable=redefined-builtin
386
387  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
388    return self.conv_op(
389        input=inp,
390        filter=filter,
391        strides=self.strides,
392        padding=self.padding,
393        data_format=self.data_format,
394        name=self.name)
395
396
397def squeeze_batch_dims(inp, op, inner_rank, name=None):
398  """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
399
400  Where `squeeze_batch` reshapes `inp` to shape
401  `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
402  and `unsqueeze_batch` does the reverse reshape but on the output.
403
404  Args:
405    inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
406      is length `inner_rank`.
407    op: A callable that takes a single input tensor and returns a single.
408      output tensor.
409    inner_rank: A python integer.
410    name: A string.
411
412  Returns:
413    `unsqueeze_batch_op(squeeze_batch(inp))`.
414  """
415  with ops.name_scope(name, "squeeze_batch_dims", [inp]):
416    inp = ops.convert_to_tensor(inp, name="input")
417    shape = inp.shape
418
419    inner_shape = shape[-inner_rank:]
420    if not inner_shape.is_fully_defined():
421      inner_shape = array_ops.shape(inp)[-inner_rank:]
422
423    batch_shape = shape[:-inner_rank]
424    if not batch_shape.is_fully_defined():
425      batch_shape = array_ops.shape(inp)[:-inner_rank]
426
427    if isinstance(inner_shape, tensor_shape.TensorShape):
428      inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list())
429    else:
430      inp_reshaped = array_ops.reshape(
431          inp, array_ops.concat(([-1], inner_shape), axis=-1))
432
433    out_reshaped = op(inp_reshaped)
434
435    out_inner_shape = out_reshaped.shape[-inner_rank:]
436    if not out_inner_shape.is_fully_defined():
437      out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:]
438
439    out = array_ops.reshape(
440        out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1))
441
442    out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
443    return out
444
445
446@tf_export("nn.dilation2d", v1=[])
447@dispatch.add_dispatch_support
448def dilation2d_v2(
449    input,   # pylint: disable=redefined-builtin
450    filters,  # pylint: disable=redefined-builtin
451    strides,
452    padding,
453    data_format,
454    dilations,
455    name=None):
456  """Computes the grayscale dilation of 4-D `input` and 3-D `filters` tensors.
457
458  The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
459  `filters` tensor has shape `[filter_height, filter_width, depth]`, i.e., each
460  input channel is processed independently of the others with its own
461  structuring function. The `output` tensor has shape
462  `[batch, out_height, out_width, depth]`. The spatial dimensions of the output
463  tensor depend on the `padding` algorithm. We currently only support the
464  default "NHWC" `data_format`.
465
466  In detail, the grayscale morphological 2-D dilation is the max-sum correlation
467  (for consistency with `conv2d`, we use unmirrored filters):
468
469      output[b, y, x, c] =
470         max_{dy, dx} input[b,
471                            strides[1] * y + rates[1] * dy,
472                            strides[2] * x + rates[2] * dx,
473                            c] +
474                      filters[dy, dx, c]
475
476  Max-pooling is a special case when the filter has size equal to the pooling
477  kernel size and contains all zeros.
478
479  Note on duality: The dilation of `input` by the `filters` is equal to the
480  negation of the erosion of `-input` by the reflected `filters`.
481
482  Args:
483    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
484      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
485      `uint32`, `uint64`.
486      4-D with shape `[batch, in_height, in_width, depth]`.
487    filters: A `Tensor`. Must have the same type as `input`.
488      3-D with shape `[filter_height, filter_width, depth]`.
489    strides: A list of `ints` that has length `>= 4`.
490      The stride of the sliding window for each dimension of the input
491      tensor. Must be: `[1, stride_height, stride_width, 1]`.
492    padding: A `string` from: `"SAME", "VALID"`.
493      The type of padding algorithm to use.
494    data_format: A `string`, only `"NHWC"` is currently supported.
495    dilations: A list of `ints` that has length `>= 4`.
496      The input stride for atrous morphological dilation. Must be:
497      `[1, rate_height, rate_width, 1]`.
498    name: A name for the operation (optional).
499
500  Returns:
501    A `Tensor`. Has the same type as `input`.
502  """
503  if data_format != "NHWC":
504    raise ValueError("Data formats other than NHWC are not yet supported")
505
506  return gen_nn_ops.dilation2d(input=input,
507                               filter=filters,
508                               strides=strides,
509                               rates=dilations,
510                               padding=padding,
511                               name=name)
512
513
514@tf_export(v1=["nn.dilation2d"])
515@dispatch.add_dispatch_support
516def dilation2d_v1(  # pylint: disable=missing-docstring
517    input,  # pylint: disable=redefined-builtin
518    filter=None,  # pylint: disable=redefined-builtin
519    strides=None,
520    rates=None,
521    padding=None,
522    name=None,
523    filters=None,
524    dilations=None):
525  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
526  rates = deprecated_argument_lookup("dilations", dilations, "rates", rates)
527  return gen_nn_ops.dilation2d(input, filter, strides, rates, padding, name)
528
529
530dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__
531
532
533@tf_export("nn.with_space_to_batch")
534@dispatch.add_dispatch_support
535def with_space_to_batch(
536    input,  # pylint: disable=redefined-builtin
537    dilation_rate,
538    padding,
539    op,
540    filter_shape=None,
541    spatial_dims=None,
542    data_format=None):
543  """Performs `op` on the space-to-batch representation of `input`.
544
545  This has the effect of transforming sliding window operations into the
546  corresponding "atrous" operation in which the input is sampled at the
547  specified `dilation_rate`.
548
549  In the special case that `dilation_rate` is uniformly 1, this simply returns:
550
551    op(input, num_spatial_dims, padding)
552
553  Otherwise, it returns:
554
555    batch_to_space_nd(
556      op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
557         num_spatial_dims,
558         "VALID")
559      adjusted_dilation_rate,
560      adjusted_crops),
561
562  where:
563
564    adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
565    adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
566
567  defined as follows:
568
569  We first define two int64 tensors `paddings` and `crops` of shape
570  `[num_spatial_dims, 2]` based on the value of `padding` and the spatial
571  dimensions of the `input`:
572
573  If `padding = "VALID"`, then:
574
575    paddings, crops = required_space_to_batch_paddings(
576      input_shape[spatial_dims],
577      dilation_rate)
578
579  If `padding = "SAME"`, then:
580
581    dilated_filter_shape =
582      filter_shape + (filter_shape - 1) * (dilation_rate - 1)
583
584    paddings, crops = required_space_to_batch_paddings(
585      input_shape[spatial_dims],
586      dilation_rate,
587      [(dilated_filter_shape - 1) // 2,
588       dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
589
590  Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
591  dimensions are contiguous starting at the second dimension, but the specified
592  `spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
593  `crops` in order to be usable with these operations.  For a given dimension,
594  if the block size is 1, and both the starting and ending padding and crop
595  amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
596  which is what is needed for dimensions not part of `spatial_dims`.
597  Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
598  efficiently for any number of leading and trailing dimensions.
599
600  For 0 <= i < len(spatial_dims), we assign:
601
602    adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
603    adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
604    adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
605
606  All unassigned values of `adjusted_dilation_rate` default to 1, while all
607  unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
608
609  Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
610  padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
611  `[1]*N`.
612
613  Advanced usage. Note the following optimization: A sequence of
614  `with_space_to_batch` operations with identical (not uniformly 1)
615  `dilation_rate` parameters and "VALID" padding
616
617    net = with_space_to_batch(net, dilation_rate, "VALID", op_1)
618    ...
619    net = with_space_to_batch(net, dilation_rate, "VALID", op_k)
620
621  can be combined into a single `with_space_to_batch` operation as follows:
622
623    def combined_op(converted_input, num_spatial_dims, _):
624      result = op_1(converted_input, num_spatial_dims, "VALID")
625      ...
626      result = op_k(result, num_spatial_dims, "VALID")
627
628    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
629
630  This eliminates the overhead of `k-1` calls to `space_to_batch_nd` and
631  `batch_to_space_nd`.
632
633  Similarly, a sequence of `with_space_to_batch` operations with identical (not
634  uniformly 1) `dilation_rate` parameters, "SAME" padding, and odd filter
635  dimensions
636
637    net = with_space_to_batch(net, dilation_rate, "SAME", op_1, filter_shape_1)
638    ...
639    net = with_space_to_batch(net, dilation_rate, "SAME", op_k, filter_shape_k)
640
641  can be combined into a single `with_space_to_batch` operation as follows:
642
643    def combined_op(converted_input, num_spatial_dims, _):
644      result = op_1(converted_input, num_spatial_dims, "SAME")
645      ...
646      result = op_k(result, num_spatial_dims, "SAME")
647
648    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
649
650  Args:
651    input: Tensor of rank > max(spatial_dims).
652    dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
653    padding: str constant equal to "VALID" or "SAME"
654    op: Function that maps (input, num_spatial_dims, padding) -> output
655    filter_shape: If padding = "SAME", specifies the shape of the convolution
656      kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
657      If padding = "VALID", filter_shape is ignored and need not be specified.
658    spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
659      integers (which are >= 1) specifying the spatial dimensions of `input`
660      and output.  Defaults to: `range(1, num_spatial_dims+1)`.
661    data_format: A string or None.  Specifies whether the channel dimension of
662      the `input` and output is the last dimension (default, or if `data_format`
663      does not start with "NC"), or the second dimension (if `data_format`
664      starts with "NC").  For N=1, the valid values are "NWC" (default) and
665      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
666      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
667
668  Returns:
669    The output Tensor as described above, dimensions will vary based on the op
670    provided.
671
672  Raises:
673    ValueError: if `padding` is invalid or the arguments are incompatible.
674    ValueError: if `spatial_dims` are invalid.
675
676  """
677  input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
678  input_shape = input.shape
679
680  def build_op(num_spatial_dims, padding):
681    return lambda inp, _: op(inp, num_spatial_dims, padding)
682
683  new_op = _WithSpaceToBatch(
684      input_shape,
685      dilation_rate,
686      padding,
687      build_op,
688      filter_shape=filter_shape,
689      spatial_dims=spatial_dims,
690      data_format=data_format)
691  return new_op(input, None)
692
693
694class _WithSpaceToBatch(object):
695  """Helper class for with_space_to_batch.
696
697  Note that this class assumes that shapes of input and filter passed to
698  `__call__` are compatible with `input_shape`, `filter_shape`, and
699  `spatial_dims` passed to the constructor.
700
701  Arguments
702    input_shape: static shape of input. i.e. input.shape.
703    dilation_rate: see `with_space_to_batch`.
704    padding: see `with_space_to_batch`.
705    build_op: Function that maps (num_spatial_dims, paddings) -> (function that
706      maps (input, filter) -> output).
707    filter_shape: see `with_space_to_batch`.
708    spatial_dims: `see with_space_to_batch`.
709    data_format: see `with_space_to_batch`.
710    num_batch_dims: (Optional).  Number of batch dims in `input_shape`.
711  """
712
713  def __init__(self,
714               input_shape,
715               dilation_rate,
716               padding,
717               build_op,
718               filter_shape=None,
719               spatial_dims=None,
720               data_format=None,
721               num_batch_dims=1):
722    """Helper class for _with_space_to_batch."""
723    dilation_rate = ops.convert_to_tensor(
724        dilation_rate, dtypes.int32, name="dilation_rate")
725    if dilation_rate.shape.ndims not in (None, 1):
726      raise ValueError(
727          "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims))
728
729    if not dilation_rate.shape.is_fully_defined():
730      raise ValueError("rate must have known shape, but saw {}"
731                       .format(dilation_rate.shape))
732
733    num_spatial_dims = dilation_rate.shape.dims[0].value
734
735    if data_format is not None and data_format.startswith("NC"):
736      starting_spatial_dim = num_batch_dims + 1
737    else:
738      starting_spatial_dim = num_batch_dims
739
740    if spatial_dims is None:
741      spatial_dims = range(starting_spatial_dim,
742                           num_spatial_dims + starting_spatial_dim)
743    orig_spatial_dims = list(spatial_dims)
744    spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
745    if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
746      raise ValueError(
747          "spatial_dims must be a monotonically increasing sequence of "
748          "positive integers, but saw: {}".format(orig_spatial_dims))
749
750    if data_format is not None and data_format.startswith("NC"):
751      expected_input_rank = spatial_dims[-1]
752    else:
753      expected_input_rank = spatial_dims[-1] + 1
754
755    try:
756      input_shape.with_rank_at_least(expected_input_rank)
757    except ValueError:
758      raise ValueError(
759          "input tensor must have rank at least {}, but saw rank {}"
760          .format(expected_input_rank, input_shape.ndims))
761
762    const_rate = tensor_util.constant_value(dilation_rate)
763    rate_or_const_rate = dilation_rate
764    if const_rate is not None:
765      rate_or_const_rate = const_rate
766      if np.any(const_rate < 1):
767        raise ValueError("dilation_rate must be positive, but saw: {}"
768                         .format(const_rate))
769      if np.all(const_rate == 1):
770        self.call = build_op(num_spatial_dims, padding)
771        return
772
773    padding, explicit_paddings = convert_padding(padding)
774
775    # We have two padding contributions. The first is used for converting "SAME"
776    # to "VALID". The second is required so that the height and width of the
777    # zero-padded value tensor are multiples of rate.
778
779    # Padding required to reduce to "VALID" convolution
780    if padding == "SAME":
781      if filter_shape is None:
782        raise ValueError("filter_shape must be specified for SAME padding")
783      filter_shape = ops.convert_to_tensor(filter_shape, name="filter_shape")
784      const_filter_shape = tensor_util.constant_value(filter_shape)
785      if const_filter_shape is not None:
786        filter_shape = const_filter_shape
787        self.base_paddings = _with_space_to_batch_base_paddings(
788            const_filter_shape, num_spatial_dims, rate_or_const_rate)
789      else:
790        self.num_spatial_dims = num_spatial_dims
791        self.rate_or_const_rate = rate_or_const_rate
792        self.base_paddings = None
793    elif padding == "VALID":
794      self.base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
795    elif padding == "EXPLICIT":
796      base_paddings = (np.array(explicit_paddings)
797                       .reshape([num_spatial_dims + 2, 2]))
798      # Remove batch and channel dimensions
799      if data_format is not None and data_format.startswith("NC"):
800        self.base_paddings = base_paddings[2:]
801      else:
802        self.base_paddings = base_paddings[1:-1]
803    else:
804      raise ValueError("Invalid padding method %r" % padding)
805
806    self.input_shape = input_shape
807    self.spatial_dims = spatial_dims
808    self.dilation_rate = dilation_rate
809    self.data_format = data_format
810    self.op = build_op(num_spatial_dims, "VALID")
811    self.call = self._with_space_to_batch_call
812
813  def _with_space_to_batch_call(self, inp, filter):  # pylint: disable=redefined-builtin
814    """Call functionality for with_space_to_batch."""
815    # Handle input whose shape is unknown during graph creation.
816    input_spatial_shape = None
817    input_shape = self.input_shape
818    spatial_dims = self.spatial_dims
819    if input_shape.ndims is not None:
820      input_shape_list = input_shape.as_list()
821      input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
822    if input_spatial_shape is None or None in input_spatial_shape:
823      input_shape_tensor = array_ops.shape(inp)
824      input_spatial_shape = array_ops.stack(
825          [input_shape_tensor[i] for i in spatial_dims])
826
827    base_paddings = self.base_paddings
828    if base_paddings is None:
829      # base_paddings could not be computed at build time since static filter
830      # shape was not fully defined.
831      filter_shape = array_ops.shape(filter)
832      base_paddings = _with_space_to_batch_base_paddings(
833          filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
834
835    paddings, crops = array_ops.required_space_to_batch_paddings(
836        input_shape=input_spatial_shape,
837        base_paddings=base_paddings,
838        block_shape=self.dilation_rate)
839
840    dilation_rate = _with_space_to_batch_adjust(self.dilation_rate, 1,
841                                                spatial_dims)
842    paddings = _with_space_to_batch_adjust(paddings, 0, spatial_dims)
843    crops = _with_space_to_batch_adjust(crops, 0, spatial_dims)
844    input_converted = array_ops.space_to_batch_nd(
845        input=inp, block_shape=dilation_rate, paddings=paddings)
846
847    result = self.op(input_converted, filter)
848
849    result_converted = array_ops.batch_to_space_nd(
850        input=result, block_shape=dilation_rate, crops=crops)
851
852    # Recover channel information for output shape if channels are not last.
853    if self.data_format is not None and self.data_format.startswith("NC"):
854      if not result_converted.shape.dims[1].value and filter is not None:
855        output_shape = result_converted.shape.as_list()
856        output_shape[1] = filter.shape[-1]
857        result_converted.set_shape(output_shape)
858
859    return result_converted
860
861  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
862    return self.call(inp, filter)
863
864
865def _with_space_to_batch_base_paddings(filter_shape, num_spatial_dims,
866                                       rate_or_const_rate):
867  """Helper function to compute base_paddings."""
868  # Spatial dimensions of the filters and the upsampled filters in which we
869  # introduce (rate - 1) zeros between consecutive filter values.
870  filter_spatial_shape = filter_shape[:num_spatial_dims]
871  pad_extra_shape = (filter_spatial_shape - 1) * rate_or_const_rate
872
873  # When full_padding_shape is odd, we pad more at end, following the same
874  # convention as conv2d.
875  pad_extra_start = pad_extra_shape // 2
876  pad_extra_end = pad_extra_shape - pad_extra_start
877  base_paddings = array_ops.stack(
878      [[pad_extra_start[i], pad_extra_end[i]] for i in range(num_spatial_dims)])
879  return base_paddings
880
881
882def _with_space_to_batch_adjust(orig, fill_value, spatial_dims):
883  """Returns an `adjusted` version of `orig` based on `spatial_dims`.
884
885  Tensor of the same type as `orig` and with shape
886  `[max(spatial_dims), ...]` where:
887
888    adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
889
890  for 0 <= i < len(spatial_dims), and
891
892    adjusted[j, ...] = fill_value
893
894  for j != spatial_dims[i] - 1 for some i.
895
896  If `orig` is a constant value, then the result will be a constant value.
897
898  Args:
899    orig: Tensor of rank > max(spatial_dims).
900    fill_value: Numpy scalar (of same data type as `orig) specifying the fill
901      value for non-spatial dimensions.
902    spatial_dims: See with_space_to_batch.
903
904  Returns:
905    `adjusted` tensor.
906  """
907  fill_dims = orig.get_shape().as_list()[1:]
908  dtype = orig.dtype.as_numpy_dtype
909  parts = []
910  const_orig = tensor_util.constant_value(orig)
911  const_or_orig = const_orig if const_orig is not None else orig
912  prev_spatial_dim = 0
913  i = 0
914  while i < len(spatial_dims):
915    start_i = i
916    start_spatial_dim = spatial_dims[i]
917    if start_spatial_dim > 1:
918      # Fill in any gap from the previous spatial dimension (or dimension 1 if
919      # this is the first spatial dimension) with `fill_value`.
920      parts.append(
921          np.full(
922              [start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
923              fill_value,
924              dtype=dtype))
925    # Find the largest value of i such that:
926    #   [spatial_dims[start_i], ..., spatial_dims[i]]
927    #     == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
928    # i.e. the end of a contiguous group of spatial dimensions.
929    while (i + 1 < len(spatial_dims) and
930           spatial_dims[i + 1] == spatial_dims[i] + 1):
931      i += 1
932    parts.append(const_or_orig[start_i:i + 1])
933    prev_spatial_dim = spatial_dims[i]
934    i += 1
935  if const_orig is not None:
936    return np.concatenate(parts)
937  else:
938    return array_ops.concat(parts, 0)
939
940
941def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
942  """Helper function for verifying strides and dilation_rate arguments.
943
944  This is used by `convolution` and `pool`.
945
946  Args:
947    num_spatial_dims: int
948    strides: Optional.  List of N ints >= 1.  Defaults to [1]*N.  If any value
949      of strides is > 1, then all values of dilation_rate must be 1.
950    dilation_rate: Optional.  List of N ints >= 1.  Defaults to [1]*N.  If any
951      value of dilation_rate is > 1, then all values of strides must be 1.
952
953  Returns:
954    Normalized (strides, dilation_rate) as int32 numpy arrays of shape
955    [num_spatial_dims].
956
957  Raises:
958    ValueError: if the parameters are invalid.
959  """
960  if dilation_rate is None:
961    dilation_rate = [1] * num_spatial_dims
962  elif len(dilation_rate) != num_spatial_dims:
963    raise ValueError("len(dilation_rate)=%d but should be %d" %
964                     (len(dilation_rate), num_spatial_dims))
965  dilation_rate = np.array(dilation_rate, dtype=np.int32)
966  if np.any(dilation_rate < 1):
967    raise ValueError("all values of dilation_rate must be positive")
968
969  if strides is None:
970    strides = [1] * num_spatial_dims
971  elif len(strides) != num_spatial_dims:
972    raise ValueError("len(strides)=%d but should be %d" % (len(strides),
973                                                           num_spatial_dims))
974  strides = np.array(strides, dtype=np.int32)
975  if np.any(strides < 1):
976    raise ValueError("all values of strides must be positive")
977
978  if np.any(strides > 1) and np.any(dilation_rate > 1):
979    raise ValueError(
980        "strides > 1 not supported in conjunction with dilation_rate > 1")
981  return strides, dilation_rate
982
983
984@tf_export(v1=["nn.convolution"])
985@dispatch.add_dispatch_support
986def convolution(
987    input,  # pylint: disable=redefined-builtin
988    filter,  # pylint: disable=redefined-builtin
989    padding,
990    strides=None,
991    dilation_rate=None,
992    name=None,
993    data_format=None,
994    filters=None,
995    dilations=None):  # pylint: disable=g-doc-args
996  """Computes sums of N-D convolutions (actually cross-correlation).
997
998  This also supports either output striding via the optional `strides` parameter
999  or atrous convolution (also known as convolution with holes or dilated
1000  convolution, based on the French word "trous" meaning holes in English) via
1001  the optional `dilation_rate` parameter.  Currently, however, output striding
1002  is not supported for atrous convolutions.
1003
1004  Specifically, in the case that `data_format` does not start with "NC", given
1005  a rank (N+2) `input` Tensor of shape
1006
1007    [num_batches,
1008     input_spatial_shape[0],
1009     ...,
1010     input_spatial_shape[N-1],
1011     num_input_channels],
1012
1013  a rank (N+2) `filter` Tensor of shape
1014
1015    [spatial_filter_shape[0],
1016     ...,
1017     spatial_filter_shape[N-1],
1018     num_input_channels,
1019     num_output_channels],
1020
1021  an optional `dilation_rate` tensor of shape [N] (defaulting to [1]*N)
1022  specifying the filter upsampling/input downsampling rate, and an optional list
1023  of N `strides` (defaulting [1]*N), this computes for each N-D spatial output
1024  position (x[0], ..., x[N-1]):
1025
1026  ```
1027    output[b, x[0], ..., x[N-1], k] =
1028        sum_{z[0], ..., z[N-1], q}
1029            filter[z[0], ..., z[N-1], q, k] *
1030            padded_input[b,
1031                         x[0]*strides[0] + dilation_rate[0]*z[0],
1032                         ...,
1033                         x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
1034                         q]
1035  ```
1036  where b is the index into the batch, k is the output channel number, q is the
1037  input channel number, and z is the N-D spatial offset within the filter. Here,
1038  `padded_input` is obtained by zero padding the input using an effective
1039  spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
1040  output striding `strides`.
1041
1042  In the case that `data_format` does start with `"NC"`, the `input` and output
1043  (but not the `filter`) are simply transposed as follows:
1044
1045    convolution(input, data_format, **kwargs) =
1046      tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
1047                               **kwargs),
1048                   [0, N+1] + range(1, N+1))
1049
1050  It is required that 1 <= N <= 3.
1051
1052  Args:
1053    input: An (N+2)-D `Tensor` of type `T`, of shape
1054      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
1055      not start with "NC" (default), or
1056      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
1057      with "NC".
1058    filter: An (N+2)-D `Tensor` with the same type as `input` and shape
1059      `spatial_filter_shape + [in_channels, out_channels]`.
1060    padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
1061      `"valid"` means no padding. `"same"` results in padding evenly to
1062      the left/right or up/down of the input such that output has the same
1063      height/width dimension as the input.
1064    strides: Optional.  Sequence of N ints >= 1.  Specifies the output stride.
1065      Defaults to [1]*N.  If any value of strides is > 1, then all values of
1066      dilation_rate must be 1.
1067    dilation_rate: Optional.  Sequence of N ints >= 1.  Specifies the filter
1068      upsampling/input downsampling rate.  In the literature, the same parameter
1069      is sometimes called `input stride` or `dilation`.  The effective filter
1070      size used for the convolution will be `spatial_filter_shape +
1071      (spatial_filter_shape - 1) * (rate - 1)`, obtained by inserting
1072      (dilation_rate[i]-1) zeros between consecutive elements of the original
1073      filter in each spatial dimension i.  If any value of dilation_rate is > 1,
1074      then all values of strides must be 1.
1075    name: Optional name for the returned tensor.
1076    data_format: A string or None.  Specifies whether the channel dimension of
1077      the `input` and output is the last dimension (default, or if `data_format`
1078      does not start with "NC"), or the second dimension (if `data_format`
1079      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1080      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
1081      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
1082
1083  Returns:
1084    A `Tensor` with the same type as `input` of shape
1085
1086        `[batch_size] + output_spatial_shape + [out_channels]`
1087
1088    if data_format is None or does not start with "NC", or
1089
1090        `[batch_size, out_channels] + output_spatial_shape`
1091
1092    if data_format starts with "NC",
1093    where `output_spatial_shape` depends on the value of `padding`.
1094
1095    If padding == "SAME":
1096      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1097
1098    If padding == "VALID":
1099      output_spatial_shape[i] =
1100        ceil((input_spatial_shape[i] -
1101              (spatial_filter_shape[i]-1) * dilation_rate[i])
1102             / strides[i]).
1103
1104  Raises:
1105    ValueError: If input/output depth does not match `filter` shape, if padding
1106      is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
1107
1108  """
1109  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
1110  dilation_rate = deprecated_argument_lookup(
1111      "dilations", dilations, "dilation_rate", dilation_rate)
1112  return convolution_internal(
1113      input,
1114      filter,
1115      strides=strides,
1116      padding=padding,
1117      data_format=data_format,
1118      dilations=dilation_rate,
1119      name=name)
1120
1121
1122@tf_export("nn.convolution", v1=[])
1123@dispatch.add_dispatch_support
1124def convolution_v2(  # pylint: disable=missing-docstring
1125    input,  # pylint: disable=redefined-builtin
1126    filters,
1127    strides=None,
1128    padding="VALID",
1129    data_format=None,
1130    dilations=None,
1131    name=None):
1132  return convolution_internal(
1133      input,  # pylint: disable=redefined-builtin
1134      filters,
1135      strides=strides,
1136      padding=padding,
1137      data_format=data_format,
1138      dilations=dilations,
1139      name=name)
1140
1141
1142convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
1143    deprecation.rewrite_argument_docstring(
1144        convolution.__doc__, "dilation_rate", "dilations"),
1145    "filter", "filters")
1146
1147
1148def convolution_internal(
1149    input,  # pylint: disable=redefined-builtin
1150    filters,
1151    strides=None,
1152    padding="VALID",
1153    data_format=None,
1154    dilations=None,
1155    name=None,
1156    call_from_convolution=True,
1157    num_spatial_dims=None):
1158  """Internal function which performs rank agnostic convolution.
1159
1160  Args:
1161    input: See `convolution`.
1162    filters: See `convolution`.
1163    strides: See `convolution`.
1164    padding: See `convolution`.
1165    data_format: See `convolution`.
1166    dilations: See `convolution`.
1167    name: See `convolution`.
1168    call_from_convolution: See `convolution`.
1169    num_spatial_dims: (Optional.).  It is a integer describing the
1170      rank of the spatial dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1171      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1172      This argument is only required to disambiguate the rank of `batch_shape`
1173      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1174      backwards compatibility, if `num_spatial_dims is None` and
1175     `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1176     `1` (i.e., the input is expected to be
1177     `[batch_size, num_channels] + input_spatial_shape`
1178     or `[batch_size] + input_spatial_shape + [num_channels]`.
1179
1180  Returns:
1181    A tensor of shape and dtype matching that of `input`.
1182
1183  Raises:
1184    ValueError: If input and filter both have unknown shapes, or if
1185      `num_spatial_dims` is provided and incompatible with the value
1186      estimated from `filters.shape`.
1187  """
1188  if (not isinstance(filters, variables_lib.Variable) and
1189      not tensor_util.is_tf_type(filters)):
1190    with ops.name_scope("convolution_internal", None, [filters, input]):
1191      filters = ops.convert_to_tensor(filters, name='filters')
1192  if (not isinstance(input, ops.Tensor) and not tensor_util.is_tf_type(input)):
1193    with ops.name_scope("convolution_internal", None, [filters, input]):
1194      input = ops.convert_to_tensor(input, name="input")
1195
1196  filters_rank = filters.shape.rank
1197  inputs_rank = input.shape.rank
1198  if num_spatial_dims is None:
1199    if filters_rank:
1200      num_spatial_dims = filters_rank - 2
1201    elif inputs_rank:
1202      num_spatial_dims = inputs_rank - 2
1203    else:
1204      raise ValueError("rank of input or filter must be known")
1205  elif filters_rank and filters_rank - 2 != num_spatial_dims:
1206    raise ValueError(
1207        "inconsistent estimate of spatial dims ({}) vs. actual passed "
1208        "num_spatial_dims ({}).  n was estimated as len(filters.shape) - 2, "
1209        "but filters shape is: {}".format(filters_rank, num_spatial_dims,
1210                                          filters.shape))
1211
1212  if inputs_rank:
1213    num_batch_dims = inputs_rank - num_spatial_dims - 1  # Channel dimension.
1214  else:
1215    num_batch_dims = 1  # By default, assume single batch dimension.
1216
1217  if num_spatial_dims not in {1, 2, 3}:
1218    raise ValueError(
1219        "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one "
1220        "of 1, 2 or 3 but saw {}.  num_batch_dims: {}.".format(
1221            num_spatial_dims, num_batch_dims))
1222
1223  if data_format is None or data_format in _CHANNELS_LAST_FORMATS:
1224    channel_index = num_batch_dims + num_spatial_dims
1225  else:
1226    channel_index = num_batch_dims
1227
1228  if dilations is None:
1229    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1230                              "dilations")
1231    is_dilated_conv = False
1232  else:
1233    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1234                              "dilations")
1235    is_dilated_conv = any(i != 1 for i in dilations)
1236
1237  strides = _get_sequence(strides, num_spatial_dims, channel_index, "strides")
1238  has_tpu_context = device_context.enclosing_tpu_context() is not None
1239
1240  if name:
1241    default_name = None
1242  elif not has_tpu_context or call_from_convolution:
1243    default_name = "convolution"
1244  elif num_spatial_dims == 2:  # Most common case.
1245    default_name = "Conv2D"
1246  elif num_spatial_dims == 3:
1247    default_name = "Conv3D"
1248  else:
1249    default_name = "conv1d"
1250
1251  with ops.name_scope(name, default_name, [input, filters]) as name:
1252    # Fast path for TPU or if no dilation, as gradient only supported on TPU
1253    # for dilations.
1254    if not is_dilated_conv or has_tpu_context:
1255      if num_spatial_dims == 2:  # Most common case.
1256        op = _conv2d_expanded_batch
1257      elif num_spatial_dims == 3:
1258        op = _conv3d_expanded_batch
1259      else:
1260        op = conv1d
1261
1262      return op(
1263          input,
1264          filters,
1265          strides,
1266          padding=padding,
1267          data_format=data_format,
1268          dilations=dilations,
1269          name=name)
1270    else:
1271      if channel_index == 1:
1272        strides = strides[2:]
1273        dilations = dilations[2:]
1274      else:
1275        strides = strides[1:-1]
1276        dilations = dilations[1:-1]
1277
1278      op = Convolution(
1279          tensor_shape.as_shape(input.shape),
1280          tensor_shape.as_shape(filters.shape),
1281          padding,
1282          strides=strides,
1283          dilation_rate=dilations,
1284          name=name,
1285          data_format=data_format,
1286          num_spatial_dims=num_spatial_dims)
1287      return op(input, filters)
1288
1289
1290class Convolution(object):
1291  """Helper class for convolution.
1292
1293  Note that this class assumes that shapes of input and filter passed to
1294  `__call__` are compatible with `input_shape`, `filter_shape`, and
1295  `num_spatial_dims` passed to the constructor.
1296
1297  Arguments
1298    input_shape: static shape of input. i.e. input.shape.  Its length is
1299      `batch_shape + input_spatial_shape + [num_channels]` if `data_format`
1300      does not start with `NC`, or
1301      `batch_shape + [num_channels] + input_spatial_shape` if `data_format`
1302      starts with `NC`.
1303    filter_shape: static shape of the filter. i.e. filter.shape.
1304    padding: The padding algorithm, must be "SAME" or "VALID".
1305    strides: see convolution.
1306    dilation_rate: see convolution.
1307    name: see convolution.
1308    data_format: A string or `None`.  Specifies whether the channel dimension of
1309      the `input` and output is the last dimension (if `data_format` is `None`
1310      or does not start with `NC`), or the first post-batch dimension (i.e. if
1311      `data_format` starts with `NC`).
1312    num_spatial_dims: (Usually optional.)  Python integer, the rank of the
1313      spatial and channel dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1314      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1315      This argument is only required to disambiguate the rank of `batch_shape`
1316      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1317      backwards compatibility, if `num_spatial_dims is None` and
1318      `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1319      `1` (i.e., the input is expected to be
1320      `[batch_size, num_channels] + input_spatial_shape`
1321      or `[batch_size] + input_spatial_shape + [num_channels]`.
1322  """
1323
1324  def __init__(self,
1325               input_shape,
1326               filter_shape,
1327               padding,
1328               strides=None,
1329               dilation_rate=None,
1330               name=None,
1331               data_format=None,
1332               num_spatial_dims=None):
1333    """Helper function for convolution."""
1334    num_batch_dims = None
1335    filter_shape = tensor_shape.as_shape(filter_shape)
1336    input_shape = tensor_shape.as_shape(input_shape)
1337
1338    if filter_shape.ndims is not None:
1339      if (num_spatial_dims is not None and
1340          filter_shape.ndims != num_spatial_dims + 2):
1341        raise ValueError(
1342            "Expected filter_shape.ndims == num_spatial_dims + 2, "
1343            "but saw filter_shape.ndims == {} and num_spatial_dims == {}"
1344            .format(filter_shape.ndims, num_spatial_dims))
1345      else:
1346        num_spatial_dims = filter_shape.ndims - 2
1347
1348    if input_shape.ndims is not None and num_spatial_dims is not None:
1349      num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1350
1351    if num_spatial_dims is None:
1352      num_spatial_dims = input_shape.ndims - 2
1353    else:
1354      if input_shape.ndims is not None:
1355        if input_shape.ndims < num_spatial_dims + 2:
1356          raise ValueError(
1357              "Expected input_shape.ndims >= num_spatial_dims + 2, but saw "
1358              "input_shape.ndims == {} and num_spatial_dims == {}"
1359              .format(input_shape.ndims, num_spatial_dims))
1360        else:
1361          if num_batch_dims is None:
1362            num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1363
1364    if num_spatial_dims is None:
1365      raise ValueError(
1366          "Cannot estimate num_spatial_dims since input_shape.ndims is None, "
1367          "filter_shape.ndims is None, and argument num_spatial_dims is also "
1368          "None.")
1369
1370    if num_batch_dims is None:
1371      num_batch_dims = 1
1372
1373    if num_batch_dims < 1:
1374      raise ValueError(
1375          "num_batch_dims should be >= 1, but saw {}.  num_batch_dims was "
1376          "estimated as `input_shape.ndims - num_spatial_dims - 1` and "
1377          "num_spatial_dims was either provided or estimated as "
1378          "`filter_shape.ndims - 2`.  input_shape.ndims: {}, "
1379          "num_spatial_dims: {}, filter_shape.ndims: {}"
1380          .format(num_batch_dims, input_shape.ndims, num_spatial_dims,
1381                  filter_shape.ndims))
1382
1383    if data_format is None or not data_format.startswith("NC"):
1384      input_channels_dim = tensor_shape.dimension_at_index(
1385          input_shape, num_spatial_dims + num_batch_dims)
1386      spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
1387    else:
1388      input_channels_dim = tensor_shape.dimension_at_index(
1389          input_shape, num_batch_dims)
1390      spatial_dims = range(
1391          num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
1392
1393    filter_dim = tensor_shape.dimension_at_index(filter_shape, num_spatial_dims)
1394    if not (input_channels_dim % filter_dim).is_compatible_with(0):
1395      raise ValueError("The number of input channels is not divisible by the "
1396                       "corresponding number of output filters. Received: "
1397                       "input channels={}, output filters={}".format(
1398                           input_channels_dim, filter_dim))
1399
1400    strides, dilation_rate = _get_strides_and_dilation_rate(
1401        num_spatial_dims, strides, dilation_rate)
1402
1403    self.input_shape = input_shape
1404    self.filter_shape = filter_shape
1405    self.data_format = data_format
1406    self.strides = strides
1407    self.padding = padding
1408    self.name = name
1409    self.dilation_rate = dilation_rate
1410    self.num_batch_dims = num_batch_dims
1411    self.num_spatial_dims = num_spatial_dims
1412    self.conv_op = _WithSpaceToBatch(
1413        input_shape,
1414        dilation_rate=dilation_rate,
1415        padding=padding,
1416        build_op=self._build_op,
1417        filter_shape=filter_shape,
1418        spatial_dims=spatial_dims,
1419        data_format=data_format,
1420        num_batch_dims=num_batch_dims)
1421
1422  def _build_op(self, _, padding):
1423    return _NonAtrousConvolution(
1424        self.input_shape,
1425        filter_shape=self.filter_shape,
1426        padding=padding,
1427        data_format=self.data_format,
1428        strides=self.strides,
1429        name=self.name,
1430        num_batch_dims=self.num_batch_dims)
1431
1432  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
1433    # TPU convolution supports dilations greater than 1.
1434    if device_context.enclosing_tpu_context() is not None:
1435      return convolution_internal(
1436          inp,
1437          filter,
1438          strides=self.strides,
1439          padding=self.padding,
1440          data_format=self.data_format,
1441          dilations=self.dilation_rate,
1442          name=self.name,
1443          call_from_convolution=False,
1444          num_spatial_dims=self.num_spatial_dims)
1445    else:
1446      return self.conv_op(inp, filter)
1447
1448
1449@tf_export(v1=["nn.pool"])
1450@dispatch.add_dispatch_support
1451def pool(
1452    input,  # pylint: disable=redefined-builtin
1453    window_shape,
1454    pooling_type,
1455    padding,
1456    dilation_rate=None,
1457    strides=None,
1458    name=None,
1459    data_format=None,
1460    dilations=None):
1461  """Performs an N-D pooling operation.
1462
1463  In the case that `data_format` does not start with "NC", computes for
1464      0 <= b < batch_size,
1465      0 <= x[i] < output_spatial_shape[i],
1466      0 <= c < num_channels:
1467
1468  ```
1469    output[b, x[0], ..., x[N-1], c] =
1470      REDUCE_{z[0], ..., z[N-1]}
1471        input[b,
1472              x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1473              ...
1474              x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1475              c],
1476  ```
1477
1478  where the reduction function REDUCE depends on the value of `pooling_type`,
1479  and pad_before is defined based on the value of `padding` as described in
1480  the "returns" section of `tf.nn.convolution` for details.
1481  The reduction never includes out-of-bounds positions.
1482
1483  In the case that `data_format` starts with `"NC"`, the `input` and output are
1484  simply transposed as follows:
1485
1486  ```
1487    pool(input, data_format, **kwargs) =
1488      tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1489                        **kwargs),
1490                   [0, N+1] + range(1, N+1))
1491  ```
1492
1493  Args:
1494    input: Tensor of rank N+2, of shape
1495      `[batch_size] + input_spatial_shape + [num_channels]` if data_format does
1496      not start with "NC" (default), or
1497      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1498      with "NC".  Pooling happens over the spatial dimensions only.
1499    window_shape: Sequence of N ints >= 1.
1500    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1501    padding: The padding algorithm, must be "SAME" or "VALID".
1502      See the "returns" section of `tf.nn.convolution` for details.
1503    dilation_rate: Optional.  Dilation rate.  List of N ints >= 1.
1504      Defaults to [1]*N.  If any value of dilation_rate is > 1, then all values
1505      of strides must be 1.
1506    strides: Optional.  Sequence of N ints >= 1.  Defaults to [1]*N.
1507      If any value of strides is > 1, then all values of dilation_rate must be
1508      1.
1509    name: Optional. Name of the op.
1510    data_format: A string or None.  Specifies whether the channel dimension of
1511      the `input` and output is the last dimension (default, or if `data_format`
1512      does not start with "NC"), or the second dimension (if `data_format`
1513      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1514      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
1515      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
1516    dilations: Alias for dilation_rate
1517
1518  Returns:
1519    Tensor of rank N+2, of shape
1520      [batch_size] + output_spatial_shape + [num_channels]
1521
1522    if data_format is None or does not start with "NC", or
1523
1524      [batch_size, num_channels] + output_spatial_shape
1525
1526    if data_format starts with "NC",
1527    where `output_spatial_shape` depends on the value of padding:
1528
1529    If padding = "SAME":
1530      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1531
1532    If padding = "VALID":
1533      output_spatial_shape[i] =
1534        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1535             / strides[i]).
1536
1537  Raises:
1538    ValueError: if arguments are invalid.
1539
1540  """
1541  dilation_rate = deprecated_argument_lookup(
1542      "dilations", dilations, "dilation_rate", dilation_rate)
1543  # pylint: enable=line-too-long
1544  with ops.name_scope(name, "%s_pool" % (pooling_type.lower()),
1545                      [input]) as scope:
1546    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
1547
1548    num_spatial_dims = len(window_shape)
1549    if num_spatial_dims < 1 or num_spatial_dims > 3:
1550      raise ValueError("It is required that 1 <= num_spatial_dims <= 3.")
1551
1552    input.get_shape().with_rank(num_spatial_dims + 2)
1553
1554    strides, dilation_rate = _get_strides_and_dilation_rate(
1555        num_spatial_dims, strides, dilation_rate)
1556
1557    if padding == "SAME" and np.any(dilation_rate > 1):
1558      raise ValueError(
1559          "pooling with SAME padding is not implemented for dilation_rate > 1")
1560
1561    if np.any(strides > window_shape):
1562      raise ValueError(
1563          "strides > window_shape not supported due to inconsistency between "
1564          "CPU and GPU implementations")
1565
1566    pooling_ops = {
1567        ("MAX", 1): max_pool,
1568        ("MAX", 2): max_pool,
1569        ("MAX", 3): max_pool3d,  # pylint: disable=undefined-variable
1570        ("AVG", 1): avg_pool,
1571        ("AVG", 2): avg_pool,
1572        ("AVG", 3): avg_pool3d,  # pylint: disable=undefined-variable
1573    }
1574    op_key = (pooling_type, num_spatial_dims)
1575    if op_key not in pooling_ops:
1576      raise ValueError("%d-D %s pooling is not supported." % (op_key[1],
1577                                                              op_key[0]))
1578
1579    if data_format is None or not data_format.startswith("NC"):
1580      adjusted_window_shape = [1] + list(window_shape) + [1]
1581      adjusted_strides = [1] + list(strides) + [1]
1582      spatial_dims = range(1, num_spatial_dims + 1)
1583    else:
1584      adjusted_window_shape = [1, 1] + list(window_shape)
1585      adjusted_strides = [1, 1] + list(strides)
1586      spatial_dims = range(2, num_spatial_dims + 2)
1587
1588    if num_spatial_dims == 1:
1589      if data_format is None or data_format == "NWC":
1590        data_format_kwargs = dict(data_format="NHWC")
1591      elif data_format == "NCW":
1592        data_format_kwargs = dict(data_format="NCHW")
1593      else:
1594        raise ValueError("data_format must be either \"NWC\" or \"NCW\".")
1595      adjusted_window_shape = [1] + adjusted_window_shape
1596      adjusted_strides = [1] + adjusted_strides
1597    else:
1598      data_format_kwargs = dict(data_format=data_format)
1599
1600    def op(converted_input, _, converted_padding):  # pylint: disable=missing-docstring
1601      if num_spatial_dims == 1:
1602        converted_input = array_ops.expand_dims(converted_input,
1603                                                spatial_dims[0])
1604      result = pooling_ops[op_key](
1605          converted_input,
1606          adjusted_window_shape,
1607          adjusted_strides,
1608          converted_padding,
1609          name=scope,
1610          **data_format_kwargs)
1611      if num_spatial_dims == 1:
1612        result = array_ops.squeeze(result, [spatial_dims[0]])
1613      return result
1614
1615    return with_space_to_batch(
1616        input=input,
1617        dilation_rate=dilation_rate,
1618        padding=padding,
1619        op=op,
1620        spatial_dims=spatial_dims,
1621        filter_shape=window_shape)
1622
1623
1624@tf_export("nn.pool", v1=[])
1625@dispatch.add_dispatch_support
1626def pool_v2(
1627    input,  # pylint: disable=redefined-builtin
1628    window_shape,
1629    pooling_type,
1630    strides=None,
1631    padding="VALID",
1632    data_format=None,
1633    dilations=None,
1634    name=None):
1635  # pylint: disable=line-too-long
1636  """Performs an N-D pooling operation.
1637
1638  In the case that `data_format` does not start with "NC", computes for
1639      0 <= b < batch_size,
1640      0 <= x[i] < output_spatial_shape[i],
1641      0 <= c < num_channels:
1642
1643  ```
1644    output[b, x[0], ..., x[N-1], c] =
1645      REDUCE_{z[0], ..., z[N-1]}
1646        input[b,
1647              x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1648              ...
1649              x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1650              c],
1651  ```
1652
1653  where the reduction function REDUCE depends on the value of `pooling_type`,
1654  and pad_before is defined based on the value of `padding` as described in
1655  the "returns" section of `tf.nn.convolution` for details.
1656  The reduction never includes out-of-bounds positions.
1657
1658  In the case that `data_format` starts with `"NC"`, the `input` and output are
1659  simply transposed as follows:
1660
1661  ```
1662    pool(input, data_format, **kwargs) =
1663      tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1664                        **kwargs),
1665                   [0, N+1] + range(1, N+1))
1666  ```
1667
1668  Args:
1669    input: Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
1670      [num_channels]` if data_format does not start with "NC" (default), or
1671      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1672      with "NC".  Pooling happens over the spatial dimensions only.
1673    window_shape: Sequence of N ints >= 1.
1674    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1675    strides: Optional. Sequence of N ints >= 1.  Defaults to [1]*N. If any value of
1676      strides is > 1, then all values of dilation_rate must be 1.
1677    padding: The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
1678      See the "returns" section of `tf.nn.convolution` for details.
1679    data_format: A string or None.  Specifies whether the channel dimension of
1680      the `input` and output is the last dimension (default, or if `data_format`
1681      does not start with "NC"), or the second dimension (if `data_format`
1682      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1683      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW". For
1684      N=3, the valid values are "NDHWC" (default) and "NCDHW".
1685    dilations: Optional.  Dilation rate.  List of N ints >= 1. Defaults to
1686      [1]*N.  If any value of dilation_rate is > 1, then all values of strides
1687      must be 1.
1688    name: Optional. Name of the op.
1689
1690  Returns:
1691    Tensor of rank N+2, of shape
1692      [batch_size] + output_spatial_shape + [num_channels]
1693
1694    if data_format is None or does not start with "NC", or
1695
1696      [batch_size, num_channels] + output_spatial_shape
1697
1698    if data_format starts with "NC",
1699    where `output_spatial_shape` depends on the value of padding:
1700
1701    If padding = "SAME":
1702      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1703
1704    If padding = "VALID":
1705      output_spatial_shape[i] =
1706        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1707             / strides[i]).
1708
1709  Raises:
1710    ValueError: if arguments are invalid.
1711
1712  """
1713  return pool(
1714      input=input,
1715      window_shape=window_shape,
1716      pooling_type=pooling_type,
1717      padding=padding,
1718      dilation_rate=dilations,
1719      strides=strides,
1720      name=name,
1721      data_format=data_format)
1722
1723
1724@tf_export("nn.atrous_conv2d")
1725@dispatch.add_dispatch_support
1726def atrous_conv2d(value, filters, rate, padding, name=None):
1727  """Atrous convolution (a.k.a. convolution with holes or dilated convolution).
1728
1729  This function is a simpler wrapper around the more general
1730  `tf.nn.convolution`, and exists only for backwards compatibility. You can
1731  use `tf.nn.convolution` to perform 1-D, 2-D, or 3-D atrous convolution.
1732
1733
1734  Computes a 2-D atrous convolution, also known as convolution with holes or
1735  dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
1736  parameter is equal to one, it performs regular 2-D convolution. If the `rate`
1737  parameter is greater than one, it performs convolution with holes, sampling
1738  the input values every `rate` pixels in the `height` and `width` dimensions.
1739  This is equivalent to convolving the input with a set of upsampled filters,
1740  produced by inserting `rate - 1` zeros between two consecutive values of the
1741  filters along the `height` and `width` dimensions, hence the name atrous
1742  convolution or convolution with holes (the French word trous means holes in
1743  English).
1744
1745  More specifically:
1746
1747  ```
1748  output[batch, height, width, out_channel] =
1749      sum_{dheight, dwidth, in_channel} (
1750          filters[dheight, dwidth, in_channel, out_channel] *
1751          value[batch, height + rate*dheight, width + rate*dwidth, in_channel]
1752      )
1753  ```
1754
1755  Atrous convolution allows us to explicitly control how densely to compute
1756  feature responses in fully convolutional networks. Used in conjunction with
1757  bilinear interpolation, it offers an alternative to `conv2d_transpose` in
1758  dense prediction tasks such as semantic image segmentation, optical flow
1759  computation, or depth estimation. It also allows us to effectively enlarge
1760  the field of view of filters without increasing the number of parameters or
1761  the amount of computation.
1762
1763  For a description of atrous convolution and how it can be used for dense
1764  feature extraction, please see: (Chen et al., 2015). The same operation is
1765  investigated further in (Yu et al., 2016). Previous works that effectively
1766  use atrous convolution in different ways are, among others,
1767  (Sermanet et al., 2014) and (Giusti et al., 2013).
1768  Atrous convolution is also closely related to the so-called noble identities
1769  in multi-rate signal processing.
1770
1771  There are many different ways to implement atrous convolution (see the refs
1772  above). The implementation here reduces
1773
1774  ```python
1775      atrous_conv2d(value, filters, rate, padding=padding)
1776  ```
1777
1778  to the following three operations:
1779
1780  ```python
1781      paddings = ...
1782      net = space_to_batch(value, paddings, block_size=rate)
1783      net = conv2d(net, filters, strides=[1, 1, 1, 1], padding="VALID")
1784      crops = ...
1785      net = batch_to_space(net, crops, block_size=rate)
1786  ```
1787
1788  Advanced usage. Note the following optimization: A sequence of `atrous_conv2d`
1789  operations with identical `rate` parameters, 'SAME' `padding`, and filters
1790  with odd heights/ widths:
1791
1792  ```python
1793      net = atrous_conv2d(net, filters1, rate, padding="SAME")
1794      net = atrous_conv2d(net, filters2, rate, padding="SAME")
1795      ...
1796      net = atrous_conv2d(net, filtersK, rate, padding="SAME")
1797  ```
1798
1799  can be equivalently performed cheaper in terms of computation and memory as:
1800
1801  ```python
1802      pad = ...  # padding so that the input dims are multiples of rate
1803      net = space_to_batch(net, paddings=pad, block_size=rate)
1804      net = conv2d(net, filters1, strides=[1, 1, 1, 1], padding="SAME")
1805      net = conv2d(net, filters2, strides=[1, 1, 1, 1], padding="SAME")
1806      ...
1807      net = conv2d(net, filtersK, strides=[1, 1, 1, 1], padding="SAME")
1808      net = batch_to_space(net, crops=pad, block_size=rate)
1809  ```
1810
1811  because a pair of consecutive `space_to_batch` and `batch_to_space` ops with
1812  the same `block_size` cancel out when their respective `paddings` and `crops`
1813  inputs are identical.
1814
1815  Args:
1816    value: A 4-D `Tensor` of type `float`. It needs to be in the default "NHWC"
1817      format. Its shape is `[batch, in_height, in_width, in_channels]`.
1818    filters: A 4-D `Tensor` with the same type as `value` and shape
1819      `[filter_height, filter_width, in_channels, out_channels]`. `filters`'
1820      `in_channels` dimension must match that of `value`. Atrous convolution is
1821      equivalent to standard convolution with upsampled filters with effective
1822      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
1823      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
1824      inserting `rate - 1` zeros along consecutive elements across the
1825      `filters`' spatial dimensions.
1826    rate: A positive int32. The stride with which we sample input values across
1827      the `height` and `width` dimensions. Equivalently, the rate by which we
1828      upsample the filter values by inserting zeros across the `height` and
1829      `width` dimensions. In the literature, the same parameter is sometimes
1830      called `input stride` or `dilation`.
1831    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
1832    name: Optional name for the returned tensor.
1833
1834  Returns:
1835    A `Tensor` with the same type as `value`.
1836    Output shape with `'VALID'` padding is:
1837
1838        [batch, height - 2 * (filter_width - 1),
1839         width - 2 * (filter_height - 1), out_channels].
1840
1841    Output shape with `'SAME'` padding is:
1842
1843        [batch, height, width, out_channels].
1844
1845  Raises:
1846    ValueError: If input/output depth does not match `filters`' shape, or if
1847      padding is other than `'VALID'` or `'SAME'`.
1848
1849  References:
1850    Multi-Scale Context Aggregation by Dilated Convolutions:
1851      [Yu et al., 2016](https://arxiv.org/abs/1511.07122)
1852      ([pdf](https://arxiv.org/pdf/1511.07122.pdf))
1853    Semantic Image Segmentation with Deep Convolutional Nets and Fully
1854    Connected CRFs:
1855      [Chen et al., 2015](http://arxiv.org/abs/1412.7062)
1856      ([pdf](https://arxiv.org/pdf/1412.7062))
1857    OverFeat - Integrated Recognition, Localization and Detection using
1858    Convolutional Networks:
1859      [Sermanet et al., 2014](https://arxiv.org/abs/1312.6229)
1860      ([pdf](https://arxiv.org/pdf/1312.6229.pdf))
1861    Fast Image Scanning with Deep Max-Pooling Convolutional Neural Networks:
1862      [Giusti et al., 2013]
1863      (https://ieeexplore.ieee.org/abstract/document/6738831)
1864      ([pdf](https://arxiv.org/pdf/1302.1700.pdf))
1865  """
1866  return convolution(
1867      input=value,
1868      filter=filters,
1869      padding=padding,
1870      dilation_rate=np.broadcast_to(rate, (2,)),
1871      name=name)
1872
1873
1874def convert_padding(padding, expected_length=4):
1875  """Converts Python padding to C++ padding for ops which take EXPLICIT padding.
1876
1877  Args:
1878    padding: the `padding` argument for a Python op which supports EXPLICIT
1879      padding.
1880    expected_length: Expected number of entries in the padding list when
1881      explicit padding is used.
1882
1883  Returns:
1884    (padding, explicit_paddings) pair, which should be passed as attributes to a
1885    C++ op.
1886
1887  Raises:
1888    ValueError: If padding is invalid.
1889  """
1890  explicit_paddings = []
1891  if padding == "EXPLICIT":
1892    # Give a better error message if EXPLICIT is passed.
1893    raise ValueError('"EXPLICIT" is not a valid value for the padding '
1894                     "parameter. To use explicit padding, the padding "
1895                     "parameter must be a list.")
1896  if isinstance(padding, (list, tuple)):
1897    for i, dim_paddings in enumerate(padding):
1898      if not isinstance(dim_paddings, (list, tuple)):
1899        raise ValueError("When padding is a list, each element of padding must "
1900                         "be a list/tuple of size 2. Element with index %d of "
1901                         "padding is not a list/tuple" % i)
1902      if len(dim_paddings) != 2:
1903        raise ValueError("When padding is a list, each element of padding must "
1904                         "be a list/tuple of size 2. Element with index %d of "
1905                         "padding has size %d" % (i, len(dim_paddings)))
1906      explicit_paddings.extend(dim_paddings)
1907    if len(padding) != expected_length:
1908      raise ValueError("When padding is a list, it must be of size %d. Got "
1909                       "padding of size: %d" % (expected_length, len(padding)))
1910    padding = "EXPLICIT"
1911  return padding, explicit_paddings
1912
1913
1914@tf_export(v1=["nn.conv1d"])
1915@dispatch.add_dispatch_support
1916@deprecation.deprecated_arg_values(
1917    None,
1918    "`NCHW` for data_format is deprecated, use `NCW` instead",
1919    warn_once=True,
1920    data_format="NCHW")
1921@deprecation.deprecated_arg_values(
1922    None,
1923    "`NHWC` for data_format is deprecated, use `NWC` instead",
1924    warn_once=True,
1925    data_format="NHWC")
1926def conv1d(
1927    value=None,
1928    filters=None,
1929    stride=None,
1930    padding=None,
1931    use_cudnn_on_gpu=None,
1932    data_format=None,
1933    name=None,
1934    input=None,  # pylint: disable=redefined-builtin
1935    dilations=None):
1936  r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
1937
1938  Given an input tensor of shape
1939    `batch_shape + [in_width, in_channels]`
1940  if `data_format` is `"NWC"`, or
1941    `batch_shape + [in_channels, in_width]`
1942  if `data_format` is `"NCW"`,
1943  and a filter / kernel tensor of shape
1944  `[filter_width, in_channels, out_channels]`, this op reshapes
1945  the arguments to pass them to `conv2d` to perform the equivalent
1946  convolution operation.
1947
1948  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
1949  For example, if `data_format` does not start with "NC", a tensor of shape
1950    `batch_shape + [in_width, in_channels]`
1951  is reshaped to
1952    `batch_shape + [1, in_width, in_channels]`,
1953  and the filter is reshaped to
1954    `[1, filter_width, in_channels, out_channels]`.
1955  The result is then reshaped back to
1956    `batch_shape + [out_width, out_channels]`
1957  \(where out_width is a function of the stride and padding as in conv2d\) and
1958  returned to the caller.
1959
1960  Args:
1961    value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
1962      `float64`.
1963    filters: A Tensor of rank at least 3.  Must have the same type as `value`.
1964    stride: An int or list of `ints` that has length `1` or `3`.  The number of
1965      entries by which the filter is moved right at each step.
1966    padding: 'SAME' or 'VALID'
1967    use_cudnn_on_gpu: An optional `bool`.  Defaults to `True`.
1968    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
1969      the data is stored in the order of `batch_shape + [in_width,
1970      in_channels]`.  The `"NCW"` format stores data as `batch_shape +
1971      [in_channels, in_width]`.
1972    name: A name for the operation (optional).
1973    input: Alias for value.
1974    dilations: An int or list of `ints` that has length `1` or `3` which
1975      defaults to 1. The dilation factor for each dimension of input. If set to
1976      k > 1, there will be k-1 skipped cells between each filter element on that
1977      dimension. Dilations in the batch and depth dimensions must be 1.
1978
1979  Returns:
1980    A `Tensor`.  Has the same type as input.
1981
1982  Raises:
1983    ValueError: if `data_format` is invalid.
1984  """
1985  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
1986  with ops.name_scope(name, "conv1d", [value, filters]) as name:
1987    # Reshape the input tensor to batch_shape + [1, in_width, in_channels]
1988    if data_format is None or data_format == "NHWC" or data_format == "NWC":
1989      data_format = "NHWC"
1990      spatial_start_dim = -3
1991      channel_index = 2
1992    elif data_format == "NCHW" or data_format == "NCW":
1993      data_format = "NCHW"
1994      spatial_start_dim = -2
1995      channel_index = 1
1996    else:
1997      raise ValueError("data_format must be \"NWC\" or \"NCW\".")
1998    strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
1999    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
2000
2001    value = array_ops.expand_dims(value, spatial_start_dim)
2002    filters = array_ops.expand_dims(filters, 0)
2003    if value.shape.ndims in (4, 3, 2, 1, 0, None):
2004      result = gen_nn_ops.conv2d(
2005          value,
2006          filters,
2007          strides,
2008          padding,
2009          use_cudnn_on_gpu=use_cudnn_on_gpu,
2010          data_format=data_format,
2011          dilations=dilations,
2012          name=name)
2013    else:
2014      result = squeeze_batch_dims(
2015          value,
2016          functools.partial(
2017              gen_nn_ops.conv2d,
2018              filter=filters,
2019              strides=strides,
2020              padding=padding,
2021              use_cudnn_on_gpu=use_cudnn_on_gpu,
2022              data_format=data_format,
2023              dilations=dilations,
2024          ),
2025          inner_rank=3,
2026          name=name)
2027    return array_ops.squeeze(result, [spatial_start_dim])
2028
2029
2030@tf_export("nn.conv1d", v1=[])
2031@dispatch.add_dispatch_support
2032def conv1d_v2(
2033    input,  # pylint: disable=redefined-builtin
2034    filters,
2035    stride,
2036    padding,
2037    data_format="NWC",
2038    dilations=None,
2039    name=None):
2040  r"""Computes a 1-D convolution given 3-D input and filter tensors.
2041
2042  Given an input tensor of shape
2043    `batch_shape + [in_width, in_channels]`
2044  if `data_format` is `"NWC"`, or
2045    `batch_shape + [in_channels, in_width]`
2046  if `data_format` is `"NCW"`,
2047  and a filter / kernel tensor of shape
2048  `[filter_width, in_channels, out_channels]`, this op reshapes
2049  the arguments to pass them to `conv2d` to perform the equivalent
2050  convolution operation.
2051
2052  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
2053  For example, if `data_format` does not start with `"NC"`, a tensor of shape
2054    `batch_shape + [in_width, in_channels]`
2055  is reshaped to
2056    `batch_shape + [1, in_width, in_channels]`,
2057  and the filter is reshaped to
2058    `[1, filter_width, in_channels, out_channels]`.
2059  The result is then reshaped back to
2060    `batch_shape + [out_width, out_channels]`
2061  \(where out_width is a function of the stride and padding as in conv2d\) and
2062  returned to the caller.
2063
2064  Args:
2065    input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
2066      `float64`.
2067    filters: A Tensor of rank at least 3.  Must have the same type as `input`.
2068    stride: An int or list of `ints` that has length `1` or `3`.  The number of
2069      entries by which the filter is moved right at each step.
2070    padding: 'SAME' or 'VALID'
2071    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
2072      the data is stored in the order of
2073      `batch_shape + [in_width, in_channels]`.  The `"NCW"` format stores data
2074      as `batch_shape + [in_channels, in_width]`.
2075    dilations: An int or list of `ints` that has length `1` or `3` which
2076      defaults to 1. The dilation factor for each dimension of input. If set to
2077      k > 1, there will be k-1 skipped cells between each filter element on that
2078      dimension. Dilations in the batch and depth dimensions must be 1.
2079    name: A name for the operation (optional).
2080
2081  Returns:
2082    A `Tensor`.  Has the same type as input.
2083
2084  Raises:
2085    ValueError: if `data_format` is invalid.
2086  """
2087  return conv1d(
2088      input,  # pylint: disable=redefined-builtin
2089      filters,
2090      stride,
2091      padding,
2092      use_cudnn_on_gpu=True,
2093      data_format=data_format,
2094      name=name,
2095      dilations=dilations)
2096
2097
2098@tf_export("nn.conv1d_transpose")
2099@dispatch.add_dispatch_support
2100def conv1d_transpose(
2101    input,  # pylint: disable=redefined-builtin
2102    filters,
2103    output_shape,
2104    strides,
2105    padding="SAME",
2106    data_format="NWC",
2107    dilations=None,
2108    name=None):
2109  """The transpose of `conv1d`.
2110
2111  This operation is sometimes called "deconvolution" after
2112  (Zeiler et al., 2010), but is actually the transpose (gradient) of `conv1d`
2113  rather than an actual deconvolution.
2114
2115  Args:
2116    input: A 3-D `Tensor` of type `float` and shape
2117      `[batch, in_width, in_channels]` for `NWC` data format or
2118      `[batch, in_channels, in_width]` for `NCW` data format.
2119    filters: A 3-D `Tensor` with the same type as `input` and shape
2120      `[filter_width, output_channels, in_channels]`.  `filter`'s
2121      `in_channels` dimension must match that of `input`.
2122    output_shape: A 1-D `Tensor`, containing three elements, representing the
2123      output shape of the deconvolution op.
2124    strides: An int or list of `ints` that has length `1` or `3`.  The number of
2125      entries by which the filter is moved right at each step.
2126    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2127      See the "returns" section of `tf.nn.convolution` for details.
2128    data_format: A string. `'NWC'` and `'NCW'` are supported.
2129    dilations: An int or list of `ints` that has length `1` or `3` which
2130      defaults to 1. The dilation factor for each dimension of input. If set to
2131      k > 1, there will be k-1 skipped cells between each filter element on that
2132      dimension. Dilations in the batch and depth dimensions must be 1.
2133    name: Optional name for the returned tensor.
2134
2135  Returns:
2136    A `Tensor` with the same type as `input`.
2137
2138  Raises:
2139    ValueError: If input/output depth does not match `filter`'s shape, if
2140      `output_shape` is not at 3-element vector, if `padding` is other than
2141      `'VALID'` or `'SAME'`, or if `data_format` is invalid.
2142
2143  References:
2144    Deconvolutional Networks:
2145      [Zeiler et al., 2010]
2146      (https://ieeexplore.ieee.org/abstract/document/5539957)
2147      ([pdf]
2148      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2149  """
2150  with ops.name_scope(name, "conv1d_transpose",
2151                      [input, filters, output_shape]) as name:
2152    # The format could be either NWC or NCW, map to NHWC or NCHW
2153    if data_format is None or data_format == "NWC":
2154      data_format = "NHWC"
2155      spatial_start_dim = 1
2156      channel_index = 2
2157    elif data_format == "NCW":
2158      data_format = "NCHW"
2159      spatial_start_dim = 2
2160      channel_index = 1
2161    else:
2162      raise ValueError("data_format must be \"NWC\" or \"NCW\".")
2163
2164    # Reshape the input tensor to [batch, 1, in_width, in_channels]
2165    strides = [1] + _get_sequence(strides, 1, channel_index, "stride")
2166    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
2167
2168    input = array_ops.expand_dims(input, spatial_start_dim)
2169    filters = array_ops.expand_dims(filters, 0)
2170    output_shape = list(output_shape) if not isinstance(
2171        output_shape, ops.Tensor) else output_shape
2172    output_shape = array_ops.concat([output_shape[: spatial_start_dim], [1],
2173                                     output_shape[spatial_start_dim:]], 0)
2174
2175    result = gen_nn_ops.conv2d_backprop_input(
2176        input_sizes=output_shape,
2177        filter=filters,
2178        out_backprop=input,
2179        strides=strides,
2180        padding=padding,
2181        data_format=data_format,
2182        dilations=dilations,
2183        name=name)
2184    return array_ops.squeeze(result, spatial_start_dim)
2185
2186
2187@tf_export("nn.conv2d", v1=[])
2188@dispatch.add_dispatch_support
2189def conv2d_v2(input,  # pylint: disable=redefined-builtin
2190              filters,
2191              strides,
2192              padding,
2193              data_format="NHWC",
2194              dilations=None,
2195              name=None):
2196  # pylint: disable=line-too-long
2197  r"""Computes a 2-D convolution given `input` and 4-D `filters` tensors.
2198
2199  The `input` tensor may have rank `4` or higher, where shape dimensions `[:-3]`
2200  are considered batch dimensions (`batch_shape`).
2201
2202  Given an input tensor of shape
2203  `batch_shape + [in_height, in_width, in_channels]` and a filter / kernel
2204  tensor of shape `[filter_height, filter_width, in_channels, out_channels]`,
2205  this op performs the following:
2206
2207  1. Flattens the filter to a 2-D matrix with shape
2208     `[filter_height * filter_width * in_channels, output_channels]`.
2209  2. Extracts image patches from the input tensor to form a *virtual*
2210     tensor of shape `[batch, out_height, out_width,
2211     filter_height * filter_width * in_channels]`.
2212  3. For each patch, right-multiplies the filter matrix and the image patch
2213     vector.
2214
2215  In detail, with the default NHWC format,
2216
2217      output[b, i, j, k] =
2218          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
2219                          filter[di, dj, q, k]
2220
2221  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2222  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2223
2224  Usage Example:
2225
2226  >>> x_in = np.array([[
2227  ...   [[2], [1], [2], [0], [1]],
2228  ...   [[1], [3], [2], [2], [3]],
2229  ...   [[1], [1], [3], [3], [0]],
2230  ...   [[2], [2], [0], [1], [1]],
2231  ...   [[0], [0], [3], [1], [2]], ]])
2232  >>> kernel_in = np.array([
2233  ...  [ [[2, 0.1]], [[3, 0.2]] ],
2234  ...  [ [[0, 0.3]],[[1, 0.4]] ], ])
2235  >>> x = tf.constant(x_in, dtype=tf.float32)
2236  >>> kernel = tf.constant(kernel_in, dtype=tf.float32)
2237  >>> tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
2238  <tf.Tensor: shape=(1, 4, 4, 2), dtype=float32, numpy=..., dtype=float32)>
2239
2240  Args:
2241    input: A `Tensor`. Must be one of the following types:
2242      `half`, `bfloat16`, `float32`, `float64`.
2243      A Tensor of rank at least 4. The dimension order is interpreted according
2244      to the value of `data_format`; with the all-but-inner-3 dimensions acting
2245      as batch dimensions. See below for details.
2246    filters: A `Tensor`. Must have the same type as `input`.
2247      A 4-D tensor of shape
2248      `[filter_height, filter_width, in_channels, out_channels]`
2249    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2250      stride of the sliding window for each dimension of `input`. If a single
2251      value is given it is replicated in the `H` and `W` dimension. By default
2252      the `N` and `C` dimensions are set to 1. The dimension order is determined
2253      by the value of `data_format`, see below for details.
2254    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2255      padding algorithm to use, or a list indicating the explicit paddings at
2256      the start and end of each dimension. When explicit padding is used and
2257      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2258      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2259      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2260      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2261    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2262      Defaults to `"NHWC"`.
2263      Specify the data format of the input and output data. With the
2264      default format "NHWC", the data is stored in the order of:
2265          `batch_shape + [height, width, channels]`.
2266      Alternatively, the format could be "NCHW", the data storage order of:
2267          `batch_shape + [channels, height, width]`.
2268    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2269      defaults to 1. The dilation factor for each dimension of`input`. If a
2270      single value is given it is replicated in the `H` and `W` dimension. By
2271      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2272      will be k-1 skipped cells between each filter element on that dimension.
2273      The dimension order is determined by the value of `data_format`, see above
2274      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2275      must be 1.
2276    name: A name for the operation (optional).
2277
2278  Returns:
2279    A `Tensor`. Has the same type as `input` and the same outer batch shape.
2280  """
2281  # pylint: enable=line-too-long
2282  return conv2d(input,  # pylint: disable=redefined-builtin
2283                filters,
2284                strides,
2285                padding,
2286                use_cudnn_on_gpu=True,
2287                data_format=data_format,
2288                dilations=dilations,
2289                name=name)
2290
2291
2292@tf_export(v1=["nn.conv2d"])
2293@dispatch.add_dispatch_support
2294def conv2d(  # pylint: disable=redefined-builtin,dangerous-default-value
2295    input,
2296    filter=None,
2297    strides=None,
2298    padding=None,
2299    use_cudnn_on_gpu=True,
2300    data_format="NHWC",
2301    dilations=[1, 1, 1, 1],
2302    name=None,
2303    filters=None):
2304  r"""Computes a 2-D convolution given 4-D `input` and `filter` tensors.
2305
2306  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2307  and a filter / kernel tensor of shape
2308  `[filter_height, filter_width, in_channels, out_channels]`, this op
2309  performs the following:
2310
2311  1. Flattens the filter to a 2-D matrix with shape
2312     `[filter_height * filter_width * in_channels, output_channels]`.
2313  2. Extracts image patches from the input tensor to form a *virtual*
2314     tensor of shape `[batch, out_height, out_width,
2315     filter_height * filter_width * in_channels]`.
2316  3. For each patch, right-multiplies the filter matrix and the image patch
2317     vector.
2318
2319  In detail, with the default NHWC format,
2320
2321      output[b, i, j, k] =
2322          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q]
2323                          * filter[di, dj, q, k]
2324
2325  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2326  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2327
2328  Args:
2329    input: A `Tensor`. Must be one of the following types:
2330      `half`, `bfloat16`, `float32`, `float64`.
2331      A 4-D tensor. The dimension order is interpreted according to the value
2332      of `data_format`, see below for details.
2333    filter: A `Tensor`. Must have the same type as `input`.
2334      A 4-D tensor of shape
2335      `[filter_height, filter_width, in_channels, out_channels]`
2336    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2337      stride of the sliding window for each dimension of `input`. If a single
2338      value is given it is replicated in the `H` and `W` dimension. By default
2339      the `N` and `C` dimensions are set to 1. The dimension order is determined
2340      by the value of `data_format`, see below for details.
2341    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2342      padding algorithm to use, or a list indicating the explicit paddings at
2343      the start and end of each dimension. When explicit padding is used and
2344      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2345      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2346      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2347      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2348    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2349    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2350      Defaults to `"NHWC"`.
2351      Specify the data format of the input and output data. With the
2352      default format "NHWC", the data is stored in the order of:
2353          [batch, height, width, channels].
2354      Alternatively, the format could be "NCHW", the data storage order of:
2355          [batch, channels, height, width].
2356    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2357      defaults to 1. The dilation factor for each dimension of`input`. If a
2358      single value is given it is replicated in the `H` and `W` dimension. By
2359      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2360      will be k-1 skipped cells between each filter element on that dimension.
2361      The dimension order is determined by the value of `data_format`, see above
2362      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2363      must be 1.
2364    name: A name for the operation (optional).
2365    filters: Alias for filter.
2366
2367  Returns:
2368    A `Tensor`. Has the same type as `input`.
2369  """
2370  filter = deprecation.deprecated_argument_lookup(
2371      "filters", filters, "filter", filter)
2372  padding, explicit_paddings = convert_padding(padding)
2373  if data_format is None:
2374    data_format = "NHWC"
2375  channel_index = 1 if data_format.startswith("NC") else 3
2376
2377  strides = _get_sequence(strides, 2, channel_index, "strides")
2378  dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2379
2380  shape = input.shape
2381  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
2382  # we fall back to len(shape).
2383  ndims = getattr(shape, "ndims", -1)
2384  if ndims == -1:
2385    ndims = len(shape)
2386  if ndims in (4, 3, 2, 1, 0, None):
2387    # We avoid calling squeeze_batch_dims to reduce extra python function
2388    # call slowdown in eager mode.  This branch doesn't require reshapes.
2389    return gen_nn_ops.conv2d(
2390        input,
2391        filter=filter,
2392        strides=strides,
2393        padding=padding,
2394        use_cudnn_on_gpu=use_cudnn_on_gpu,
2395        explicit_paddings=explicit_paddings,
2396        data_format=data_format,
2397        dilations=dilations,
2398        name=name)
2399  return squeeze_batch_dims(
2400      input,
2401      functools.partial(
2402          gen_nn_ops.conv2d,
2403          filter=filter,
2404          strides=strides,
2405          padding=padding,
2406          use_cudnn_on_gpu=use_cudnn_on_gpu,
2407          explicit_paddings=explicit_paddings,
2408          data_format=data_format,
2409          dilations=dilations),
2410      inner_rank=3,
2411      name=name)
2412
2413
2414@tf_export(v1=["nn.conv2d_backprop_filter"])
2415@dispatch.add_dispatch_support
2416def conv2d_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
2417    input,
2418    filter_sizes,
2419    out_backprop,
2420    strides,
2421    padding,
2422    use_cudnn_on_gpu=True,
2423    data_format="NHWC",
2424    dilations=[1, 1, 1, 1],
2425    name=None):
2426  r"""Computes the gradients of convolution with respect to the filter.
2427
2428  Args:
2429    input: A `Tensor`. Must be one of the following types:
2430      `half`, `bfloat16`, `float32`, `float64`.
2431      4-D with shape `[batch, in_height, in_width, in_channels]`.
2432    filter_sizes: A `Tensor` of type `int32`.
2433      An integer vector representing the tensor shape of `filter`,
2434      where `filter` is a 4-D
2435      `[filter_height, filter_width, in_channels, out_channels]` tensor.
2436    out_backprop: A `Tensor`. Must have the same type as `input`.
2437      4-D with shape `[batch, out_height, out_width, out_channels]`.
2438      Gradients w.r.t. the output of the convolution.
2439    strides: A list of `ints`.
2440      The stride of the sliding window for each dimension of the input
2441      of the convolution. Must be in the same order as the dimension specified
2442      with format.
2443    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2444      padding algorithm to use, or a list indicating the explicit paddings at
2445      the start and end of each dimension. When explicit padding is used and
2446      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2447      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2448      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2449      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2450    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2451    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2452      Defaults to `"NHWC"`.
2453      Specify the data format of the input and output data. With the
2454      default format "NHWC", the data is stored in the order of:
2455          [batch, in_height, in_width, in_channels].
2456      Alternatively, the format could be "NCHW", the data storage order of:
2457          [batch, in_channels, in_height, in_width].
2458    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2459      1-D tensor of length 4.  The dilation factor for each dimension of
2460      `input`. If set to k > 1, there will be k-1 skipped cells between each
2461      filter element on that dimension. The dimension order is determined by
2462      the value of `data_format`, see above for details. Dilations in the batch
2463      and depth dimensions must be 1.
2464    name: A name for the operation (optional).
2465
2466  Returns:
2467    A `Tensor`. Has the same type as `input`.
2468  """
2469  padding, explicit_paddings = convert_padding(padding)
2470  return gen_nn_ops.conv2d_backprop_filter(
2471      input, filter_sizes, out_backprop, strides, padding, use_cudnn_on_gpu,
2472      explicit_paddings, data_format, dilations, name)
2473
2474
2475@tf_export(v1=["nn.conv2d_backprop_input"])
2476@dispatch.add_dispatch_support
2477def conv2d_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
2478    input_sizes,
2479    filter=None,
2480    out_backprop=None,
2481    strides=None,
2482    padding=None,
2483    use_cudnn_on_gpu=True,
2484    data_format="NHWC",
2485    dilations=[1, 1, 1, 1],
2486    name=None,
2487    filters=None):
2488  r"""Computes the gradients of convolution with respect to the input.
2489
2490  Args:
2491    input_sizes: A `Tensor` of type `int32`.
2492      An integer vector representing the shape of `input`,
2493      where `input` is a 4-D `[batch, height, width, channels]` tensor.
2494    filter: A `Tensor`. Must be one of the following types:
2495      `half`, `bfloat16`, `float32`, `float64`.
2496      4-D with shape
2497      `[filter_height, filter_width, in_channels, out_channels]`.
2498    out_backprop: A `Tensor`. Must have the same type as `filter`.
2499      4-D with shape `[batch, out_height, out_width, out_channels]`.
2500      Gradients w.r.t. the output of the convolution.
2501    strides: A list of `ints`.
2502      The stride of the sliding window for each dimension of the input
2503      of the convolution. Must be in the same order as the dimension specified
2504      with format.
2505    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2506      padding algorithm to use, or a list indicating the explicit paddings at
2507      the start and end of each dimension. When explicit padding is used and
2508      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2509      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2510      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2511      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2512    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2513    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2514      Defaults to `"NHWC"`.
2515      Specify the data format of the input and output data. With the
2516      default format "NHWC", the data is stored in the order of:
2517          [batch, in_height, in_width, in_channels].
2518      Alternatively, the format could be "NCHW", the data storage order of:
2519          [batch, in_channels, in_height, in_width].
2520    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2521      1-D tensor of length 4.  The dilation factor for each dimension of
2522      `input`. If set to k > 1, there will be k-1 skipped cells between each
2523      filter element on that dimension. The dimension order is determined by
2524      the value of `data_format`, see above for details. Dilations in the batch
2525      and depth dimensions must be 1.
2526    name: A name for the operation (optional).
2527    filters: Alias for filter.
2528
2529  Returns:
2530    A `Tensor`. Has the same type as `filter`.
2531  """
2532  filter = deprecation.deprecated_argument_lookup(
2533      "filters", filters, "filter", filter)
2534  padding, explicit_paddings = convert_padding(padding)
2535  return gen_nn_ops.conv2d_backprop_input(
2536      input_sizes, filter, out_backprop, strides, padding, use_cudnn_on_gpu,
2537      explicit_paddings, data_format, dilations, name)
2538
2539
2540@tf_export(v1=["nn.conv2d_transpose"])
2541@dispatch.add_dispatch_support
2542def conv2d_transpose(
2543    value=None,
2544    filter=None,  # pylint: disable=redefined-builtin
2545    output_shape=None,
2546    strides=None,
2547    padding="SAME",
2548    data_format="NHWC",
2549    name=None,
2550    input=None,  # pylint: disable=redefined-builtin
2551    filters=None,
2552    dilations=None):
2553  """The transpose of `conv2d`.
2554
2555  This operation is sometimes called "deconvolution" after
2556  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv2d`
2557  rather than an actual deconvolution.
2558
2559  Args:
2560    value: A 4-D `Tensor` of type `float` and shape
2561      `[batch, height, width, in_channels]` for `NHWC` data format or
2562      `[batch, in_channels, height, width]` for `NCHW` data format.
2563    filter: A 4-D `Tensor` with the same type as `value` and shape
2564      `[height, width, output_channels, in_channels]`.  `filter`'s
2565      `in_channels` dimension must match that of `value`.
2566    output_shape: A 1-D `Tensor` representing the output shape of the
2567      deconvolution op.
2568    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2569      stride of the sliding window for each dimension of `input`. If a single
2570      value is given it is replicated in the `H` and `W` dimension. By default
2571      the `N` and `C` dimensions are set to 0. The dimension order is determined
2572      by the value of `data_format`, see below for details.
2573    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2574      See the "returns" section of `tf.nn.convolution` for details.
2575    data_format: A string. 'NHWC' and 'NCHW' are supported.
2576    name: Optional name for the returned tensor.
2577    input: Alias for value.
2578    filters: Alias for filter.
2579    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2580      defaults to 1. The dilation factor for each dimension of`input`. If a
2581      single value is given it is replicated in the `H` and `W` dimension. By
2582      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2583      will be k-1 skipped cells between each filter element on that dimension.
2584      The dimension order is determined by the value of `data_format`, see above
2585      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2586      must be 1.
2587
2588  Returns:
2589    A `Tensor` with the same type as `value`.
2590
2591  Raises:
2592    ValueError: If input/output depth does not match `filter`'s shape, or if
2593      padding is other than `'VALID'` or `'SAME'`.
2594
2595  References:
2596    Deconvolutional Networks:
2597      [Zeiler et al., 2010]
2598      (https://ieeexplore.ieee.org/abstract/document/5539957)
2599      ([pdf]
2600      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2601  """
2602  value = deprecated_argument_lookup("input", input, "value", value)
2603  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
2604  with ops.name_scope(name, "conv2d_transpose",
2605                      [value, filter, output_shape]) as name:
2606    return conv2d_transpose_v2(
2607        value,
2608        filter,
2609        output_shape,
2610        strides,
2611        padding=padding,
2612        data_format=data_format,
2613        dilations=dilations,
2614        name=name)
2615
2616
2617@tf_export("nn.conv2d_transpose", v1=[])
2618@dispatch.add_dispatch_support
2619def conv2d_transpose_v2(
2620    input,  # pylint: disable=redefined-builtin
2621    filters,  # pylint: disable=redefined-builtin
2622    output_shape,
2623    strides,
2624    padding="SAME",
2625    data_format="NHWC",
2626    dilations=None,
2627    name=None):
2628  """The transpose of `conv2d`.
2629
2630  This operation is sometimes called "deconvolution" after
2631  (Zeiler et al., 2010), but is really the transpose (gradient) of
2632  `atrous_conv2d` rather than an actual deconvolution.
2633
2634  Args:
2635    input: A 4-D `Tensor` of type `float` and shape `[batch, height, width,
2636      in_channels]` for `NHWC` data format or `[batch, in_channels, height,
2637      width]` for `NCHW` data format.
2638    filters: A 4-D `Tensor` with the same type as `input` and shape `[height,
2639      width, output_channels, in_channels]`.  `filter`'s `in_channels` dimension
2640      must match that of `input`.
2641    output_shape: A 1-D `Tensor` representing the output shape of the
2642      deconvolution op.
2643    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2644      stride of the sliding window for each dimension of `input`. If a single
2645      value is given it is replicated in the `H` and `W` dimension. By default
2646      the `N` and `C` dimensions are set to 0. The dimension order is determined
2647      by the value of `data_format`, see below for details.
2648    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2649      padding algorithm to use, or a list indicating the explicit paddings at
2650      the start and end of each dimension. When explicit padding is used and
2651      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2652      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2653      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2654      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2655    data_format: A string. 'NHWC' and 'NCHW' are supported.
2656    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2657      defaults to 1. The dilation factor for each dimension of`input`. If a
2658      single value is given it is replicated in the `H` and `W` dimension. By
2659      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2660      will be k-1 skipped cells between each filter element on that dimension.
2661      The dimension order is determined by the value of `data_format`, see above
2662      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2663      must be 1.
2664    name: Optional name for the returned tensor.
2665
2666  Returns:
2667    A `Tensor` with the same type as `input`.
2668
2669  Raises:
2670    ValueError: If input/output depth does not match `filter`'s shape, or if
2671      padding is other than `'VALID'` or `'SAME'`.
2672
2673  References:
2674    Deconvolutional Networks:
2675      [Zeiler et al., 2010]
2676      (https://ieeexplore.ieee.org/abstract/document/5539957)
2677      ([pdf]
2678      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2679  """
2680  with ops.name_scope(name, "conv2d_transpose",
2681                      [input, filter, output_shape]) as name:
2682    if data_format is None:
2683      data_format = "NHWC"
2684    channel_index = 1 if data_format.startswith("NC") else 3
2685
2686    strides = _get_sequence(strides, 2, channel_index, "strides")
2687    dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2688    padding, explicit_paddings = convert_padding(padding)
2689
2690    return gen_nn_ops.conv2d_backprop_input(
2691        input_sizes=output_shape,
2692        filter=filters,
2693        out_backprop=input,
2694        strides=strides,
2695        padding=padding,
2696        explicit_paddings=explicit_paddings,
2697        data_format=data_format,
2698        dilations=dilations,
2699        name=name)
2700
2701
2702def _conv2d_expanded_batch(
2703    input,  # pylint: disable=redefined-builtin
2704    filters,
2705    strides,
2706    padding,
2707    data_format,
2708    dilations,
2709    name):
2710  """Helper function for `convolution_internal`; handles expanded batches."""
2711  # Try really hard to avoid modifying the legacy name scopes - return early.
2712  input_rank = input.shape.rank
2713  if input_rank is None or input_rank < 5:
2714    # We avoid calling squeeze_batch_dims to reduce extra python function
2715    # call slowdown in eager mode.  This branch doesn't require reshapes.
2716    return gen_nn_ops.conv2d(
2717        input,
2718        filter=filters,
2719        strides=strides,
2720        padding=padding,
2721        data_format=data_format,
2722        dilations=dilations,
2723        name=name)
2724  return squeeze_batch_dims(
2725      input,
2726      functools.partial(
2727          gen_nn_ops.conv2d,
2728          filter=filters,
2729          strides=strides,
2730          padding=padding,
2731          data_format=data_format,
2732          dilations=dilations),
2733      inner_rank=3,
2734      name=name)
2735
2736
2737@tf_export("nn.atrous_conv2d_transpose")
2738@dispatch.add_dispatch_support
2739def atrous_conv2d_transpose(value,
2740                            filters,
2741                            output_shape,
2742                            rate,
2743                            padding,
2744                            name=None):
2745  """The transpose of `atrous_conv2d`.
2746
2747  This operation is sometimes called "deconvolution" after
2748  (Zeiler et al., 2010), but is really the transpose (gradient) of
2749  `atrous_conv2d` rather than an actual deconvolution.
2750
2751  Args:
2752    value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
2753      format. Its shape is `[batch, in_height, in_width, in_channels]`.
2754    filters: A 4-D `Tensor` with the same type as `value` and shape
2755      `[filter_height, filter_width, out_channels, in_channels]`. `filters`'
2756      `in_channels` dimension must match that of `value`. Atrous convolution is
2757      equivalent to standard convolution with upsampled filters with effective
2758      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
2759      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
2760      inserting `rate - 1` zeros along consecutive elements across the
2761      `filters`' spatial dimensions.
2762    output_shape: A 1-D `Tensor` of shape representing the output shape of the
2763      deconvolution op.
2764    rate: A positive int32. The stride with which we sample input values across
2765      the `height` and `width` dimensions. Equivalently, the rate by which we
2766      upsample the filter values by inserting zeros across the `height` and
2767      `width` dimensions. In the literature, the same parameter is sometimes
2768      called `input stride` or `dilation`.
2769    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2770    name: Optional name for the returned tensor.
2771
2772  Returns:
2773    A `Tensor` with the same type as `value`.
2774
2775  Raises:
2776    ValueError: If input/output depth does not match `filters`' shape, or if
2777      padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
2778      than one, or if the output_shape is not a tensor with 4 elements.
2779
2780  References:
2781    Deconvolutional Networks:
2782      [Zeiler et al., 2010]
2783      (https://ieeexplore.ieee.org/abstract/document/5539957)
2784      ([pdf]
2785      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2786  """
2787  with ops.name_scope(name, "atrous_conv2d_transpose",
2788                      [value, filters, output_shape]) as name:
2789    value = ops.convert_to_tensor(value, name="value")
2790    filters = ops.convert_to_tensor(filters, name="filters")
2791    if not value.get_shape().dims[3].is_compatible_with(filters.get_shape()[3]):
2792      raise ValueError(
2793          "value's input channels does not match filters' input channels, "
2794          "{} != {}".format(value.get_shape()[3],
2795                            filters.get_shape()[3]))
2796    if rate < 1:
2797      raise ValueError("rate {} cannot be less than one".format(rate))
2798
2799    if rate == 1:
2800      return conv2d_transpose(
2801          value,
2802          filters,
2803          output_shape,
2804          strides=[1, 1, 1, 1],
2805          padding=padding,
2806          data_format="NHWC")
2807
2808    output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
2809    if not output_shape_.get_shape().is_compatible_with(
2810        tensor_shape.TensorShape([4])):
2811      raise ValueError("output_shape must have shape (4,), got {}".format(
2812          output_shape_.get_shape()))
2813
2814    if isinstance(output_shape, tuple):
2815      output_shape = list(output_shape)
2816
2817    if isinstance(output_shape, (list, np.ndarray)):
2818      # output_shape's shape should be == [4] if reached this point.
2819      if not filters.get_shape().dims[2].is_compatible_with(output_shape[3]):
2820        raise ValueError(
2821            "output_shape does not match filter's output channels, "
2822            "{} != {}".format(output_shape[3],
2823                              filters.get_shape()[2]))
2824
2825    # We have two padding contributions. The first is used for converting "SAME"
2826    # to "VALID". The second is required so that the height and width of the
2827    # zero-padded value tensor are multiples of rate.
2828
2829    # Padding required to reduce to "VALID" convolution
2830    if padding == "SAME":
2831      # Handle filters whose shape is unknown during graph creation.
2832      if filters.get_shape().is_fully_defined():
2833        filter_shape = filters.get_shape().as_list()
2834      else:
2835        filter_shape = array_ops.shape(filters)
2836      filter_height, filter_width = filter_shape[0], filter_shape[1]
2837
2838      # Spatial dimensions of the filters and the upsampled filters in which we
2839      # introduce (rate - 1) zeros between consecutive filter values.
2840      filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
2841      filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
2842
2843      pad_height = filter_height_up - 1
2844      pad_width = filter_width_up - 1
2845
2846      # When pad_height (pad_width) is odd, we pad more to bottom (right),
2847      # following the same convention as conv2d().
2848      pad_top = pad_height // 2
2849      pad_bottom = pad_height - pad_top
2850      pad_left = pad_width // 2
2851      pad_right = pad_width - pad_left
2852    elif padding == "VALID":
2853      pad_top = 0
2854      pad_bottom = 0
2855      pad_left = 0
2856      pad_right = 0
2857    else:
2858      raise ValueError("padding must be either VALID or SAME:"
2859                       " {}".format(padding))
2860
2861    in_height = output_shape[1] + pad_top + pad_bottom
2862    in_width = output_shape[2] + pad_left + pad_right
2863
2864    # More padding so that rate divides the height and width of the input.
2865    pad_bottom_extra = (rate - in_height % rate) % rate
2866    pad_right_extra = (rate - in_width % rate) % rate
2867
2868    # The paddings argument to space_to_batch is just the extra padding
2869    # component.
2870    space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]]
2871
2872    value = array_ops.space_to_batch(
2873        input=value, paddings=space_to_batch_pad, block_size=rate)
2874
2875    input_sizes = [
2876        rate * rate * output_shape[0], (in_height + pad_bottom_extra) // rate,
2877        (in_width + pad_right_extra) // rate, output_shape[3]
2878    ]
2879
2880    value = gen_nn_ops.conv2d_backprop_input(
2881        input_sizes=input_sizes,
2882        filter=filters,
2883        out_backprop=value,
2884        strides=[1, 1, 1, 1],
2885        padding="VALID",
2886        data_format="NHWC")
2887
2888    # The crops argument to batch_to_space includes both padding components.
2889    batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
2890                           [pad_left, pad_right + pad_right_extra]]
2891
2892    return array_ops.batch_to_space(
2893        input=value, crops=batch_to_space_crop, block_size=rate)
2894
2895
2896@tf_export(v1=["nn.depthwise_conv2d_native"])
2897@dispatch.add_dispatch_support
2898@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
2899def depthwise_conv2d_native(  # pylint: disable=redefined-builtin,dangerous-default-value
2900    input,
2901    filter,
2902    strides,
2903    padding,
2904    data_format="NHWC",
2905    dilations=[1, 1, 1, 1],
2906    name=None):
2907  r"""Computes a 2-D depthwise convolution.
2908
2909  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2910  and a filter / kernel tensor of shape
2911  `[filter_height, filter_width, in_channels, channel_multiplier]`, containing
2912  `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
2913  a different filter to each input channel (expanding from 1 channel to
2914  `channel_multiplier` channels for each), then concatenates the results
2915  together. Thus, the output has `in_channels * channel_multiplier` channels.
2916
2917  ```
2918  for k in 0..in_channels-1
2919    for q in 0..channel_multiplier-1
2920      output[b, i, j, k * channel_multiplier + q] =
2921        sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
2922                          filter[di, dj, k, q]
2923  ```
2924
2925  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2926  horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
2927
2928  Args:
2929    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
2930      `float32`, `float64`.
2931    filter: A `Tensor`. Must have the same type as `input`.
2932    strides: A list of `ints`. 1-D of length 4.  The stride of the sliding
2933      window for each dimension of `input`.
2934    padding: Controls how to pad the image before applying the convolution. Can
2935      be the string `"SAME"` or `"VALID"` indicating the type of padding
2936      algorithm to use, or a list indicating the explicit paddings at the start
2937      and end of each dimension. When explicit padding is used and data_format
2938      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2939      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2940      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2941      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2942    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
2943      `"NHWC"`. Specify the data format of the input and output data. With the
2944      default format "NHWC", the data is stored in the order of: [batch, height,
2945        width, channels].
2946      Alternatively, the format could be "NCHW", the data storage order of:
2947        [batch, channels, height, width].
2948    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
2949      tensor of length 4.  The dilation factor for each dimension of `input`. If
2950      set to k > 1, there will be k-1 skipped cells between each filter element
2951      on that dimension. The dimension order is determined by the value of
2952      `data_format`, see above for details. Dilations in the batch and depth
2953      dimensions must be 1.
2954    name: A name for the operation (optional).
2955
2956  Returns:
2957    A `Tensor`. Has the same type as `input`.
2958  """
2959  padding, explicit_paddings = convert_padding(padding)
2960  return gen_nn_ops.depthwise_conv2d_native(
2961      input,
2962      filter,
2963      strides,
2964      padding,
2965      explicit_paddings=explicit_paddings,
2966      data_format=data_format,
2967      dilations=dilations,
2968      name=name)
2969
2970
2971@tf_export(
2972    "nn.depthwise_conv2d_backprop_input",
2973    v1=[
2974        "nn.depthwise_conv2d_native_backprop_input",
2975        "nn.depthwise_conv2d_backprop_input"
2976    ])
2977@dispatch.add_dispatch_support
2978@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
2979def depthwise_conv2d_native_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
2980    input_sizes,
2981    filter,
2982    out_backprop,
2983    strides,
2984    padding,
2985    data_format="NHWC",
2986    dilations=[1, 1, 1, 1],
2987    name=None):
2988  r"""Computes the gradients of depthwise convolution with respect to the input.
2989
2990  Args:
2991    input_sizes: A `Tensor` of type `int32`. An integer vector representing the
2992      shape of `input`, based on `data_format`.  For example, if `data_format`
2993      is 'NHWC' then `input` is a 4-D `[batch, height, width, channels]` tensor.
2994    filter: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
2995      `float32`, `float64`. 4-D with shape `[filter_height, filter_width,
2996      in_channels, depthwise_multiplier]`.
2997    out_backprop: A `Tensor`. Must have the same type as `filter`. 4-D with
2998      shape  based on `data_format`. For example, if `data_format` is 'NHWC'
2999      then out_backprop shape is `[batch, out_height, out_width, out_channels]`.
3000      Gradients w.r.t. the output of the convolution.
3001    strides: A list of `ints`. The stride of the sliding window for each
3002      dimension of the input of the convolution.
3003    padding: Controls how to pad the image before applying the convolution. Can
3004      be the string `"SAME"` or `"VALID"` indicating the type of padding
3005      algorithm to use, or a list indicating the explicit paddings at the start
3006      and end of each dimension. When explicit padding is used and data_format
3007      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
3008      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
3009      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
3010      [pad_top, pad_bottom], [pad_left, pad_right]]`.
3011    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
3012      `"NHWC"`. Specify the data format of the input and output data. With the
3013      default format "NHWC", the data is stored in the order of: [batch, height,
3014        width, channels].
3015      Alternatively, the format could be "NCHW", the data storage order of:
3016        [batch, channels, height, width].
3017    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
3018      tensor of length 4.  The dilation factor for each dimension of `input`. If
3019      set to k > 1, there will be k-1 skipped cells between each filter element
3020      on that dimension. The dimension order is determined by the value of
3021      `data_format`, see above for details. Dilations in the batch and depth
3022      dimensions must be 1.
3023    name: A name for the operation (optional).
3024
3025  Returns:
3026    A `Tensor`. Has the same type as `filter`.
3027  """
3028  padding, explicit_paddings = convert_padding(padding)
3029  return gen_nn_ops.depthwise_conv2d_native_backprop_input(
3030      input_sizes,
3031      filter,
3032      out_backprop,
3033      strides,
3034      padding,
3035      explicit_paddings=explicit_paddings,
3036      data_format=data_format,
3037      dilations=dilations,
3038      name=name)
3039
3040
3041@tf_export(
3042    "nn.depthwise_conv2d_backprop_filter",
3043    v1=[
3044        "nn.depthwise_conv2d_native_backprop_filter",
3045        "nn.depthwise_conv2d_backprop_filter"
3046    ])
3047@dispatch.add_dispatch_support
3048@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
3049def depthwise_conv2d_native_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
3050    input,
3051    filter_sizes,
3052    out_backprop,
3053    strides,
3054    padding,
3055    data_format="NHWC",
3056    dilations=[1, 1, 1, 1],
3057    name=None):
3058  r"""Computes the gradients of depthwise convolution with respect to the filter.
3059
3060  Args:
3061    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
3062      `float32`, `float64`. 4-D with shape based on `data_format`.  For example,
3063      if `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
3064      in_width, in_channels]` tensor.
3065    filter_sizes: A `Tensor` of type `int32`. An integer vector representing the
3066      tensor shape of `filter`, where `filter` is a 4-D `[filter_height,
3067      filter_width, in_channels, depthwise_multiplier]` tensor.
3068    out_backprop: A `Tensor`. Must have the same type as `input`. 4-D with shape
3069      based on `data_format`. For example, if `data_format` is 'NHWC' then
3070      out_backprop shape is `[batch, out_height, out_width, out_channels]`.
3071      Gradients w.r.t. the output of the convolution.
3072    strides: A list of `ints`. The stride of the sliding window for each
3073      dimension of the input of the convolution.
3074    padding: Controls how to pad the image before applying the convolution. Can
3075      be the string `"SAME"` or `"VALID"` indicating the type of padding
3076      algorithm to use, or a list indicating the explicit paddings at the start
3077      and end of each dimension. When explicit padding is used and data_format
3078      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
3079      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
3080      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
3081      [pad_top, pad_bottom], [pad_left, pad_right]]`.
3082    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
3083      `"NHWC"`. Specify the data format of the input and output data. With the
3084      default format "NHWC", the data is stored in the order of: [batch, height,
3085        width, channels].
3086      Alternatively, the format could be "NCHW", the data storage order of:
3087        [batch, channels, height, width].
3088    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
3089      tensor of length 4.  The dilation factor for each dimension of `input`. If
3090      set to k > 1, there will be k-1 skipped cells between each filter element
3091      on that dimension. The dimension order is determined by the value of
3092      `data_format`, see above for details. Dilations in the batch and depth
3093      dimensions must be 1.
3094    name: A name for the operation (optional).
3095
3096  Returns:
3097    A `Tensor`. Has the same type as `input`.
3098  """
3099  padding, explicit_paddings = convert_padding(padding)
3100  return gen_nn_ops.depthwise_conv2d_native_backprop_filter(
3101      input,
3102      filter_sizes,
3103      out_backprop,
3104      strides,
3105      padding,
3106      explicit_paddings=explicit_paddings,
3107      data_format=data_format,
3108      dilations=dilations,
3109      name=name)
3110
3111
3112def _conv3d_expanded_batch(
3113    input,  # pylint: disable=redefined-builtin
3114    filter,  # pylint: disable=redefined-builtin
3115    strides,
3116    padding,
3117    data_format,
3118    dilations=None,
3119    name=None):
3120  """Helper function for `conv3d`; handles expanded batches."""
3121  shape = input.shape
3122  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
3123  # we fall back to len(shape).
3124  ndims = getattr(shape, "ndims", -1)
3125  if ndims == -1:
3126    ndims = len(shape)
3127  if ndims in (5, 4, 3, 2, 1, 0, None):
3128    # We avoid calling squeeze_batch_dims to reduce extra python function
3129    # call slowdown in eager mode.  This branch doesn't require reshapes.
3130    return gen_nn_ops.conv3d(
3131        input,
3132        filter,
3133        strides,
3134        padding,
3135        data_format=data_format,
3136        dilations=dilations,
3137        name=name)
3138  else:
3139    return squeeze_batch_dims(
3140        input,
3141        functools.partial(
3142            gen_nn_ops.conv3d,
3143            filter=filter,
3144            strides=strides,
3145            padding=padding,
3146            data_format=data_format,
3147            dilations=dilations),
3148        inner_rank=4,
3149        name=name)
3150
3151
3152@tf_export("nn.conv3d", v1=[])
3153@dispatch.add_dispatch_support
3154def conv3d_v2(input,  # pylint: disable=redefined-builtin,missing-docstring
3155              filters,
3156              strides,
3157              padding,
3158              data_format="NDHWC",
3159              dilations=None,
3160              name=None):
3161  if dilations is None:
3162    dilations = [1, 1, 1, 1, 1]
3163  return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
3164                                dilations, name)
3165
3166
3167@tf_export(v1=["nn.conv3d"])
3168@dispatch.add_dispatch_support
3169def conv3d_v1(  # pylint: disable=missing-docstring,dangerous-default-value
3170    input,  # pylint: disable=redefined-builtin
3171    filter=None,  # pylint: disable=redefined-builtin
3172    strides=None,
3173    padding=None,
3174    data_format="NDHWC",
3175    dilations=[1, 1, 1, 1, 1],
3176    name=None,
3177    filters=None):
3178  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3179  return gen_nn_ops.conv3d(
3180      input, filter, strides, padding, data_format, dilations, name)
3181
3182
3183conv3d_v2.__doc__ = deprecation.rewrite_argument_docstring(
3184    gen_nn_ops.conv3d.__doc__, "filter", "filters")
3185conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__
3186
3187
3188@tf_export(v1=["nn.conv3d_transpose"])
3189@dispatch.add_dispatch_support
3190def conv3d_transpose(
3191    value,
3192    filter=None,  # pylint: disable=redefined-builtin
3193    output_shape=None,
3194    strides=None,
3195    padding="SAME",
3196    data_format="NDHWC",
3197    name=None,
3198    input=None,  # pylint: disable=redefined-builtin
3199    filters=None,
3200    dilations=None):
3201  """The transpose of `conv3d`.
3202
3203  This operation is sometimes called "deconvolution" after
3204  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3205  rather than an actual deconvolution.
3206
3207  Args:
3208    value: A 5-D `Tensor` of type `float` and shape
3209      `[batch, depth, height, width, in_channels]`.
3210    filter: A 5-D `Tensor` with the same type as `value` and shape
3211      `[depth, height, width, output_channels, in_channels]`.  `filter`'s
3212      `in_channels` dimension must match that of `value`.
3213    output_shape: A 1-D `Tensor` representing the output shape of the
3214      deconvolution op.
3215    strides: A list of ints. The stride of the sliding window for each
3216      dimension of the input tensor.
3217    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
3218      See the "returns" section of `tf.nn.convolution` for details.
3219    data_format: A string, either `'NDHWC'` or `'NCDHW`' specifying the layout
3220      of the input and output tensors. Defaults to `'NDHWC'`.
3221    name: Optional name for the returned tensor.
3222    input: Alias of value.
3223    filters: Alias of filter.
3224    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3225      defaults to 1. The dilation factor for each dimension of`input`. If a
3226      single value is given it is replicated in the `D`, `H` and `W` dimension.
3227      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3228      will be k-1 skipped cells between each filter element on that dimension.
3229      The dimension order is determined by the value of `data_format`, see above
3230      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3231      must be 1.
3232
3233  Returns:
3234    A `Tensor` with the same type as `value`.
3235
3236  Raises:
3237    ValueError: If input/output depth does not match `filter`'s shape, or if
3238      padding is other than `'VALID'` or `'SAME'`.
3239
3240  References:
3241    Deconvolutional Networks:
3242      [Zeiler et al., 2010]
3243      (https://ieeexplore.ieee.org/abstract/document/5539957)
3244      ([pdf]
3245      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3246  """
3247  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3248  value = deprecated_argument_lookup("input", input, "value", value)
3249  return conv3d_transpose_v2(
3250      value,
3251      filter,
3252      output_shape,
3253      strides,
3254      padding=padding,
3255      data_format=data_format,
3256      dilations=dilations,
3257      name=name)
3258
3259
3260@tf_export("nn.conv3d_transpose", v1=[])
3261@dispatch.add_dispatch_support
3262def conv3d_transpose_v2(input,  # pylint: disable=redefined-builtin
3263                        filters,
3264                        output_shape,
3265                        strides,
3266                        padding="SAME",
3267                        data_format="NDHWC",
3268                        dilations=None,
3269                        name=None):
3270  """The transpose of `conv3d`.
3271
3272  This operation is sometimes called "deconvolution" after
3273  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3274  rather than an actual deconvolution.
3275
3276  Args:
3277    input: A 5-D `Tensor` of type `float` and shape `[batch, depth, height,
3278      width, in_channels]` for `NDHWC` data format or `[batch, in_channels,
3279      depth, height, width]` for `NCDHW` data format.
3280    filters: A 5-D `Tensor` with the same type as `input` and shape `[depth,
3281      height, width, output_channels, in_channels]`.  `filter`'s `in_channels`
3282      dimension must match that of `input`.
3283    output_shape: A 1-D `Tensor` representing the output shape of the
3284      deconvolution op.
3285    strides: An int or list of `ints` that has length `1`, `3` or `5`.  The
3286      stride of the sliding window for each dimension of `input`. If a single
3287      value is given it is replicated in the `D`, `H` and `W` dimension. By
3288      default the `N` and `C` dimensions are set to 0. The dimension order is
3289      determined by the value of `data_format`, see below for details.
3290    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3291      the "returns" section of `tf.nn.convolution` for details.
3292    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
3293    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3294      defaults to 1. The dilation factor for each dimension of`input`. If a
3295      single value is given it is replicated in the `D`, `H` and `W` dimension.
3296      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3297      will be k-1 skipped cells between each filter element on that dimension.
3298      The dimension order is determined by the value of `data_format`, see above
3299      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3300      must be 1.
3301    name: Optional name for the returned tensor.
3302
3303  Returns:
3304    A `Tensor` with the same type as `input`.
3305
3306  References:
3307    Deconvolutional Networks:
3308      [Zeiler et al., 2010]
3309      (https://ieeexplore.ieee.org/abstract/document/5539957)
3310      ([pdf]
3311      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3312  """
3313  with ops.name_scope(name, "conv3d_transpose",
3314                      [input, filter, output_shape]) as name:
3315    if data_format is None:
3316      data_format = "NDHWC"
3317    channel_index = 1 if data_format.startswith("NC") else 4
3318
3319    strides = _get_sequence(strides, 3, channel_index, "strides")
3320    dilations = _get_sequence(dilations, 3, channel_index, "dilations")
3321
3322    return gen_nn_ops.conv3d_backprop_input_v2(
3323        input_sizes=output_shape,
3324        filter=filters,
3325        out_backprop=input,
3326        strides=strides,
3327        padding=padding,
3328        data_format=data_format,
3329        dilations=dilations,
3330        name=name)
3331
3332
3333CONV_TRANSPOSE_OPS = (
3334    conv1d_transpose,
3335    conv2d_transpose_v2,
3336    conv3d_transpose_v2,
3337)
3338
3339
3340@tf_export("nn.conv_transpose")
3341@dispatch.add_dispatch_support
3342def conv_transpose(input,  # pylint: disable=redefined-builtin
3343                   filters,
3344                   output_shape,
3345                   strides,
3346                   padding="SAME",
3347                   data_format=None,
3348                   dilations=None,
3349                   name=None):
3350  """The transpose of `convolution`.
3351
3352  This operation is sometimes called "deconvolution" after
3353  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3354  rather than an actual deconvolution.
3355
3356  Args:
3357    input: An N+2 dimensional `Tensor` of shape
3358      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
3359      not start with "NC" (default), or
3360      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
3361      with "NC". It must be one of the following types:
3362      `half`, `bfloat16`, `float32`, `float64`.
3363    filters: An N+2 dimensional `Tensor` with the same type as `input` and
3364      shape `spatial_filter_shape + [in_channels, out_channels]`.
3365    output_shape: A 1-D `Tensor` representing the output shape of the
3366      deconvolution op.
3367    strides: An int or list of `ints` that has length `1`, `N` or `N+2`.  The
3368      stride of the sliding window for each dimension of `input`. If a single
3369      value is given it is replicated in the spatial dimensions. By default
3370      the `N` and `C` dimensions are set to 0. The dimension order is determined
3371      by the value of `data_format`, see below for details.
3372    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3373      the "returns" section of `tf.nn.convolution` for details.
3374    data_format: A string or None.  Specifies whether the channel dimension of
3375      the `input` and output is the last dimension (default, or if `data_format`
3376      does not start with "NC"), or the second dimension (if `data_format`
3377      starts with "NC").  For N=1, the valid values are "NWC" (default) and
3378      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
3379      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
3380    dilations: An int or list of `ints` that has length `1`, `N` or `N+2`,
3381      defaults to 1. The dilation factor for each dimension of`input`. If a
3382      single value is given it is replicated in the spatial dimensions. By
3383      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3384      will be k-1 skipped cells between each filter element on that dimension.
3385      The dimension order is determined by the value of `data_format`, see above
3386      for details.
3387    name: A name for the operation (optional). If not specified "conv_transpose"
3388      is used.
3389
3390  Returns:
3391    A `Tensor` with the same type as `value`.
3392
3393  References:
3394    Deconvolutional Networks:
3395      [Zeiler et al., 2010]
3396      (https://ieeexplore.ieee.org/abstract/document/5539957)
3397      ([pdf]
3398      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3399  """
3400  with ops.name_scope(name, "conv_transpose",
3401                      [input, filter, output_shape]) as name:
3402    if tensor_util.is_tf_type(output_shape):
3403      n = output_shape.shape[0] - 2
3404    elif isinstance(output_shape, collections_abc.Sized):
3405      n = len(output_shape) - 2
3406    else:
3407      raise ValueError("output_shape must be a tensor or sized collection.")
3408
3409    if not 1 <= n <= 3:
3410      raise ValueError(
3411          "output_shape must be of length 3, 4 or 5 but was {}.".format(n + 2))
3412
3413    op = CONV_TRANSPOSE_OPS[n-1]
3414    return op(
3415        input,
3416        filters,
3417        output_shape,
3418        strides,
3419        padding=padding,
3420        data_format=data_format,
3421        dilations=dilations,
3422        name=name)
3423
3424
3425@tf_export("nn.bias_add")
3426@dispatch.add_dispatch_support
3427def bias_add(value, bias, data_format=None, name=None):
3428  """Adds `bias` to `value`.
3429
3430  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3431  Broadcasting is supported, so `value` may have any number of dimensions.
3432  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3433  case where both types are quantized.
3434
3435  Args:
3436    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3437      `int16`, `int8`, `complex64`, or `complex128`.
3438    bias: A 1-D `Tensor` with size matching the channel dimension of `value`.
3439      Must be the same type as `value` unless `value` is a quantized type,
3440      in which case a different quantized type may be used.
3441    data_format: A string. 'N...C' and 'NC...' are supported. If `None` (the
3442      default) is specified then 'N..C' is assumed.
3443    name: A name for the operation (optional).
3444
3445  Returns:
3446    A `Tensor` with the same type as `value`.
3447
3448  Raises:
3449    ValueError if data format is unrecognized, if `value` has less than two
3450    dimensions when `data_format` is 'N..C'/`None` or `value` has less
3451    then three dimensions when `data_format` is `NC..`, if `bias` does not
3452    have exactly one dimension (is a vector), or if the size of `bias`
3453    does not match the size of the channel dimension of `value`.
3454  """
3455  with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
3456    if data_format is not None:
3457      if data_format.startswith("NC"):
3458        data_format = "NCHW"
3459      elif data_format.startswith("N") and data_format.endswith("C"):
3460        data_format = "NHWC"
3461      else:
3462        raise ValueError("data_format must be of the form `N...C` or `NC...`")
3463
3464    if not context.executing_eagerly():
3465      value = ops.convert_to_tensor(value, name="input")
3466      bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3467
3468    # TODO(duncanriach): Implement deterministic functionality at CUDA kernel
3469    #   level.
3470    if config.deterministic_ops_enabled():
3471      # Note that this code does not implement the same error checks as the
3472      # pre-existing C++ ops.
3473      if data_format == "NCHW":
3474        broadcast_shape_head = [1, array_ops.size(bias)]
3475        broadcast_shape_tail = array_ops.ones(
3476            array_ops.rank(value) - 2, dtype=dtypes.int32)
3477        broadcast_shape = array_ops.concat(
3478            [broadcast_shape_head, broadcast_shape_tail], 0)
3479        return math_ops.add(
3480            value, array_ops.reshape(bias, broadcast_shape), name=name)
3481      else:  # data_format == 'NHWC' or data_format == None
3482        return math_ops.add(value, bias, name=name)
3483    else:
3484      return gen_nn_ops.bias_add(
3485          value, bias, data_format=data_format, name=name)
3486
3487
3488def bias_add_v1(value, bias, name=None):
3489  """Adds `bias` to `value`.
3490
3491  This is a deprecated version of bias_add and will soon to be removed.
3492
3493  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3494  Broadcasting is supported, so `value` may have any number of dimensions.
3495  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3496  case where both types are quantized.
3497
3498  Args:
3499    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3500      `int16`, `int8`, `complex64`, or `complex128`.
3501    bias: A 1-D `Tensor` with size matching the last dimension of `value`.
3502      Must be the same type as `value` unless `value` is a quantized type,
3503      in which case a different quantized type may be used.
3504    name: A name for the operation (optional).
3505
3506  Returns:
3507    A `Tensor` with the same type as `value`.
3508  """
3509  with ops.name_scope(name, "BiasAddV1", [value, bias]) as name:
3510    value = ops.convert_to_tensor(value, name="input")
3511    bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3512    return gen_nn_ops.bias_add_v1(value, bias, name=name)
3513
3514
3515@tf_export(v1=["nn.crelu"])
3516@dispatch.add_dispatch_support
3517def crelu(features, name=None, axis=-1):
3518  """Computes Concatenated ReLU.
3519
3520  Concatenates a ReLU which selects only the positive part of the activation
3521  with a ReLU which selects only the *negative* part of the activation.
3522  Note that as a result this non-linearity doubles the depth of the activations.
3523  Source: [Understanding and Improving Convolutional Neural Networks via
3524  Concatenated Rectified Linear Units. W. Shang, et
3525  al.](https://arxiv.org/abs/1603.05201)
3526
3527  Args:
3528    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3529      `int16`, or `int8`.
3530    name: A name for the operation (optional).
3531    axis: The axis that the output values are concatenated along. Default is -1.
3532
3533  Returns:
3534    A `Tensor` with the same type as `features`.
3535
3536  References:
3537    Understanding and Improving Convolutional Neural Networks via Concatenated
3538    Rectified Linear Units:
3539      [Shang et al., 2016](http://proceedings.mlr.press/v48/shang16)
3540      ([pdf](http://proceedings.mlr.press/v48/shang16.pdf))
3541  """
3542  with ops.name_scope(name, "CRelu", [features]) as name:
3543    features = ops.convert_to_tensor(features, name="features")
3544    c = array_ops.concat([features, -features], axis, name=name)  # pylint: disable=invalid-unary-operand-type
3545    return gen_nn_ops.relu(c)
3546
3547
3548@tf_export("nn.crelu", v1=[])
3549@dispatch.add_dispatch_support
3550def crelu_v2(features, axis=-1, name=None):
3551  return crelu(features, name=name, axis=axis)
3552crelu_v2.__doc__ = crelu.__doc__
3553
3554
3555@tf_export("nn.relu6")
3556@dispatch.add_dispatch_support
3557def relu6(features, name=None):
3558  """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
3559
3560  Args:
3561    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3562      `int16`, or `int8`.
3563    name: A name for the operation (optional).
3564
3565  Returns:
3566    A `Tensor` with the same type as `features`.
3567
3568  References:
3569    Convolutional Deep Belief Networks on CIFAR-10:
3570      Krizhevsky et al., 2010
3571      ([pdf](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf))
3572  """
3573  with ops.name_scope(name, "Relu6", [features]) as name:
3574    features = ops.convert_to_tensor(features, name="features")
3575    return gen_nn_ops.relu6(features, name=name)
3576
3577
3578@tf_export("nn.leaky_relu")
3579@dispatch.add_dispatch_support
3580def leaky_relu(features, alpha=0.2, name=None):
3581  """Compute the Leaky ReLU activation function.
3582
3583  Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models.
3584  AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013]
3585  (https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf).
3586
3587  Args:
3588    features: A `Tensor` representing preactivation values. Must be one of
3589      the following types: `float16`, `float32`, `float64`, `int32`, `int64`.
3590    alpha: Slope of the activation function at x < 0.
3591    name: A name for the operation (optional).
3592
3593  Returns:
3594    The activation value.
3595
3596  References:
3597    Rectifier Nonlinearities Improve Neural Network Acoustic Models:
3598      [Maas et al., 2013]
3599      (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.693.1422)
3600      ([pdf]
3601      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.693.1422&rep=rep1&type=pdf))
3602  """
3603  with ops.name_scope(name, "LeakyRelu", [features, alpha]) as name:
3604    features = ops.convert_to_tensor(features, name="features")
3605    if features.dtype.is_integer:
3606      features = math_ops.cast(features, dtypes.float32)
3607    if isinstance(alpha, np.ndarray):
3608      alpha = alpha.item()
3609    return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
3610
3611
3612@tf_export("nn.gelu", v1=[])
3613@dispatch.add_dispatch_support
3614def gelu(features, approximate=False, name=None):
3615  """Compute the Gaussian Error Linear Unit (GELU) activation function.
3616
3617  Gaussian error linear unit (GELU) computes
3618  `x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
3619  The (GELU) nonlinearity weights inputs by their value, rather than gates
3620  inputs by their sign as in ReLU.
3621
3622  For example:
3623
3624  >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
3625  >>> y = tf.nn.gelu(x)
3626  >>> y.numpy()
3627  array([-0.00404951, -0.15865529,  0.        ,  0.8413447 ,  2.9959507 ],
3628      dtype=float32)
3629  >>> y = tf.nn.gelu(x, approximate=True)
3630  >>> y.numpy()
3631  array([-0.00363752, -0.15880796,  0.        ,  0.841192  ,  2.9963627 ],
3632      dtype=float32)
3633
3634  Args:
3635    features: A `Tensor` representing preactivation values.
3636    approximate: An optional `bool`. Defaults to `False`. Whether to enable
3637      approximation.
3638    name: A name for the operation (optional).
3639
3640  Returns:
3641    A `Tensor` with the same type as `features`.
3642
3643  References:
3644    [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415).
3645  """
3646  with ops.name_scope(name, "Gelu", [features]):
3647    features = ops.convert_to_tensor(features, name="features")
3648    if approximate:
3649      coeff = math_ops.cast(0.044715, features.dtype)
3650      return 0.5 * features * (
3651          1.0 + math_ops.tanh(0.7978845608028654 *
3652                              (features + coeff * math_ops.pow(features, 3))))
3653    else:
3654      return 0.5 * features * (1.0 + math_ops.erf(
3655          features / math_ops.cast(1.4142135623730951, features.dtype)))
3656
3657
3658def _flatten_outer_dims(logits):
3659  """Flattens logits' outer dimensions and keep its last dimension."""
3660  rank = array_ops.rank(logits)
3661  last_dim_size = array_ops.slice(
3662      array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1])
3663  output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
3664
3665  # Set output shape if known.
3666  if not context.executing_eagerly():
3667    shape = logits.get_shape()
3668    if shape is not None and shape.dims is not None:
3669      shape = shape.as_list()
3670      product = 1
3671      product_valid = True
3672      for d in shape[:-1]:
3673        if d is None:
3674          product_valid = False
3675          break
3676        else:
3677          product *= d
3678      if product_valid:
3679        output_shape = [product, shape[-1]]
3680        output.set_shape(output_shape)
3681
3682  return output
3683
3684
3685def _wrap_2d_function(inputs, compute_op, dim=-1, name=None):
3686  """Helper function for ops that accept and return 2d inputs of same shape.
3687
3688  It reshapes and transposes the inputs into a 2-D Tensor and then invokes
3689  the given function. The output would be transposed and reshaped back.
3690  If the given function returns a tuple of tensors, each of them will be
3691  transposed and reshaped.
3692
3693  Args:
3694    inputs: A non-empty `Tensor`. Must be one of the following types: `half`,
3695      `float32`, `float64`.
3696    compute_op: The function to wrap. Must accept the input tensor as its first
3697      arugment, and a second keyword argument `name`.
3698    dim: The dimension softmax would be performed on. The default is -1 which
3699      indicates the last dimension.
3700    name: A name for the operation (optional).
3701
3702  Returns:
3703    A `Tensor`. Has the same shape as inputs. If compute_op returns multiple
3704      tensors, each of them have the same shape as the input.
3705  Raises:
3706    InvalidArgumentError: if `inputs` is empty or `dim` is beyond the last
3707      dimension of `inputs`.
3708  """
3709
3710  def _swap_axis(input_tensor, dim_index, last_index, name=None):
3711    """Swaps logits's dim_index and last_index."""
3712    return array_ops.transpose(
3713        input_tensor,
3714        array_ops.concat([
3715            math_ops.range(dim_index), [last_index],
3716            math_ops.range(dim_index + 1, last_index), [dim_index]
3717        ], 0),
3718        name=name)
3719
3720  inputs = ops.convert_to_tensor(inputs)
3721
3722  # We need its original shape for shape inference.
3723  shape = inputs.get_shape()
3724  is_last_dim = (dim == -1) or (dim == shape.ndims - 1)
3725
3726  if is_last_dim:
3727    return compute_op(inputs, name=name)
3728
3729  dim_val = dim
3730  if isinstance(dim, ops.Tensor):
3731    dim_val = tensor_util.constant_value(dim)
3732  if dim_val is not None and not -shape.ndims <= dim_val < shape.ndims:
3733    raise errors_impl.InvalidArgumentError(
3734        None, None,
3735        "Dimension (%d) must be in the range [%d, %d) where %d is the number of"
3736        " dimensions in the input." % (dim_val, -shape.ndims, shape.ndims,
3737                                       shape.ndims))
3738
3739  # If dim is not the last dimension, we have to do a transpose so that we can
3740  # still perform the op on its last dimension.
3741
3742  # In case dim is negative (and is not last dimension -1), add shape.ndims
3743  ndims = array_ops.rank(inputs)
3744  if not isinstance(dim, ops.Tensor):
3745    if dim < 0:
3746      dim += ndims
3747  else:
3748    dim = array_ops.where(math_ops.less(dim, 0), dim + ndims, dim)
3749
3750  # Swap logits' dimension of dim and its last dimension.
3751  input_rank = array_ops.rank(inputs)
3752  dim_axis = dim % shape.ndims
3753  inputs = _swap_axis(inputs, dim_axis, math_ops.subtract(input_rank, 1))
3754
3755  # Do the actual call on its last dimension.
3756  def fix_output(output):
3757    output = _swap_axis(
3758        output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
3759
3760    # Make shape inference work since transpose may erase its static shape.
3761    output.set_shape(shape)
3762    return output
3763
3764  outputs = compute_op(inputs)
3765  if isinstance(outputs, tuple):
3766    return tuple(fix_output(output) for output in outputs)
3767  else:
3768    return fix_output(outputs)
3769
3770
3771@tf_export("nn.softmax", "math.softmax", v1=[])
3772@dispatch.add_dispatch_support
3773def softmax_v2(logits, axis=None, name=None):
3774  """Computes softmax activations.
3775
3776  Used for multi-class predictions. The sum of all outputs generated by softmax
3777  is 1.
3778
3779  This function performs the equivalent of
3780
3781      softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
3782
3783  Example usage:
3784
3785  >>> softmax = tf.nn.softmax([-1, 0., 1.])
3786  >>> softmax
3787  <tf.Tensor: shape=(3,), dtype=float32,
3788  numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
3789  >>> sum(softmax)
3790  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
3791
3792  Args:
3793    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3794      `float32`, `float64`.
3795    axis: The dimension softmax would be performed on. The default is -1 which
3796      indicates the last dimension.
3797    name: A name for the operation (optional).
3798
3799  Returns:
3800    A `Tensor`. Has the same type and shape as `logits`.
3801
3802  Raises:
3803    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3804      dimension of `logits`.
3805  """
3806  if axis is None:
3807    axis = -1
3808  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3809
3810
3811@tf_export(v1=["nn.softmax", "math.softmax"])
3812@dispatch.add_dispatch_support
3813@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3814def softmax(logits, axis=None, name=None, dim=None):
3815  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3816  if axis is None:
3817    axis = -1
3818  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3819
3820
3821softmax.__doc__ = softmax_v2.__doc__
3822
3823
3824@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
3825@dispatch.add_dispatch_support
3826@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3827def log_softmax(logits, axis=None, name=None, dim=None):
3828  """Computes log softmax activations.
3829
3830  For each batch `i` and class `j` we have
3831
3832      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3833
3834  Args:
3835    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3836      `float32`, `float64`.
3837    axis: The dimension softmax would be performed on. The default is -1 which
3838      indicates the last dimension.
3839    name: A name for the operation (optional).
3840    dim: Deprecated alias for `axis`.
3841
3842  Returns:
3843    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3844
3845  Raises:
3846    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3847      dimension of `logits`.
3848  """
3849  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3850  if axis is None:
3851    axis = -1
3852  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3853
3854
3855@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
3856@dispatch.add_dispatch_support
3857def log_softmax_v2(logits, axis=None, name=None):
3858  """Computes log softmax activations.
3859
3860  For each batch `i` and class `j` we have
3861
3862      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3863
3864  Args:
3865    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3866      `float32`, `float64`.
3867    axis: The dimension softmax would be performed on. The default is -1 which
3868      indicates the last dimension.
3869    name: A name for the operation (optional).
3870
3871  Returns:
3872    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3873
3874  Raises:
3875    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3876      dimension of `logits`.
3877  """
3878  if axis is None:
3879    axis = -1
3880  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3881
3882
3883def _ensure_xent_args(name, sentinel, labels, logits):
3884  # Make sure that all arguments were passed as named arguments.
3885  if sentinel is not None:
3886    raise ValueError("Only call `%s` with "
3887                     "named arguments (labels=..., logits=..., ...)" % name)
3888  if labels is None or logits is None:
3889    raise ValueError("Both labels and logits must be provided.")
3890
3891
3892@tf_export("nn.softmax_cross_entropy_with_logits", v1=[])
3893@dispatch.add_dispatch_support
3894def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
3895  """Computes softmax cross entropy between `logits` and `labels`.
3896
3897  Measures the probability error in discrete classification tasks in which the
3898  classes are mutually exclusive (each entry is in exactly one class).  For
3899  example, each CIFAR-10 image is labeled with one and only one label: an image
3900  can be a dog or a truck, but not both.
3901
3902  **NOTE:**  While the classes are mutually exclusive, their probabilities
3903  need not be.  All that is required is that each row of `labels` is
3904  a valid probability distribution.  If they are not, the computation of the
3905  gradient will be incorrect.
3906
3907  If using exclusive `labels` (wherein one and only
3908  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
3909
3910  Usage:
3911
3912  >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
3913  >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
3914  >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
3915  <tf.Tensor: shape=(2,), dtype=float32,
3916  numpy=array([0.16984604, 0.82474494], dtype=float32)>
3917
3918  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
3919  on `logits` internally for efficiency.  Do not call this op with the
3920  output of `softmax`, as it will produce incorrect results.
3921
3922  A common use case is to have logits and labels of shape
3923  `[batch_size, num_classes]`, but higher dimensions are supported, with
3924  the `axis` argument specifying the class dimension.
3925
3926  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
3927  or `float64`).
3928
3929  Backpropagation will happen into both `logits` and `labels`.  To disallow
3930  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
3931  before feeding it to this function.
3932
3933  **Note that to avoid confusion, it is required to pass only named arguments to
3934  this function.**
3935
3936  Args:
3937    labels: Each vector along the class dimension should hold a valid
3938      probability distribution e.g. for the case in which labels are of shape
3939      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
3940      probability distribution.
3941    logits: Per-label activations, typically a linear output. These activation
3942      energies are interpreted as unnormalized log probabilities.
3943    axis: The class dimension. Defaulted to -1 which is the last dimension.
3944    name: A name for the operation (optional).
3945
3946  Returns:
3947    A `Tensor` that contains the softmax cross entropy loss. Its type is the
3948    same as `logits` and its shape is the same as `labels` except that it does
3949    not have the last dimension of `labels`.
3950  """
3951  return softmax_cross_entropy_with_logits_v2_helper(
3952      labels=labels, logits=logits, axis=axis, name=name)
3953
3954
3955@tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"])
3956@dispatch.add_dispatch_support
3957@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3958def softmax_cross_entropy_with_logits_v2_helper(
3959    labels, logits, axis=None, name=None, dim=None):
3960  """Computes softmax cross entropy between `logits` and `labels`.
3961
3962  Measures the probability error in discrete classification tasks in which the
3963  classes are mutually exclusive (each entry is in exactly one class).  For
3964  example, each CIFAR-10 image is labeled with one and only one label: an image
3965  can be a dog or a truck, but not both.
3966
3967  **NOTE:**  While the classes are mutually exclusive, their probabilities
3968  need not be.  All that is required is that each row of `labels` is
3969  a valid probability distribution.  If they are not, the computation of the
3970  gradient will be incorrect.
3971
3972  If using exclusive `labels` (wherein one and only
3973  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
3974
3975  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
3976  on `logits` internally for efficiency.  Do not call this op with the
3977  output of `softmax`, as it will produce incorrect results.
3978
3979  A common use case is to have logits and labels of shape
3980  `[batch_size, num_classes]`, but higher dimensions are supported, with
3981  the `axis` argument specifying the class dimension.
3982
3983  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
3984  or `float64`).
3985
3986  Backpropagation will happen into both `logits` and `labels`.  To disallow
3987  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
3988  before feeding it to this function.
3989
3990  **Note that to avoid confusion, it is required to pass only named arguments to
3991  this function.**
3992
3993  Args:
3994    labels: Each vector along the class dimension should hold a valid
3995      probability distribution e.g. for the case in which labels are of shape
3996      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
3997      probability distribution.
3998    logits: Unscaled log probabilities.
3999    axis: The class dimension. Defaulted to -1 which is the last dimension.
4000    name: A name for the operation (optional).
4001    dim: Deprecated alias for axis.
4002
4003  Returns:
4004    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4005    same as `logits` and its shape is the same as `labels` except that it does
4006    not have the last dimension of `labels`.
4007  """
4008  # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
4009  # could break users who call this with bad labels, but disregard the bad
4010  # results.
4011  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
4012  del dim
4013  if axis is None:
4014    axis = -1
4015
4016  with ops.name_scope(name, "softmax_cross_entropy_with_logits",
4017                      [logits, labels]) as name:
4018    logits = ops.convert_to_tensor(logits, name="logits")
4019    labels = ops.convert_to_tensor(labels, name="labels")
4020    convert_to_float32 = (
4021        logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
4022    precise_logits = math_ops.cast(
4023        logits, dtypes.float32) if convert_to_float32 else logits
4024    # labels and logits must be of the same type
4025    labels = math_ops.cast(labels, precise_logits.dtype)
4026    input_rank = array_ops.rank(precise_logits)
4027    # For shape inference.
4028    shape = logits.get_shape()
4029
4030    # Move the dim to the end if dim is not the last dimension.
4031    if axis != -1:
4032
4033      def _move_dim_to_end(tensor, dim_index, rank):
4034        return array_ops.transpose(
4035            tensor,
4036            array_ops.concat([
4037                math_ops.range(dim_index),
4038                math_ops.range(dim_index + 1, rank), [dim_index]
4039            ], 0))
4040
4041      precise_logits = _move_dim_to_end(precise_logits, axis, input_rank)
4042      labels = _move_dim_to_end(labels, axis, input_rank)
4043
4044    input_shape = array_ops.shape(precise_logits)
4045
4046    # Make precise_logits and labels into matrices.
4047    precise_logits = _flatten_outer_dims(precise_logits)
4048    labels = _flatten_outer_dims(labels)
4049
4050    # Do the actual op computation.
4051    if config.deterministic_ops_enabled():
4052      log_probs = log_softmax_v2(precise_logits)
4053      cost = -math_ops.reduce_sum(labels * log_probs, axis=1)
4054    else:
4055      # The second output tensor contains the gradients.  We use it in
4056      # CrossEntropyGrad() in nn_grad but not here.
4057      cost, unused_backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
4058          precise_logits, labels, name=name)
4059
4060    # The output cost shape should be the input minus axis.
4061    output_shape = array_ops.slice(input_shape, [0],
4062                                   [math_ops.subtract(input_rank, 1)])
4063    cost = array_ops.reshape(cost, output_shape)
4064
4065    # Make shape inference work since reshape and transpose may erase its static
4066    # shape.
4067    if not context.executing_eagerly(
4068    ) and shape is not None and shape.dims is not None:
4069      shape = shape.as_list()
4070      del shape[axis]
4071      cost.set_shape(shape)
4072
4073    if convert_to_float32:
4074      return math_ops.cast(cost, logits.dtype)
4075    else:
4076      return cost
4077
4078
4079_XENT_DEPRECATION = """
4080Future major versions of TensorFlow will allow gradients to flow
4081into the labels input on backprop by default.
4082
4083See `tf.nn.softmax_cross_entropy_with_logits_v2`.
4084"""
4085
4086
4087@tf_export(v1=["nn.softmax_cross_entropy_with_logits"])
4088@dispatch.add_dispatch_support
4089@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
4090def softmax_cross_entropy_with_logits(
4091    _sentinel=None,  # pylint: disable=invalid-name
4092    labels=None,
4093    logits=None,
4094    dim=-1,
4095    name=None,
4096    axis=None):
4097  """Computes softmax cross entropy between `logits` and `labels`.
4098
4099  Measures the probability error in discrete classification tasks in which the
4100  classes are mutually exclusive (each entry is in exactly one class).  For
4101  example, each CIFAR-10 image is labeled with one and only one label: an image
4102  can be a dog or a truck, but not both.
4103
4104  **NOTE:**  While the classes are mutually exclusive, their probabilities
4105  need not be.  All that is required is that each row of `labels` is
4106  a valid probability distribution.  If they are not, the computation of the
4107  gradient will be incorrect.
4108
4109  If using exclusive `labels` (wherein one and only
4110  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
4111
4112  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4113  on `logits` internally for efficiency.  Do not call this op with the
4114  output of `softmax`, as it will produce incorrect results.
4115
4116  A common use case is to have logits and labels of shape
4117  `[batch_size, num_classes]`, but higher dimensions are supported, with
4118  the `dim` argument specifying the class dimension.
4119
4120  Backpropagation will happen only into `logits`.  To calculate a cross entropy
4121  loss that allows backpropagation into both `logits` and `labels`, see
4122  `tf.nn.softmax_cross_entropy_with_logits_v2`.
4123
4124  **Note that to avoid confusion, it is required to pass only named arguments to
4125  this function.**
4126
4127  Args:
4128    _sentinel: Used to prevent positional parameters. Internal, do not use.
4129    labels: Each vector along the class dimension should hold a valid
4130      probability distribution e.g. for the case in which labels are of shape
4131      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
4132      probability distribution.
4133    logits: Per-label activations, typically a linear output. These activation
4134      energies are interpreted as unnormalized log probabilities.
4135    dim: The class dimension. Defaulted to -1 which is the last dimension.
4136    name: A name for the operation (optional).
4137    axis: Alias for dim.
4138
4139  Returns:
4140    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4141    same as `logits` and its shape is the same as `labels` except that it does
4142    not have the last dimension of `labels`.
4143  """
4144  dim = deprecated_argument_lookup("axis", axis, "dim", dim)
4145  _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
4146                    logits)
4147
4148  with ops.name_scope(name, "softmax_cross_entropy_with_logits_sg",
4149                      [logits, labels]) as name:
4150    labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
4151
4152  return softmax_cross_entropy_with_logits_v2(
4153      labels=labels, logits=logits, axis=dim, name=name)
4154
4155
4156@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
4157@dispatch.add_dispatch_support
4158def sparse_softmax_cross_entropy_with_logits(
4159    _sentinel=None,  # pylint: disable=invalid-name
4160    labels=None,
4161    logits=None,
4162    name=None):
4163  """Computes sparse softmax cross entropy between `logits` and `labels`.
4164
4165  Measures the probability error in discrete classification tasks in which the
4166  classes are mutually exclusive (each entry is in exactly one class).  For
4167  example, each CIFAR-10 image is labeled with one and only one label: an image
4168  can be a dog or a truck, but not both.
4169
4170  **NOTE:**  For this operation, the probability of a given label is considered
4171  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4172  must provide a single specific index for the true class for each row of
4173  `logits` (each minibatch entry).  For soft softmax classification with
4174  a probability distribution for each entry, see
4175  `softmax_cross_entropy_with_logits_v2`.
4176
4177  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4178  on `logits` internally for efficiency.  Do not call this op with the
4179  output of `softmax`, as it will produce incorrect results.
4180
4181  A common use case is to have logits of shape
4182  `[batch_size, num_classes]` and have labels of shape
4183  `[batch_size]`, but higher dimensions are supported, in which
4184  case the `dim`-th dimension is assumed to be of size `num_classes`.
4185  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4186  `labels` must have the dtype of `int32` or `int64`.
4187
4188  **Note that to avoid confusion, it is required to pass only named arguments to
4189  this function.**
4190
4191  Args:
4192    _sentinel: Used to prevent positional parameters. Internal, do not use.
4193    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4194      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4195      must be an index in `[0, num_classes)`. Other values will raise an
4196      exception when this op is run on CPU, and return `NaN` for corresponding
4197      loss and gradient rows on GPU.
4198    logits: Per-label activations (typically a linear output) of shape
4199      `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
4200      `float64`. These activation energies are interpreted as unnormalized log
4201      probabilities.
4202    name: A name for the operation (optional).
4203
4204  Returns:
4205    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4206    with the softmax cross entropy loss.
4207
4208  Raises:
4209    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4210      of the labels is not equal to the rank of the logits minus one.
4211  """
4212  _ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
4213                    labels, logits)
4214
4215  # TODO(pcmurray) Raise an error when the label is not an index in
4216  # [0, num_classes). Note: This could break users who call this with bad
4217  # labels, but disregard the bad results.
4218
4219  # Reshape logits and labels to rank 2.
4220  with ops.name_scope(name, "SparseSoftmaxCrossEntropyWithLogits",
4221                      [labels, logits]):
4222    labels = ops.convert_to_tensor(labels)
4223    logits = ops.convert_to_tensor(logits)
4224    precise_logits = math_ops.cast(logits, dtypes.float32) if (dtypes.as_dtype(
4225        logits.dtype) == dtypes.float16) else logits
4226
4227    # Store label shape for result later.
4228    labels_static_shape = labels.get_shape()
4229    labels_shape = array_ops.shape(labels)
4230    static_shapes_fully_defined = (
4231        labels_static_shape.is_fully_defined() and
4232        logits.get_shape()[:-1].is_fully_defined())
4233    if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
4234      raise ValueError(
4235          "Logits cannot be scalars - received shape %s." % logits.get_shape())
4236    if logits.get_shape().ndims is not None and (
4237        labels_static_shape.ndims is not None and
4238        labels_static_shape.ndims != logits.get_shape().ndims - 1):
4239      raise ValueError("Rank mismatch: Rank of labels (received %s) should "
4240                       "equal rank of logits minus 1 (received %s)." %
4241                       (labels_static_shape.ndims, logits.get_shape().ndims))
4242    if (static_shapes_fully_defined and
4243        labels_static_shape != logits.get_shape()[:-1]):
4244      raise ValueError("Shape mismatch: The shape of labels (received %s) "
4245                       "should equal the shape of logits except for the last "
4246                       "dimension (received %s)." % (labels_static_shape,
4247                                                     logits.get_shape()))
4248    # Check if no reshapes are required.
4249    if logits.get_shape().ndims == 2:
4250      cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
4251          precise_logits, labels, name=name)
4252      if logits.dtype == dtypes.float16:
4253        return math_ops.cast(cost, dtypes.float16)
4254      else:
4255        return cost
4256
4257    # Perform a check of the dynamic shapes if the static shapes are not fully
4258    # defined.
4259    shape_checks = []
4260    if not static_shapes_fully_defined:
4261      shape_checks.append(
4262          check_ops.assert_equal(
4263              array_ops.shape(labels),
4264              array_ops.shape(logits)[:-1]))
4265    with ops.control_dependencies(shape_checks):
4266      # Reshape logits to 2 dim, labels to 1 dim.
4267      num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
4268      precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
4269      labels = array_ops.reshape(labels, [-1])
4270      # The second output tensor contains the gradients.  We use it in
4271      # _CrossEntropyGrad() in nn_grad but not here.
4272      cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
4273          precise_logits, labels, name=name)
4274      cost = array_ops.reshape(cost, labels_shape)
4275      cost.set_shape(labels_static_shape)
4276      if logits.dtype == dtypes.float16:
4277        return math_ops.cast(cost, dtypes.float16)
4278      else:
4279        return cost
4280
4281
4282@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
4283@dispatch.add_dispatch_support
4284def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
4285  """Computes sparse softmax cross entropy between `logits` and `labels`.
4286
4287  Measures the probability error in discrete classification tasks in which the
4288  classes are mutually exclusive (each entry is in exactly one class).  For
4289  example, each CIFAR-10 image is labeled with one and only one label: an image
4290  can be a dog or a truck, but not both.
4291
4292  Note:  For this operation, the probability of a given label is considered
4293  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4294  must provide a single specific index for the true class for each row of
4295  `logits` (each minibatch entry).  For soft softmax classification with
4296  a probability distribution for each entry, see
4297  `softmax_cross_entropy_with_logits_v2`.
4298
4299  Warning: This op expects unscaled logits, since it performs a `softmax`
4300  on `logits` internally for efficiency.  Do not call this op with the
4301  output of `softmax`, as it will produce incorrect results.
4302
4303  A common use case is to have logits of shape
4304  `[batch_size, num_classes]` and have labels of shape
4305  `[batch_size]`, but higher dimensions are supported, in which
4306  case the `dim`-th dimension is assumed to be of size `num_classes`.
4307  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4308  `labels` must have the dtype of `int32` or `int64`.
4309
4310  >>> logits = tf.constant([[2., -5., .5, -.1],
4311  ...                       [0., 0., 1.9, 1.4],
4312  ...                       [-100., 100., -100., -100.]])
4313  >>> labels = tf.constant([0, 3, 1])
4314  >>> tf.nn.sparse_softmax_cross_entropy_with_logits(
4315  ...     labels=labels, logits=logits).numpy()
4316  array([0.29750752, 1.1448325 , 0.        ], dtype=float32)
4317
4318  To avoid confusion, passing only named arguments to this function is
4319  recommended.
4320
4321  Args:
4322    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4323      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4324      must be an index in `[0, num_classes)`. Other values will raise an
4325      exception when this op is run on CPU, and return `NaN` for corresponding
4326      loss and gradient rows on GPU.
4327    logits: Unscaled log probabilities of shape `[d_0, d_1, ..., d_{r-1},
4328      num_classes]` and dtype `float16`, `float32`, or `float64`.
4329    name: A name for the operation (optional).
4330
4331  Returns:
4332    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4333    with the softmax cross entropy loss.
4334
4335  Raises:
4336    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4337      of the labels is not equal to the rank of the logits minus one.
4338  """
4339  return sparse_softmax_cross_entropy_with_logits(
4340      labels=labels, logits=logits, name=name)
4341
4342
4343@tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"])
4344@dispatch.add_dispatch_support
4345def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None):  # pylint: disable=redefined-builtin
4346  """Performs the avg pooling on the input.
4347
4348  Each entry in `output` is the mean of the corresponding size `ksize`
4349  window in `value`.
4350
4351  Args:
4352    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4353      [num_channels]` if `data_format` does not start with "NC" (default), or
4354      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4355      with "NC". Pooling happens over the spatial dimensions only.
4356    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4357      of the window for each dimension of the input tensor.
4358    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4359      stride of the sliding window for each dimension of the input tensor.
4360    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4361      the "returns" section of `tf.nn.convolution` for details.
4362    data_format: A string. Specifies the channel dimension. For N=1 it can be
4363      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4364      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4365    name: Optional name for the operation.
4366
4367  Returns:
4368    A `Tensor` of format specified by `data_format`.
4369    The average pooled output tensor.
4370  """
4371  if input.shape is not None:
4372    n = len(input.shape) - 2
4373  elif data_format is not None:
4374    n = len(data_format) - 2
4375  else:
4376    raise ValueError(
4377        "The input must have a rank or a data format must be given.")
4378  if not 1 <= n <= 3:
4379    raise ValueError(
4380        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
4381
4382  if data_format is None:
4383    channel_index = n + 1
4384  else:
4385    channel_index = 1 if data_format.startswith("NC") else n + 1
4386
4387  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4388  strides = _get_sequence(strides, n, channel_index, "strides")
4389
4390  avg_pooling_ops = {
4391      1: avg_pool1d,
4392      2: gen_nn_ops.avg_pool,
4393      3: gen_nn_ops.avg_pool3d
4394  }
4395
4396  op = avg_pooling_ops[n]
4397  return op(
4398      input,
4399      ksize=ksize,
4400      strides=strides,
4401      padding=padding,
4402      data_format=data_format,
4403      name=name)
4404
4405
4406@tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"])
4407@dispatch.add_dispatch_support
4408def avg_pool(value, ksize, strides, padding, data_format="NHWC",
4409             name=None, input=None):  # pylint: disable=redefined-builtin
4410  """Performs the average pooling on the input.
4411
4412  Each entry in `output` is the mean of the corresponding size `ksize`
4413  window in `value`.
4414
4415  Args:
4416    value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4417      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4418    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4419      the window for each dimension of the input tensor.
4420    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4421      stride of the sliding window for each dimension of the input tensor.
4422    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4423      See the "returns" section of `tf.nn.convolution` for details.
4424    data_format: A string. 'NHWC' and 'NCHW' are supported.
4425    name: Optional name for the operation.
4426    input: Alias for value.
4427
4428  Returns:
4429    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4430  """
4431  with ops.name_scope(name, "AvgPool", [value]) as name:
4432    value = deprecation.deprecated_argument_lookup(
4433        "input", input, "value", value)
4434
4435    if data_format is None:
4436      data_format = "NHWC"
4437    channel_index = 1 if data_format.startswith("NC") else 3
4438
4439    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4440    strides = _get_sequence(strides, 2, channel_index, "strides")
4441
4442    return gen_nn_ops.avg_pool(
4443        value,
4444        ksize=ksize,
4445        strides=strides,
4446        padding=padding,
4447        data_format=data_format,
4448        name=name)
4449
4450
4451@tf_export("nn.avg_pool2d", v1=[])
4452@dispatch.add_dispatch_support
4453def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4454  """Performs the average pooling on the input.
4455
4456  Each entry in `output` is the mean of the corresponding size `ksize`
4457  window in `value`.
4458
4459  Args:
4460    input: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4461      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4462    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4463      the window for each dimension of the input tensor.
4464    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4465      stride of the sliding window for each dimension of the input tensor.
4466    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4467      See the "returns" section of `tf.nn.convolution` for details.
4468    data_format: A string. 'NHWC' and 'NCHW' are supported.
4469    name: Optional name for the operation.
4470
4471  Returns:
4472    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4473  """
4474  with ops.name_scope(name, "AvgPool2D", [input]) as name:
4475    if data_format is None:
4476      data_format = "NHWC"
4477    channel_index = 1 if data_format.startswith("NC") else 3
4478
4479    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4480    strides = _get_sequence(strides, 2, channel_index, "strides")
4481
4482    return gen_nn_ops.avg_pool(
4483        input,
4484        ksize=ksize,
4485        strides=strides,
4486        padding=padding,
4487        data_format=data_format,
4488        name=name)
4489
4490
4491@tf_export("nn.avg_pool1d")
4492@dispatch.add_dispatch_support
4493def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):  # pylint: disable=redefined-builtin
4494  """Performs the average pooling on the input.
4495
4496  Each entry in `output` is the mean of the corresponding size `ksize`
4497  window in `value`.
4498
4499  Note internally this op reshapes and uses the underlying 2d operation.
4500
4501  Args:
4502    input: A 3-D `Tensor` of the format specified by `data_format`.
4503    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4504      window for each dimension of the input tensor.
4505    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4506      the sliding window for each dimension of the input tensor.
4507    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4508      the "returns" section of `tf.nn.convolution` for details.
4509    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4510    name: A name for the operation (optional).
4511
4512  Returns:
4513    A `Tensor` of format specified by `data_format`.
4514    The max pooled output tensor.
4515  """
4516  with ops.name_scope(name, "AvgPool1D", [input]) as name:
4517    if data_format is None:
4518      data_format = "NWC"
4519    channel_index = 1 if data_format.startswith("NC") else 2
4520    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4521    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4522
4523    expanding_dim = 1 if data_format == "NWC" else 2
4524    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4525
4526    input = array_ops.expand_dims_v2(input, expanding_dim)
4527    result = gen_nn_ops.avg_pool(
4528        input,
4529        ksize=ksize,
4530        strides=strides,
4531        padding=padding,
4532        data_format=data_format,
4533        name=name)
4534    return array_ops.squeeze(result, expanding_dim)
4535
4536
4537@tf_export("nn.avg_pool3d")
4538@dispatch.add_dispatch_support
4539def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):  # pylint: disable=redefined-builtin
4540  """Performs the average pooling on the input.
4541
4542  Each entry in `output` is the mean of the corresponding size `ksize`
4543  window in `value`.
4544
4545  Args:
4546    input: A 5-D `Tensor` of shape `[batch, height, width, channels]` and type
4547      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4548    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
4549      the window for each dimension of the input tensor.
4550    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
4551      stride of the sliding window for each dimension of the input tensor.
4552    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4553      See the "returns" section of `tf.nn.convolution` for details.
4554    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
4555    name: Optional name for the operation.
4556
4557  Returns:
4558    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4559  """
4560  with ops.name_scope(name, "AvgPool3D", [input]) as name:
4561    if data_format is None:
4562      data_format = "NDHWC"
4563    channel_index = 1 if data_format.startswith("NC") else 3
4564
4565    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
4566    strides = _get_sequence(strides, 3, channel_index, "strides")
4567
4568    return gen_nn_ops.avg_pool3d(
4569        input,
4570        ksize=ksize,
4571        strides=strides,
4572        padding=padding,
4573        data_format=data_format,
4574        name=name)
4575
4576
4577# pylint: disable=redefined-builtin
4578@tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
4579@dispatch.add_dispatch_support
4580def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
4581  """Performs max pooling on the input.
4582
4583  For a given window of `ksize`, takes the maximum value within that window.
4584  Used for reducing computation and preventing overfitting.
4585
4586  Consider an example of pooling with 2x2, non-overlapping windows:
4587
4588  >>> matrix = tf.constant([
4589  ...     [0, 0, 1, 7],
4590  ...     [0, 2, 0, 0],
4591  ...     [5, 2, 0, 0],
4592  ...     [0, 0, 9, 8],
4593  ... ])
4594  >>> reshaped = tf.reshape(matrix, (1, 4, 4, 1))
4595  >>> tf.nn.max_pool(reshaped, ksize=2, strides=2, padding="SAME")
4596  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4597  array([[[[2],
4598           [7]],
4599          [[5],
4600           [9]]]], dtype=int32)>
4601
4602  We can adjust the window size using the `ksize` parameter. For example, if we
4603  were to expand the window to 3:
4604
4605  >>> tf.nn.max_pool(reshaped, ksize=3, strides=2, padding="SAME")
4606  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4607  array([[[[5],
4608           [7]],
4609          [[9],
4610           [9]]]], dtype=int32)>
4611
4612  We've now picked up two additional large numbers (5 and 9) in two of the
4613  pooled spots.
4614
4615  Note that our windows are now overlapping, since we're still moving by 2 units
4616  on each iteration. This is causing us to see the same 9 repeated twice, since
4617  it is part of two overlapping windows.
4618
4619  We can adjust how far we move our window with each iteration using the
4620  `strides` parameter. Updating this to the same value as our window size
4621  eliminates the overlap:
4622
4623  >>> tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="SAME")
4624  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4625  array([[[[2],
4626           [7]],
4627          [[5],
4628           [9]]]], dtype=int32)>
4629
4630  Because the window does not neatly fit into our input, padding is added around
4631  the edges, giving us the same result as when we used a 2x2 window. We can skip
4632  padding altogether and simply drop the windows that do not fully fit into our
4633  input by instead passing `"VALID"` to the `padding` argument:
4634
4635  >>> tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="VALID")
4636  <tf.Tensor: shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],
4637   dtype=int32)>
4638
4639  Now we've grabbed the largest value in the 3x3 window starting from the upper-
4640  left corner. Since no other windows fit in our input, they are dropped.
4641
4642  Args:
4643    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4644      [num_channels]` if `data_format` does not start with "NC" (default), or
4645      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4646      with "NC". Pooling happens over the spatial dimensions only.
4647    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4648      of the window for each dimension of the input tensor.
4649    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4650      stride of the sliding window for each dimension of the input tensor.
4651    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4652      padding algorithm to use, or a list indicating the explicit paddings at
4653      the start and end of each dimension. When explicit padding is used and
4654      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4655      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4656      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4657      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4658      padding, the size of the paddings cannot be greater than the sliding
4659      window size.
4660    data_format: A string. Specifies the channel dimension. For N=1 it can be
4661      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4662      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4663    name: Optional name for the operation.
4664
4665  Returns:
4666    A `Tensor` of format specified by `data_format`.
4667    The max pooled output tensor.
4668  """
4669  if input.shape is not None:
4670    n = len(input.shape) - 2
4671  elif data_format is not None:
4672    n = len(data_format) - 2
4673  else:
4674    raise ValueError(
4675        "The input must have a rank or a data format must be given.")
4676  if not 1 <= n <= 3:
4677    raise ValueError(
4678        "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
4679
4680  if data_format is None:
4681    channel_index = n + 1
4682  else:
4683    channel_index = 1 if data_format.startswith("NC") else n + 1
4684
4685  if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4686    raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4687                     "explicit padding")
4688
4689  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4690  strides = _get_sequence(strides, n, channel_index, "strides")
4691
4692  if (isinstance(padding, (list, tuple)) and n == 3):
4693    raise ValueError("Explicit padding is not yet supported with an input "
4694                     "tensor of rank 5")
4695
4696  max_pooling_ops = {
4697      1: max_pool1d,
4698      2: max_pool2d,
4699      3: gen_nn_ops.max_pool3d
4700  }
4701
4702  op = max_pooling_ops[n]
4703  return op(
4704      input,
4705      ksize=ksize,
4706      strides=strides,
4707      padding=padding,
4708      data_format=data_format,
4709      name=name)
4710# pylint: enable=redefined-builtin
4711
4712
4713@tf_export(v1=["nn.max_pool"])
4714@dispatch.add_dispatch_support
4715def max_pool(value,
4716             ksize,
4717             strides,
4718             padding,
4719             data_format="NHWC",
4720             name=None,
4721             input=None):  # pylint: disable=redefined-builtin
4722  """Performs the max pooling on the input.
4723
4724  Args:
4725    value: A 4-D `Tensor` of the format specified by `data_format`.
4726    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
4727      The size of the window for each dimension of the input tensor.
4728    strides: An int or list of `ints` that has length `1`, `2` or `4`.
4729      The stride of the sliding window for each dimension of the input tensor.
4730    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4731      padding algorithm to use, or a list indicating the explicit paddings at
4732      the start and end of each dimension. When explicit padding is used and
4733      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4734      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4735      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4736      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4737      padding, the size of the paddings cannot be greater than the sliding
4738      window size.
4739    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
4740    name: Optional name for the operation.
4741    input: Alias for value.
4742
4743  Returns:
4744    A `Tensor` of format specified by `data_format`.
4745    The max pooled output tensor.
4746  """
4747  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
4748  with ops.name_scope(name, "MaxPool", [value]) as name:
4749    if data_format is None:
4750      data_format = "NHWC"
4751    channel_index = 1 if data_format.startswith("NC") else 3
4752
4753    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4754    strides = _get_sequence(strides, 2, channel_index, "strides")
4755    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4756      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4757                       "explicit padding")
4758    padding, explicit_paddings = convert_padding(padding)
4759    if ((np.isscalar(ksize) and ksize == 0) or
4760        (isinstance(ksize,
4761                    (list, tuple, np.ndarray)) and any(v == 0 for v in ksize))):
4762      raise ValueError("ksize cannot be zero.")
4763
4764    return gen_nn_ops.max_pool(
4765        value,
4766        ksize=ksize,
4767        strides=strides,
4768        padding=padding,
4769        explicit_paddings=explicit_paddings,
4770        data_format=data_format,
4771        name=name)
4772
4773
4774# pylint: disable=redefined-builtin
4775@tf_export("nn.max_pool1d")
4776@dispatch.add_dispatch_support
4777def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
4778  """Performs the max pooling on the input.
4779
4780  Note internally this op reshapes and uses the underlying 2d operation.
4781
4782  Args:
4783    input: A 3-D `Tensor` of the format specified by `data_format`.
4784    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4785      window for each dimension of the input tensor.
4786    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4787      the sliding window for each dimension of the input tensor.
4788    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4789      padding algorithm to use, or a list indicating the explicit paddings at
4790      the start and end of each dimension. When explicit padding is used and
4791      data_format is `"NWC"`, this should be in the form `[[0, 0], [pad_left,
4792      pad_right], [0, 0]]`. When explicit padding used and data_format is
4793      `"NCW"`, this should be in the form `[[0, 0], [0, 0], [pad_left,
4794      pad_right]]`. When using explicit padding, the size of the paddings cannot
4795      be greater than the sliding window size.
4796    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4797    name: A name for the operation (optional).
4798
4799  Returns:
4800    A `Tensor` of format specified by `data_format`.
4801    The max pooled output tensor.
4802  """
4803  with ops.name_scope(name, "MaxPool1d", [input]) as name:
4804    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4805      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4806                       "explicit padding")
4807    if data_format is None:
4808      data_format = "NWC"
4809    channel_index = 1 if data_format.startswith("NC") else 2
4810    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4811    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4812    padding, explicit_paddings = convert_padding(padding, 3)
4813    if padding == "EXPLICIT":
4814      explicit_paddings = [0, 0] + explicit_paddings
4815
4816    expanding_dim = 1 if data_format == "NWC" else 2
4817    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4818
4819    input = array_ops.expand_dims_v2(input, expanding_dim)
4820    result = gen_nn_ops.max_pool(
4821        input,
4822        ksize=ksize,
4823        strides=strides,
4824        padding=padding,
4825        explicit_paddings=explicit_paddings,
4826        data_format=data_format,
4827        name=name)
4828    return array_ops.squeeze(result, expanding_dim)
4829# pylint: enable=redefined-builtin
4830
4831
4832# pylint: disable=redefined-builtin
4833@tf_export("nn.max_pool2d")
4834@dispatch.add_dispatch_support
4835def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
4836  """Performs the max pooling on the input.
4837
4838  Args:
4839    input: A 4-D `Tensor` of the format specified by `data_format`.
4840    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4841      the window for each dimension of the input tensor.
4842    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4843      stride of the sliding window for each dimension of the input tensor.
4844    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4845      padding algorithm to use, or a list indicating the explicit paddings at
4846      the start and end of each dimension. When explicit padding is used and
4847      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4848      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4849      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4850      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4851      padding, the size of the paddings cannot be greater than the sliding
4852      window size.
4853    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
4854    name: Optional name for the operation.
4855
4856  Returns:
4857    A `Tensor` of format specified by `data_format`.
4858    The max pooled output tensor.
4859  """
4860  with ops.name_scope(name, "MaxPool2d", [input]) as name:
4861    if data_format is None:
4862      data_format = "NHWC"
4863    channel_index = 1 if data_format.startswith("NC") else 3
4864
4865    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4866    strides = _get_sequence(strides, 2, channel_index, "strides")
4867    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4868      raise ValueError("Data formats NCHW_VECT_C is not yet supported with "
4869                       "explicit padding")
4870    padding, explicit_paddings = convert_padding(padding)
4871
4872    return gen_nn_ops.max_pool(
4873        input,
4874        ksize=ksize,
4875        strides=strides,
4876        padding=padding,
4877        explicit_paddings=explicit_paddings,
4878        data_format=data_format,
4879        name=name)
4880# pylint: enable=redefined-builtin
4881
4882
4883# pylint: disable=redefined-builtin
4884@tf_export("nn.max_pool3d")
4885@dispatch.add_dispatch_support
4886def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
4887  """Performs the max pooling on the input.
4888
4889  Args:
4890    input: A 5-D `Tensor` of the format specified by `data_format`.
4891    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
4892      the window for each dimension of the input tensor.
4893    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
4894      stride of the sliding window for each dimension of the input tensor.
4895    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4896      the "returns" section of `tf.nn.convolution` for details.
4897    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
4898      The data format of the input and output data. With the default format
4899      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
4900        in_width, in_channels]. Alternatively, the format could be "NCDHW", the
4901      data storage order is: [batch, in_channels, in_depth, in_height,
4902        in_width].
4903    name: A name for the operation (optional).
4904
4905  Returns:
4906    A `Tensor` of format specified by `data_format`.
4907    The max pooled output tensor.
4908  """
4909  with ops.name_scope(name, "MaxPool3D", [input]) as name:
4910    if data_format is None:
4911      data_format = "NDHWC"
4912    channel_index = 1 if data_format.startswith("NC") else 4
4913
4914    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
4915    strides = _get_sequence(strides, 3, channel_index, "strides")
4916
4917    return gen_nn_ops.max_pool3d(
4918        input,
4919        ksize=ksize,
4920        strides=strides,
4921        padding=padding,
4922        data_format=data_format,
4923        name=name)
4924# pylint: enable=redefined-builtin
4925
4926
4927@tf_export("nn.max_pool_with_argmax", v1=[])
4928@dispatch.add_dispatch_support
4929def max_pool_with_argmax_v2(
4930    input,  # pylint: disable=redefined-builtin
4931    ksize,
4932    strides,
4933    padding,
4934    data_format="NHWC",
4935    output_dtype=dtypes.int64,
4936    include_batch_in_index=False,
4937    name=None):
4938  """Performs max pooling on the input and outputs both max values and indices.
4939
4940  The indices in `argmax` are flattened, so that a maximum value at position
4941  `[b, y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
4942  `include_batch_in_index` is False;
4943  `((b * height + y) * width + x) * channels + c`
4944  if `include_batch_in_index` is True.
4945
4946  The indices returned are always in `[0, height) x [0, width)` before
4947  flattening, even if padding is involved and the mathematically correct answer
4948  is outside (either negative or too large).  This is a bug, but fixing it is
4949  difficult to do in a safe backwards compatible way, especially due to
4950  flattening.
4951
4952  Args:
4953    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
4954      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
4955      `uint32`, `uint64`.
4956      4-D with shape `[batch, height, width, channels]`.  Input to pool over.
4957    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
4958      The size of the window for each dimension of the input tensor.
4959    strides: An int or list of `ints` that has length `1`, `2` or `4`.
4960      The stride of the sliding window for each dimension of the
4961      input tensor.
4962    padding: A `string` from: `"SAME", "VALID"`.
4963      The type of padding algorithm to use.
4964    data_format: An optional `string`, must be set to `"NHWC"`. Defaults to
4965      `"NHWC"`.
4966      Specify the data format of the input and output data.
4967    output_dtype: An optional `tf.DType` from: `tf.int32, tf.int64`.
4968      Defaults to `tf.int64`.
4969      The dtype of the returned argmax tensor.
4970    include_batch_in_index: An optional `boolean`. Defaults to `False`.
4971      Whether to include batch dimension in flattened index of `argmax`.
4972    name: A name for the operation (optional).
4973
4974  Returns:
4975    A tuple of `Tensor` objects (output, argmax).
4976
4977    output: A `Tensor`. Has the same type as `input`.
4978    argmax: A `Tensor` of type `output_dtype`.
4979  """
4980
4981  if data_format != "NHWC":
4982    raise ValueError("Data formats other than 'NHWC' are not yet supported")
4983
4984  ksize = _get_sequence(ksize, 2, 3, "ksize")
4985  strides = _get_sequence(strides, 2, 3, "strides")
4986
4987  return gen_nn_ops.max_pool_with_argmax(
4988      input=input,
4989      ksize=ksize,
4990      strides=strides,
4991      padding=padding,
4992      Targmax=output_dtype,
4993      include_batch_in_index=include_batch_in_index,
4994      name=name)
4995
4996
4997@tf_export(v1=["nn.max_pool_with_argmax"])
4998@dispatch.add_dispatch_support
4999def max_pool_with_argmax_v1(  # pylint: disable=missing-docstring,invalid-name
5000    input,  # pylint: disable=redefined-builtin
5001    ksize,
5002    strides,
5003    padding,
5004    data_format="NHWC",
5005    Targmax=None,
5006    name=None,
5007    output_dtype=None,
5008    include_batch_in_index=False):
5009  if data_format != "NHWC":
5010    raise ValueError("Data formats other than 'NHWC' are not yet supported")
5011
5012  Targmax = deprecated_argument_lookup(
5013      "output_dtype", output_dtype, "Targmax", Targmax)
5014  if Targmax is None:
5015    Targmax = dtypes.int64
5016  return gen_nn_ops.max_pool_with_argmax(
5017      input=input,
5018      ksize=ksize,
5019      strides=strides,
5020      padding=padding,
5021      Targmax=Targmax,
5022      include_batch_in_index=include_batch_in_index,
5023      name=name)
5024
5025
5026max_pool_with_argmax_v1.__doc__ = gen_nn_ops.max_pool_with_argmax.__doc__
5027
5028
5029@ops.RegisterStatistics("Conv3D", "flops")
5030def _calc_conv3d_flops(graph, node):
5031  """Calculates the compute resources needed for Conv3D."""
5032  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5033  input_shape.assert_is_fully_defined()
5034  filter_shape = graph_util.tensor_shape_from_node_def_name(
5035      graph, node.input[1])
5036  filter_shape.assert_is_fully_defined()
5037  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5038  output_shape.assert_is_fully_defined()
5039  filter_time = int(filter_shape[0])
5040  filter_height = int(filter_shape[1])
5041  filter_width = int(filter_shape[2])
5042  filter_in_depth = int(filter_shape[3])
5043  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5044  return ops.OpStats("flops", (output_count * filter_in_depth * filter_time *
5045                               filter_height * filter_width * 2))
5046
5047
5048@ops.RegisterStatistics("Conv2D", "flops")
5049def _calc_conv_flops(graph, node):
5050  """Calculates the compute resources needed for Conv2D."""
5051  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5052  input_shape.assert_is_fully_defined()
5053  filter_shape = graph_util.tensor_shape_from_node_def_name(
5054      graph, node.input[1])
5055  filter_shape.assert_is_fully_defined()
5056  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5057  output_shape.assert_is_fully_defined()
5058  filter_height = int(filter_shape[0])
5059  filter_width = int(filter_shape[1])
5060  filter_in_depth = int(filter_shape[2])
5061  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5062  return ops.OpStats(
5063      "flops",
5064      (output_count * filter_in_depth * filter_height * filter_width * 2))
5065
5066
5067@ops.RegisterStatistics("DepthwiseConv2dNative", "flops")
5068def _calc_depthwise_conv_flops(graph, node):
5069  """Calculates the compute resources needed for DepthwiseConv2dNative."""
5070  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5071  input_shape.assert_is_fully_defined()
5072  filter_shape = graph_util.tensor_shape_from_node_def_name(
5073      graph, node.input[1])
5074  filter_shape.assert_is_fully_defined()
5075  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5076  output_shape.assert_is_fully_defined()
5077  filter_height = int(filter_shape[0])
5078  filter_width = int(filter_shape[1])
5079  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5080  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
5081
5082
5083@ops.RegisterStatistics("BiasAdd", "flops")
5084def _calc_bias_add_flops(graph, node):
5085  """Calculates the computing needed for BiasAdd."""
5086  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5087  input_shape.assert_is_fully_defined()
5088  input_count = np.prod(input_shape.as_list())
5089  return ops.OpStats("flops", input_count)
5090
5091
5092@tf_export(v1=["nn.xw_plus_b"])
5093@dispatch.add_dispatch_support
5094def xw_plus_b(x, weights, biases, name=None):  # pylint: disable=invalid-name
5095  """Computes matmul(x, weights) + biases.
5096
5097  Args:
5098    x: a 2D tensor.  Dimensions typically: batch, in_units
5099    weights: a 2D tensor.  Dimensions typically: in_units, out_units
5100    biases: a 1D tensor.  Dimensions: out_units
5101    name: A name for the operation (optional).  If not specified
5102      "xw_plus_b" is used.
5103
5104  Returns:
5105    A 2-D Tensor computing matmul(x, weights) + biases.
5106    Dimensions typically: batch, out_units.
5107  """
5108  with ops.name_scope(name, "xw_plus_b", [x, weights, biases]) as name:
5109    x = ops.convert_to_tensor(x, name="x")
5110    weights = ops.convert_to_tensor(weights, name="weights")
5111    biases = ops.convert_to_tensor(biases, name="biases")
5112    mm = math_ops.matmul(x, weights)
5113    return bias_add(mm, biases, name=name)
5114
5115
5116def xw_plus_b_v1(x, weights, biases, name=None):
5117  """Computes matmul(x, weights) + biases.
5118
5119  This is a deprecated version of that will soon be removed.
5120
5121  Args:
5122    x: a 2D tensor.  Dimensions typically: batch, in_units
5123    weights: a 2D tensor.  Dimensions typically: in_units, out_units
5124    biases: a 1D tensor.  Dimensions: out_units
5125    name: A name for the operation (optional).  If not specified
5126      "xw_plus_b_v1" is used.
5127
5128  Returns:
5129    A 2-D Tensor computing matmul(x, weights) + biases.
5130    Dimensions typically: batch, out_units.
5131  """
5132  with ops.name_scope(name, "xw_plus_b_v1", [x, weights, biases]) as name:
5133    x = ops.convert_to_tensor(x, name="x")
5134    weights = ops.convert_to_tensor(weights, name="weights")
5135    biases = ops.convert_to_tensor(biases, name="biases")
5136    mm = math_ops.matmul(x, weights)
5137    return bias_add_v1(mm, biases, name=name)
5138
5139
5140def _get_noise_shape(x, noise_shape):
5141  # If noise_shape is none return immediately.
5142  if noise_shape is None:
5143    return array_ops.shape(x)
5144
5145  try:
5146    # Best effort to figure out the intended shape.
5147    # If not possible, let the op to handle it.
5148    # In eager mode exception will show up.
5149    noise_shape_ = tensor_shape.as_shape(noise_shape)
5150  except (TypeError, ValueError):
5151    return noise_shape
5152
5153  if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
5154    new_dims = []
5155    for i, dim in enumerate(x.shape.dims):
5156      if noise_shape_.dims[i].value is None and dim.value is not None:
5157        new_dims.append(dim.value)
5158      else:
5159        new_dims.append(noise_shape_.dims[i].value)
5160    return tensor_shape.TensorShape(new_dims)
5161
5162  return noise_shape
5163
5164
5165@tf_export(v1=["nn.dropout"])
5166@dispatch.add_dispatch_support
5167@deprecation.deprecated_args(None, "Please use `rate` instead of `keep_prob`. "
5168                             "Rate should be set to `rate = 1 - keep_prob`.",
5169                             "keep_prob")
5170def dropout(x, keep_prob=None, noise_shape=None, seed=None, name=None,
5171            rate=None):
5172  """Computes dropout.
5173
5174  For each element of `x`, with probability `rate`, outputs `0`, and otherwise
5175  scales up the input by `1 / (1-rate)`. The scaling is such that the expected
5176  sum is unchanged.
5177
5178  By default, each element is kept or dropped independently.  If `noise_shape`
5179  is specified, it must be
5180  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5181  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5182  will make independent decisions.  For example, if `shape(x) = [k, l, m, n]`
5183  and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
5184  kept independently and each row and column will be kept or not kept together.
5185
5186  Args:
5187    x: A floating point tensor.
5188    keep_prob: (deprecated) A deprecated alias for `(1-rate)`.
5189    noise_shape: A 1-D integer `Tensor`, representing the
5190      shape for randomly generated keep/drop flags.
5191    seed: A Python integer. Used to create random seeds. See
5192      `tf.random.set_seed` for behavior.
5193    name: A name for this operation (optional).
5194    rate: A scalar `Tensor` with the same type as `x`. The probability that each
5195      element of `x` is discarded.
5196
5197  Returns:
5198    A Tensor of the same shape of `x`.
5199
5200  Raises:
5201    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating
5202      point tensor.
5203  """
5204  try:
5205    rate_from_keep_prob = 1. - keep_prob if keep_prob is not None else None
5206  except TypeError:
5207    raise ValueError("keep_prob must be a floating point number or Tensor "
5208                     "(got %r)" % keep_prob)
5209
5210  rate = deprecation.deprecated_argument_lookup(
5211      "rate", rate,
5212      "keep_prob", rate_from_keep_prob)
5213
5214  if rate is None:
5215    raise ValueError("You must provide a rate to dropout.")
5216
5217  return dropout_v2(x, rate, noise_shape=noise_shape, seed=seed, name=name)
5218
5219
5220@tf_export("nn.dropout", v1=[])
5221@dispatch.add_dispatch_support
5222def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
5223  """Computes dropout: randomly sets elements to zero to prevent overfitting.
5224
5225  Warning: You should consider using
5226  `tf.nn.experimental.stateless_dropout` instead of this function. The
5227  difference between `tf.nn.experimental.stateless_dropout` and this
5228  function is analogous to the difference between
5229  `tf.random.stateless_uniform` and `tf.random.uniform`. Please see
5230  [Random number
5231  generation](https://www.tensorflow.org/guide/random_numbers) guide
5232  for a detailed description of the various RNG systems in TF. As the
5233  guide states, legacy stateful RNG ops like `tf.random.uniform` and
5234  `tf.nn.dropout` are not deprecated yet but highly discouraged,
5235  because their states are hard to control.
5236
5237  Note: The behavior of dropout has changed between TensorFlow 1.x and 2.x.
5238  When converting 1.x code, please use named arguments to ensure behavior stays
5239  consistent.
5240
5241  See also: `tf.keras.layers.Dropout` for a dropout layer.
5242
5243  [Dropout](https://arxiv.org/abs/1207.0580) is useful for regularizing DNN
5244  models. Inputs elements are randomly set to zero (and the other elements are
5245  rescaled). This encourages each node to be independently useful, as it cannot
5246  rely on the output of other nodes.
5247
5248  More precisely: With probability `rate` elements of `x` are set to `0`.
5249  The remaining elements are scaled up by `1.0 / (1 - rate)`, so that the
5250  expected value is preserved.
5251
5252  >>> tf.random.set_seed(0)
5253  >>> x = tf.ones([3,5])
5254  >>> tf.nn.dropout(x, rate = 0.5, seed = 1).numpy()
5255  array([[2., 0., 0., 2., 2.],
5256       [2., 2., 2., 2., 2.],
5257       [2., 0., 2., 0., 2.]], dtype=float32)
5258
5259  >>> tf.random.set_seed(0)
5260  >>> x = tf.ones([3,5])
5261  >>> tf.nn.dropout(x, rate = 0.8, seed = 1).numpy()
5262  array([[0., 0., 0., 5., 5.],
5263       [0., 5., 0., 5., 0.],
5264       [5., 0., 5., 0., 5.]], dtype=float32)
5265
5266  >>> tf.nn.dropout(x, rate = 0.0) == x
5267  <tf.Tensor: shape=(3, 5), dtype=bool, numpy=
5268    array([[ True,  True,  True,  True,  True],
5269           [ True,  True,  True,  True,  True],
5270           [ True,  True,  True,  True,  True]])>
5271
5272
5273  By default, each element is kept or dropped independently.  If `noise_shape`
5274  is specified, it must be
5275  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5276  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5277  will make independent decisions. This is useful for dropping whole
5278  channels from an image or sequence. For example:
5279
5280  >>> tf.random.set_seed(0)
5281  >>> x = tf.ones([3,10])
5282  >>> tf.nn.dropout(x, rate = 2/3, noise_shape=[1,10], seed=1).numpy()
5283  array([[0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5284       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5285       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.]], dtype=float32)
5286
5287  Args:
5288    x: A floating point tensor.
5289    rate: A scalar `Tensor` with the same type as x. The probability
5290      that each element is dropped. For example, setting rate=0.1 would drop
5291      10% of input elements.
5292    noise_shape: A 1-D integer `Tensor`, representing the
5293      shape for randomly generated keep/drop flags.
5294    seed: A Python integer. Used to create random seeds. See
5295      `tf.random.set_seed` for behavior.
5296    name: A name for this operation (optional).
5297
5298  Returns:
5299    A Tensor of the same shape of `x`.
5300
5301  Raises:
5302    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
5303      tensor. `rate=1` is disallowed, because the output would be all zeros,
5304      which is likely not what was intended.
5305  """
5306  uniform_sampler = functools.partial(random_ops.random_uniform, seed=seed)
5307  def dummy_rng_step():
5308    random_seed.get_seed(seed)
5309  return _dropout(x=x, rate=rate, noise_shape=noise_shape,
5310                  uniform_sampler=uniform_sampler,
5311                  dummy_rng_step=dummy_rng_step, name=name,
5312                  default_name="dropout")
5313
5314
5315@tf_export("nn.experimental.stateless_dropout")
5316@dispatch.add_dispatch_support
5317def stateless_dropout(x, rate, seed, rng_alg=None, noise_shape=None, name=None):
5318  """Computes dropout: randomly sets elements to zero to prevent overfitting.
5319
5320  [Dropout](https://arxiv.org/abs/1207.0580) is useful for regularizing DNN
5321  models. Inputs elements are randomly set to zero (and the other elements are
5322  rescaled). This encourages each node to be independently useful, as it cannot
5323  rely on the output of other nodes.
5324
5325  More precisely: With probability `rate` elements of `x` are set to `0`.
5326  The remaining elements are scaled up by `1.0 / (1 - rate)`, so that the
5327  expected value is preserved.
5328
5329  >>> x = tf.ones([3,5])
5330  >>> tf.nn.experimental.stateless_dropout(x, rate=0.5, seed=[1, 0])
5331  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5332  array([[2., 0., 2., 0., 0.],
5333         [0., 0., 2., 0., 2.],
5334         [0., 0., 0., 0., 2.]], dtype=float32)>
5335
5336  >>> x = tf.ones([3,5])
5337  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5338  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5339  array([[5., 0., 0., 0., 0.],
5340         [0., 0., 0., 0., 5.],
5341         [0., 0., 0., 0., 5.]], dtype=float32)>
5342
5343  >>> tf.nn.experimental.stateless_dropout(x, rate=0.0, seed=[1, 0]) == x
5344  <tf.Tensor: shape=(3, 5), dtype=bool, numpy=
5345  array([[ True,  True,  True,  True,  True],
5346         [ True,  True,  True,  True,  True],
5347         [ True,  True,  True,  True,  True]])>
5348
5349
5350  This function is a stateless version of `tf.nn.dropout`, in the
5351  sense that no matter how many times you call this function, the same
5352  `seed` will lead to the same results, and different `seed` will lead
5353  to different results.
5354
5355  >>> x = tf.ones([3,5])
5356  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5357  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5358  array([[5., 0., 0., 0., 0.],
5359         [0., 0., 0., 0., 5.],
5360         [0., 0., 0., 0., 5.]], dtype=float32)>
5361  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5362  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5363  array([[5., 0., 0., 0., 0.],
5364         [0., 0., 0., 0., 5.],
5365         [0., 0., 0., 0., 5.]], dtype=float32)>
5366  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[2, 0])
5367  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5368  array([[5., 0., 0., 0., 0.],
5369         [0., 0., 0., 5., 0.],
5370         [0., 0., 0., 0., 0.]], dtype=float32)>
5371  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[2, 0])
5372  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5373  array([[5., 0., 0., 0., 0.],
5374         [0., 0., 0., 5., 0.],
5375         [0., 0., 0., 0., 0.]], dtype=float32)>
5376
5377  Compare the above results to those of `tf.nn.dropout` below. The
5378  second time `tf.nn.dropout` is called with the same seed, it will
5379  give a different output.
5380
5381  >>> tf.random.set_seed(0)
5382  >>> x = tf.ones([3,5])
5383  >>> tf.nn.dropout(x, rate=0.8, seed=1)
5384  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5385  array([[0., 0., 0., 5., 5.],
5386         [0., 5., 0., 5., 0.],
5387         [5., 0., 5., 0., 5.]], dtype=float32)>
5388  >>> tf.nn.dropout(x, rate=0.8, seed=1)
5389  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5390  array([[0., 0., 0., 0., 0.],
5391         [0., 0., 0., 5., 0.],
5392         [0., 0., 0., 0., 0.]], dtype=float32)>
5393  >>> tf.nn.dropout(x, rate=0.8, seed=2)
5394  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5395  array([[0., 0., 0., 0., 0.],
5396         [0., 5., 0., 5., 0.],
5397         [0., 0., 0., 0., 0.]], dtype=float32)>
5398  >>> tf.nn.dropout(x, rate=0.8, seed=2)
5399  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5400  array([[0., 0., 0., 0., 0.],
5401         [5., 0., 5., 0., 5.],
5402         [0., 5., 0., 0., 5.]], dtype=float32)>
5403
5404  The difference between this function and `tf.nn.dropout` is
5405  analogous to the difference between `tf.random.stateless_uniform`
5406  and `tf.random.uniform`. Please see [Random number
5407  generation](https://www.tensorflow.org/guide/random_numbers) guide
5408  for a detailed description of the various RNG systems in TF. As the
5409  guide states, legacy stateful RNG ops like `tf.random.uniform` and
5410  `tf.nn.dropout` are not deprecated yet but highly discouraged,
5411  because their states are hard to control.
5412
5413  By default, each element is kept or dropped independently.  If `noise_shape`
5414  is specified, it must be
5415  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5416  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5417  will make independent decisions. This is useful for dropping whole
5418  channels from an image or sequence. For example:
5419
5420  >>> x = tf.ones([3,10])
5421  >>> tf.nn.experimental.stateless_dropout(x, rate=2/3, noise_shape=[1,10],
5422  ...                                      seed=[1, 0])
5423  <tf.Tensor: shape=(3, 10), dtype=float32, numpy=
5424  array([[3., 0., 0., 0., 0., 0., 0., 3., 0., 3.],
5425         [3., 0., 0., 0., 0., 0., 0., 3., 0., 3.],
5426         [3., 0., 0., 0., 0., 0., 0., 3., 0., 3.]], dtype=float32)>
5427
5428  Args:
5429    x: A floating point tensor.
5430    rate: A scalar `Tensor` with the same type as x. The probability
5431      that each element is dropped. For example, setting rate=0.1 would drop
5432      10% of input elements.
5433    seed: An integer tensor of shape `[2]`. The seed of the random numbers.
5434    rng_alg: The algorithm used to generate the random numbers
5435      (default to `"auto_select"`). See the `alg` argument of
5436      `tf.random.stateless_uniform` for the supported values.
5437    noise_shape: A 1-D integer `Tensor`, representing the
5438      shape for randomly generated keep/drop flags.
5439    name: A name for this operation.
5440
5441  Returns:
5442    A Tensor of the same shape and dtype of `x`.
5443
5444  Raises:
5445    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
5446      tensor. `rate=1` is disallowed, because the output would be all zeros,
5447      which is likely not what was intended.
5448  """
5449  uniform_sampler = functools.partial(
5450      stateless_random_ops.stateless_random_uniform, seed=seed, alg=rng_alg)
5451  def dummy_rng_step():
5452    pass
5453  return _dropout(x=x, rate=rate, noise_shape=noise_shape,
5454                  uniform_sampler=uniform_sampler,
5455                  dummy_rng_step=dummy_rng_step, name=name,
5456                  default_name="stateless_dropout")
5457
5458
5459def _dropout(x, rate, noise_shape, uniform_sampler, dummy_rng_step, name,
5460             default_name):
5461  """Shared implementation of the various dropout functions.
5462
5463  Args:
5464    x: same as the namesake in `dropout_v2`.
5465    rate: same as the namesake in `dropout_v2`.
5466    noise_shape: same as the namesake in `dropout_v2`.
5467    uniform_sampler: a callable of signature `(shape, dtype) ->
5468      Tensor`, used to generate a tensor of uniformly-distributed
5469      random numbers, of the given shape and dtype.
5470    dummy_rng_step: a callable of signature `() -> None`, to make a
5471      dummy RNG call in the fast path. In the fast path where rate is
5472      0, we don't need to generate random numbers, but some samplers
5473      still require you to make an RNG call, to make sure that RNG
5474      states won't depend on whether the fast path is taken.
5475    name: same as the namesake in `dropout_v2`.
5476    default_name: a default name in case `name` is `None`.
5477
5478  Returns:
5479    A Tensor of the same shape and dtype of `x`.
5480  """
5481  with ops.name_scope(name, default_name, [x]) as name:
5482    is_rate_number = isinstance(rate, numbers.Real)
5483    if is_rate_number and (rate < 0 or rate >= 1):
5484      raise ValueError("rate must be a scalar tensor or a float in the "
5485                       "range [0, 1), got %g" % rate)
5486    x = ops.convert_to_tensor(x, name="x")
5487    x_dtype = x.dtype
5488    if not x_dtype.is_floating:
5489      raise ValueError("x has to be a floating point tensor since it's going "
5490                       "to be scaled. Got a %s tensor instead." % x_dtype)
5491    if is_rate_number and rate == 0:
5492      # Fast-path: Return the input immediately if rate is non-tensor & is `0`.
5493      # We trigger this after all error checking
5494      # and after `x` has been converted to a tensor, to prevent inconsistent
5495      # tensor conversions/error raising if rate is changed to/from 0.
5496      #
5497      # We also explicitly call `dummy_rng_step` to make sure
5498      # we don't change the random number generation behavior of
5499      # stateful random ops by entering a fastpath,
5500      # despite not generating a random tensor in the fastpath
5501      dummy_rng_step()
5502      return x
5503
5504    is_executing_eagerly = context.executing_eagerly()
5505    if not tensor_util.is_tf_type(rate):
5506      if is_rate_number:
5507        keep_prob = 1 - rate
5508        scale = 1 / keep_prob
5509        scale = ops.convert_to_tensor(scale, dtype=x_dtype)
5510        ret = gen_math_ops.mul(x, scale)
5511      else:
5512        raise ValueError("rate is neither scalar nor scalar tensor %r" % rate)
5513    else:
5514      rate.get_shape().assert_has_rank(0)
5515      rate_dtype = rate.dtype
5516      if rate_dtype != x_dtype:
5517        if not rate_dtype.is_compatible_with(x_dtype):
5518          raise ValueError(
5519              "`x` has dtype %s which is incomptaible with `rate`'s dtype %s" %
5520              (x_dtype.name, rate_dtype.name))
5521        rate = gen_math_ops.cast(rate, x_dtype, name="rate")
5522      one_tensor = constant_op.constant(1, dtype=x_dtype)
5523      ret = gen_math_ops.real_div(x, gen_math_ops.sub(one_tensor, rate))
5524
5525    noise_shape = _get_noise_shape(x, noise_shape)
5526    # Sample a uniform distribution on [0.0, 1.0) and select values larger
5527    # than or equal to `rate`.
5528    random_tensor = uniform_sampler(shape=noise_shape, dtype=x_dtype)
5529    keep_mask = random_tensor >= rate
5530    ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype))
5531    if not is_executing_eagerly:
5532      ret.set_shape(x.get_shape())
5533    return ret
5534
5535
5536@tf_export("math.top_k", "nn.top_k")
5537@dispatch.add_dispatch_support
5538def top_k(input, k=1, sorted=True, name=None):  # pylint: disable=redefined-builtin
5539  """Finds values and indices of the `k` largest entries for the last dimension.
5540
5541  If the input is a vector (rank=1), finds the `k` largest entries in the vector
5542  and outputs their values and indices as vectors.  Thus `values[j]` is the
5543  `j`-th largest entry in `input`, and its index is `indices[j]`.
5544
5545  >>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
5546  ...                         k=3)
5547  >>> result.values.numpy()
5548  array([99, 98, 96], dtype=int32)
5549  >>> result.indices.numpy()
5550  array([5, 2, 9], dtype=int32)
5551
5552  For matrices (resp. higher rank input), computes the top `k` entries in each
5553  row (resp. vector along the last dimension).  Thus,
5554
5555  >>> input = tf.random.normal(shape=(3,4,5,6))
5556  >>> k = 2
5557  >>> values, indices  = tf.math.top_k(input, k=k)
5558  >>> values.shape.as_list()
5559  [3, 4, 5, 2]
5560  >>>
5561  >>> values.shape == indices.shape == input.shape[:-1] + [k]
5562  True
5563
5564  The indices can be used to `gather` from a tensor who's shape matches `input`.
5565
5566  >>> gathered_values = tf.gather(input, indices, batch_dims=-1)
5567  >>> assert tf.reduce_all(gathered_values == values)
5568
5569  If two elements are equal, the lower-index element appears first.
5570
5571  >>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
5572  ...                        k=3)
5573  >>> result.indices.numpy()
5574  array([0, 1, 3], dtype=int32)
5575
5576  Args:
5577    input: 1-D or higher `Tensor` with last dimension at least `k`.
5578    k: 0-D `int32` `Tensor`.  Number of top elements to look for along the last
5579      dimension (along each row for matrices).
5580    sorted: If true the resulting `k` elements will be sorted by the values in
5581      descending order.
5582    name: Optional name for the operation.
5583
5584  Returns:
5585    A tuple with two named fields:
5586    values: The `k` largest elements along each last dimensional slice.
5587    indices: The indices of `values` within the last dimension of `input`.
5588  """
5589  return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name)
5590
5591
5592def nth_element(input, n, reverse=False, name=None):  # pylint: disable=redefined-builtin
5593  r"""Finds values of the `n`-th smallest value for the last dimension.
5594
5595  Note that n is zero-indexed.
5596
5597  If the input is a vector (rank-1), finds the entries which is the nth-smallest
5598  value in the vector and outputs their values as scalar tensor.
5599
5600  For matrices (resp. higher rank input), computes the entries which is the
5601  nth-smallest value in each row (resp. vector along the last dimension). Thus,
5602
5603      values.shape = input.shape[:-1]
5604
5605  Args:
5606    input: 1-D or higher `Tensor` with last dimension at least `n+1`.
5607    n: A `Tensor` of type `int32`.
5608      0-D. Position of sorted vector to select along the last dimension (along
5609      each row for matrices). Valid range of n is `[0, input.shape[:-1])`
5610    reverse: An optional `bool`. Defaults to `False`.
5611      When set to True, find the nth-largest value in the vector and vice
5612      versa.
5613    name: A name for the operation (optional).
5614
5615  Returns:
5616    A `Tensor`. Has the same type as `input`.
5617    The `n`-th order statistic along each last dimensional slice.
5618  """
5619  return gen_nn_ops.nth_element(input, n, reverse=reverse, name=name)
5620
5621
5622@tf_export(v1=["nn.fractional_max_pool"])
5623@dispatch.add_dispatch_support
5624@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
5625                        "args are deprecated.  Use fractional_max_pool_v2.")
5626def fractional_max_pool(value,
5627                        pooling_ratio,
5628                        pseudo_random=False,
5629                        overlapping=False,
5630                        deterministic=False,
5631                        seed=0,
5632                        seed2=0,
5633                        name=None):   # pylint: disable=redefined-builtin
5634  r"""Performs fractional max pooling on the input.
5635
5636  This is a deprecated version of `fractional_max_pool`.
5637
5638  Fractional max pooling is slightly different than regular max pooling.  In
5639  regular max pooling, you downsize an input set by taking the maximum value of
5640  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
5641  a factor of N, where N is an integer.  Fractional max pooling, as you might
5642  expect from the word "fractional", means that the overall reduction ratio N
5643  does not have to be an integer.
5644
5645  The sizes of the pooling regions are generated randomly but are fairly
5646  uniform.  For example, let's look at the height dimension, and the constraints
5647  on the list of rows that will be pool boundaries.
5648
5649  First we define the following:
5650
5651  1.  input_row_length : the number of rows from the input set
5652  2.  output_row_length : which will be smaller than the input
5653  3.  alpha = input_row_length / output_row_length : our reduction ratio
5654  4.  K = floor(alpha)
5655  5.  row_pooling_sequence : this is the result list of pool boundary rows
5656
5657  Then, row_pooling_sequence should satisfy:
5658
5659  1.  a[0] = 0 : the first value of the sequence is 0
5660  2.  a[end] = input_row_length : the last value of the sequence is the size
5661  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
5662  4.  length(row_pooling_sequence) = output_row_length+1
5663
5664  Args:
5665    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5666    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5667      each dimension of `value`, currently only supports row and col dimension
5668      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5669      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5670      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5671      ratio on height and width dimensions respectively.
5672    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5673      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5674      random fashion. Check (Graham, 2015) for difference between
5675      pseudorandom and random.
5676    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5677      it means when pooling, the values at the boundary of adjacent pooling
5678      cells are used by both cells. For example:
5679      `index  0  1  2  3  4`
5680      `value  20 5  16 3  7`
5681      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5682      twice.  The result would be [20, 16] for fractional max pooling.
5683    deterministic: An optional `bool`.  Deprecated; use `fractional_max_pool_v2`
5684      instead.
5685    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5686      random number generator is seeded by the given seed.  Otherwise it is
5687      seeded by a random seed.
5688    seed2: An optional `int`.  Deprecated; use `fractional_max_pool_v2` instead.
5689    name: A name for the operation (optional).
5690
5691  Returns:
5692  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5693  `col_pooling_sequence`).
5694    output: Output `Tensor` after fractional max pooling.  Has the same type as
5695      `value`.
5696    row_pooling_sequence: A `Tensor` of type `int64`.
5697    col_pooling_sequence: A `Tensor` of type `int64`.
5698
5699  References:
5700    Fractional Max-Pooling:
5701      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5702      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5703  """
5704  return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5705                                        overlapping, deterministic, seed, seed2,
5706                                        name)
5707
5708
5709@tf_export("nn.fractional_max_pool", v1=[])
5710@dispatch.add_dispatch_support
5711def fractional_max_pool_v2(value,
5712                           pooling_ratio,
5713                           pseudo_random=False,
5714                           overlapping=False,
5715                           seed=0,
5716                           name=None):  # pylint: disable=redefined-builtin
5717  r"""Performs fractional max pooling on the input.
5718
5719  Fractional max pooling is slightly different than regular max pooling.  In
5720  regular max pooling, you downsize an input set by taking the maximum value of
5721  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
5722  a factor of N, where N is an integer.  Fractional max pooling, as you might
5723  expect from the word "fractional", means that the overall reduction ratio N
5724  does not have to be an integer.
5725
5726  The sizes of the pooling regions are generated randomly but are fairly
5727  uniform.  For example, let's look at the height dimension, and the constraints
5728  on the list of rows that will be pool boundaries.
5729
5730  First we define the following:
5731
5732  1.  input_row_length : the number of rows from the input set
5733  2.  output_row_length : which will be smaller than the input
5734  3.  alpha = input_row_length / output_row_length : our reduction ratio
5735  4.  K = floor(alpha)
5736  5.  row_pooling_sequence : this is the result list of pool boundary rows
5737
5738  Then, row_pooling_sequence should satisfy:
5739
5740  1.  a[0] = 0 : the first value of the sequence is 0
5741  2.  a[end] = input_row_length : the last value of the sequence is the size
5742  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
5743  4.  length(row_pooling_sequence) = output_row_length+1
5744
5745  Args:
5746    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5747    pooling_ratio: An int or list of `ints` that has length `1`, `2` or `4`.
5748      Pooling ratio for each dimension of `value`, currently only supports row
5749      and col dimension and should be >= 1.0. For example, a valid pooling ratio
5750      looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements must be 1.0
5751      because we don't allow pooling on batch and channels dimensions.  1.44 and
5752      1.73 are pooling ratio on height and width dimensions respectively.
5753    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5754      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5755      random fashion. Check paper (Graham, 2015) for difference between
5756      pseudorandom and random.
5757    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5758      it means when pooling, the values at the boundary of adjacent pooling
5759      cells are used by both cells. For example:
5760      `index  0  1  2  3  4`
5761      `value  20 5  16 3  7`
5762      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5763      twice.  The result would be [20, 16] for fractional max pooling.
5764    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5765      random number generator is seeded by the given seed.  Otherwise it is
5766      seeded by a random seed.
5767    name: A name for the operation (optional).
5768
5769  Returns:
5770  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5771  `col_pooling_sequence`).
5772    output: Output `Tensor` after fractional max pooling.  Has the same type as
5773      `value`.
5774    row_pooling_sequence: A `Tensor` of type `int64`.
5775    col_pooling_sequence: A `Tensor` of type `int64`.
5776
5777  References:
5778    Fractional Max-Pooling:
5779      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5780      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5781  """
5782  if (isinstance(pooling_ratio, (list, tuple))):
5783    if (pooling_ratio[0] != 1.0 or pooling_ratio[-1] != 1.0):
5784      raise ValueError(
5785          "The first and last elements of pooling ratio must be 1.0.")
5786    for element in pooling_ratio:
5787      if element < 1.0:
5788        raise ValueError("pooling_ratio should be >= 1.0.")
5789  elif (isinstance(pooling_ratio, (int, float))):
5790    if pooling_ratio < 1.0:
5791      raise ValueError("pooling_ratio should be >= 1.0.")
5792  else:
5793    raise ValueError("pooling_ratio should be an int or a list of ints.")
5794
5795  pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio")
5796
5797  if seed == 0:
5798    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5799                                          overlapping, deterministic=False,
5800                                          seed=0, seed2=0, name=name)
5801  else:
5802    seed1, seed2 = random_seed.get_seed(seed)
5803    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
5804                                          overlapping, deterministic=True,
5805                                          seed=seed1, seed2=seed2, name=name)
5806
5807
5808@tf_export(v1=["nn.fractional_avg_pool"])
5809@dispatch.add_dispatch_support
5810@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
5811                        "args are deprecated.  Use fractional_avg_pool_v2.")
5812def fractional_avg_pool(value,
5813                        pooling_ratio,
5814                        pseudo_random=False,
5815                        overlapping=False,
5816                        deterministic=False,
5817                        seed=0,
5818                        seed2=0,
5819                        name=None):  # pylint: disable=redefined-builtin
5820  r"""Performs fractional average pooling on the input.
5821
5822  This is a deprecated version of `fractional_avg_pool`.
5823
5824  Fractional average pooling is similar to Fractional max pooling in the pooling
5825  region generation step. The only difference is that after pooling regions are
5826  generated, a mean operation is performed instead of a max operation in each
5827  pooling region.
5828
5829  Args:
5830    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5831    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5832      each dimension of `value`, currently only supports row and col dimension
5833      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5834      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5835      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5836      ratio on height and width dimensions respectively.
5837    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5838      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5839      random fashion. Check paper (Graham, 2015) for difference between
5840      pseudorandom and random.
5841    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5842      it means when pooling, the values at the boundary of adjacent pooling
5843      cells are used by both cells. For example:
5844      `index  0  1  2  3  4`
5845      `value  20 5  16 3  7`
5846      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5847      twice.  The result would be [20, 16] for fractional avg pooling.
5848    deterministic: An optional `bool`.  Deprecated; use `fractional_avg_pool_v2`
5849      instead.
5850    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5851      random number generator is seeded by the given seed.  Otherwise it is
5852      seeded by a random seed.
5853    seed2: An optional `int`.  Deprecated; use `fractional_avg_pool_v2` instead.
5854    name: A name for the operation (optional).
5855
5856  Returns:
5857  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5858  `col_pooling_sequence`).
5859    output: Output `Tensor` after fractional avg pooling.  Has the same type as
5860      `value`.
5861    row_pooling_sequence: A `Tensor` of type `int64`.
5862    col_pooling_sequence: A `Tensor` of type `int64`.
5863
5864  References:
5865    Fractional Max-Pooling:
5866      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5867      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5868  """
5869  return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5870                                        overlapping, deterministic, seed, seed2,
5871                                        name=name)
5872
5873
5874@tf_export("nn.fractional_avg_pool", v1=[])
5875@dispatch.add_dispatch_support
5876def fractional_avg_pool_v2(value,
5877                           pooling_ratio,
5878                           pseudo_random=False,
5879                           overlapping=False,
5880                           seed=0,
5881                           name=None):  # pylint: disable=redefined-builtin
5882  r"""Performs fractional average pooling on the input.
5883
5884  Fractional average pooling is similar to Fractional max pooling in the pooling
5885  region generation step. The only difference is that after pooling regions are
5886  generated, a mean operation is performed instead of a max operation in each
5887  pooling region.
5888
5889  Args:
5890    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5891    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5892      each dimension of `value`, currently only supports row and col dimension
5893      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
5894      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
5895      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
5896      ratio on height and width dimensions respectively.
5897    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
5898      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
5899      random fashion. Check paper (Graham, 2015) for difference between
5900      pseudorandom and random.
5901    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
5902      it means when pooling, the values at the boundary of adjacent pooling
5903      cells are used by both cells. For example:
5904      `index  0  1  2  3  4`
5905      `value  20 5  16 3  7`
5906      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
5907      twice.  The result would be [20, 16] for fractional avg pooling.
5908    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
5909      random number generator is seeded by the given seed.  Otherwise it is
5910      seeded by a random seed.
5911    name: A name for the operation (optional).
5912
5913  Returns:
5914  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
5915  `col_pooling_sequence`).
5916    output: Output `Tensor` after fractional avg pooling.  Has the same type as
5917      `value`.
5918    row_pooling_sequence: A `Tensor` of type `int64`.
5919    col_pooling_sequence: A `Tensor` of type `int64`.
5920
5921  References:
5922    Fractional Max-Pooling:
5923      [Graham, 2015](https://arxiv.org/abs/1412.6071)
5924      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
5925  """
5926  if seed == 0:
5927    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5928                                          overlapping, deterministic=False,
5929                                          seed=0, seed2=0, name=name)
5930  else:
5931    seed1, seed2 = random_seed.get_seed(seed)
5932    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
5933                                          overlapping, deterministic=True,
5934                                          seed=seed1, seed2=seed2, name=name)
5935
5936
5937@ops.RegisterStatistics("Dilation2D", "flops")
5938def _calc_dilation2d_flops(graph, node):
5939  """Calculates the compute resources needed for Dilation2D."""
5940  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5941  input_shape.assert_is_fully_defined()
5942  filter_shape = graph_util.tensor_shape_from_node_def_name(
5943      graph, node.input[1])
5944  filter_shape.assert_is_fully_defined()
5945  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5946  output_shape.assert_is_fully_defined()
5947  filter_height = int(filter_shape[0])
5948  filter_width = int(filter_shape[1])
5949  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5950  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
5951
5952
5953@tf_export(v1=["nn.erosion2d"])
5954@dispatch.add_dispatch_support
5955def erosion2d(value, kernel, strides, rates, padding, name=None):
5956  """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
5957
5958  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
5959  `kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e.,
5960  each input channel is processed independently of the others with its own
5961  structuring function. The `output` tensor has shape
5962  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
5963  output tensor depend on the `padding` algorithm. We currently only support the
5964  default "NHWC" `data_format`.
5965
5966  In detail, the grayscale morphological 2-D erosion is given by:
5967
5968      output[b, y, x, c] =
5969         min_{dy, dx} value[b,
5970                            strides[1] * y - rates[1] * dy,
5971                            strides[2] * x - rates[2] * dx,
5972                            c] -
5973                      kernel[dy, dx, c]
5974
5975  Duality: The erosion of `value` by the `kernel` is equal to the negation of
5976  the dilation of `-value` by the reflected `kernel`.
5977
5978  Args:
5979    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
5980    kernel: A `Tensor`. Must have the same type as `value`.
5981      3-D with shape `[kernel_height, kernel_width, depth]`.
5982    strides: A list of `ints` that has length `>= 4`.
5983      1-D of length 4. The stride of the sliding window for each dimension of
5984      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
5985    rates: A list of `ints` that has length `>= 4`.
5986      1-D of length 4. The input stride for atrous morphological dilation.
5987      Must be: `[1, rate_height, rate_width, 1]`.
5988    padding: A `string` from: `"SAME", "VALID"`.
5989      The type of padding algorithm to use.
5990    name: A name for the operation (optional). If not specified "erosion2d"
5991      is used.
5992
5993  Returns:
5994    A `Tensor`. Has the same type as `value`.
5995    4-D with shape `[batch, out_height, out_width, depth]`.
5996  Raises:
5997    ValueError: If the `value` depth does not match `kernel`' shape, or if
5998      padding is other than `'VALID'` or `'SAME'`.
5999  """
6000  with ops.name_scope(name, "erosion2d", [value, kernel]) as name:
6001    # Reduce erosion to dilation by duality.
6002    return math_ops.negative(
6003        gen_nn_ops.dilation2d(
6004            input=math_ops.negative(value),
6005            filter=array_ops.reverse_v2(kernel, [0, 1]),
6006            strides=strides,
6007            rates=rates,
6008            padding=padding,
6009            name=name))
6010
6011
6012@tf_export("nn.erosion2d", v1=[])
6013@dispatch.add_dispatch_support
6014def erosion2d_v2(value,
6015                 filters,
6016                 strides,
6017                 padding,
6018                 data_format,
6019                 dilations,
6020                 name=None):
6021  """Computes the grayscale erosion of 4-D `value` and 3-D `filters` tensors.
6022
6023  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
6024  `filters` tensor has shape `[filters_height, filters_width, depth]`, i.e.,
6025  each input channel is processed independently of the others with its own
6026  structuring function. The `output` tensor has shape
6027  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
6028  output tensor depend on the `padding` algorithm. We currently only support the
6029  default "NHWC" `data_format`.
6030
6031  In detail, the grayscale morphological 2-D erosion is given by:
6032
6033      output[b, y, x, c] =
6034         min_{dy, dx} value[b,
6035                            strides[1] * y - dilations[1] * dy,
6036                            strides[2] * x - dilations[2] * dx,
6037                            c] -
6038                      filters[dy, dx, c]
6039
6040  Duality: The erosion of `value` by the `filters` is equal to the negation of
6041  the dilation of `-value` by the reflected `filters`.
6042
6043  Args:
6044    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
6045    filters: A `Tensor`. Must have the same type as `value`.
6046      3-D with shape `[filters_height, filters_width, depth]`.
6047    strides: A list of `ints` that has length `>= 4`.
6048      1-D of length 4. The stride of the sliding window for each dimension of
6049      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
6050    padding: A `string` from: `"SAME", "VALID"`.
6051      The type of padding algorithm to use.
6052    data_format: A `string`, only `"NHWC"` is currently supported.
6053    dilations: A list of `ints` that has length `>= 4`.
6054      1-D of length 4. The input stride for atrous morphological dilation.
6055      Must be: `[1, rate_height, rate_width, 1]`.
6056    name: A name for the operation (optional). If not specified "erosion2d"
6057      is used.
6058
6059  Returns:
6060    A `Tensor`. Has the same type as `value`.
6061    4-D with shape `[batch, out_height, out_width, depth]`.
6062
6063  Raises:
6064    ValueError: If the `value` depth does not match `filters`' shape, or if
6065      padding is other than `'VALID'` or `'SAME'`.
6066  """
6067  if data_format != "NHWC":
6068    raise ValueError("Data formats other than NHWC are not yet supported")
6069
6070  with ops.name_scope(name, "erosion2d", [value, filters]) as name:
6071    # Reduce erosion to dilation by duality.
6072    return math_ops.negative(
6073        gen_nn_ops.dilation2d(
6074            input=math_ops.negative(value),
6075            filter=array_ops.reverse_v2(filters, [0, 1]),
6076            strides=strides,
6077            rates=dilations,
6078            padding=padding,
6079            name=name))
6080
6081
6082@tf_export(v1=["math.in_top_k", "nn.in_top_k"])
6083@dispatch.add_dispatch_support
6084def in_top_k(predictions, targets, k, name=None):
6085  r"""Says whether the targets are in the top `K` predictions.
6086
6087  This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
6088  prediction for the target class is finite (not inf, -inf, or nan) and among
6089  the top `k` predictions among all predictions for example `i`. Note that the
6090  behavior of `InTopK` differs from the `TopK` op in its handling of ties; if
6091  multiple classes have the same prediction value and straddle the top-`k`
6092  boundary, all of those classes are considered to be in the top `k`.
6093
6094  More formally, let
6095
6096    \\(predictions_i\\) be the predictions for all classes for example `i`,
6097    \\(targets_i\\) be the target class for example `i`,
6098    \\(out_i\\) be the output for example `i`,
6099
6100  $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
6101
6102  Args:
6103    predictions: A `Tensor` of type `float32`.
6104      A `batch_size` x `classes` tensor.
6105    targets: A `Tensor`. Must be one of the following types: `int32`, `int64`.
6106      A `batch_size` vector of class ids.
6107    k: An `int`. Number of top elements to look at for computing precision.
6108    name: A name for the operation (optional).
6109
6110  Returns:
6111    A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`.
6112  """
6113  with ops.name_scope(name, "in_top_k"):
6114    return gen_nn_ops.in_top_kv2(predictions, targets, k, name=name)
6115
6116
6117@tf_export("math.in_top_k", "nn.in_top_k", v1=[])
6118@dispatch.add_dispatch_support
6119def in_top_k_v2(targets, predictions, k, name=None):
6120  return in_top_k(predictions, targets, k, name)
6121
6122
6123in_top_k_v2.__doc__ = in_top_k.__doc__
6124
6125
6126tf_export(v1=["nn.quantized_avg_pool"])(
6127    dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool))
6128tf_export(v1=["nn.quantized_conv2d"])(
6129    dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d))
6130tf_export(v1=["nn.quantized_relu_x"])(
6131    dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
6132tf_export(v1=["nn.quantized_max_pool"])(
6133    dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))
6134
6135
6136@tf_export("nn.isotonic_regression", v1=[])
6137@dispatch.add_dispatch_support
6138def isotonic_regression(inputs, decreasing=True, axis=-1):
6139  r"""Solves isotonic regression problems along the given axis.
6140
6141  For each vector x, the problem solved is
6142
6143  $$\argmin_{y_1 >= y_2 >= ... >= y_n} \sum_i (x_i - y_i)^2.$$
6144
6145  As the solution is component-wise constant, a second tensor is returned that
6146  encodes the segments. The problems are solved over the given axis.
6147
6148  Consider the following example, where we solve a batch of two problems. The
6149  first input is [3, 1, 2], while the second [1, 3, 4] (as the axis is 1).
6150  >>> x = tf.constant([[3, 1, 2], [1, 3, 4]], dtype=tf.float32)
6151  >>> y, segments = tf.nn.isotonic_regression(x, axis=1)
6152  >>> y  # The solution.
6153  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
6154  array([[3.       , 1.5      , 1.5      ],
6155         [2.6666667, 2.6666667, 2.6666667]], dtype=float32)>
6156
6157  Note that the first solution has two blocks [2] and [1.5, 1.5]. The second
6158  solution is constant, and thus has a single segment. These segments are
6159  exactly what the second returned tensor encodes:
6160
6161  >>> segments
6162  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
6163  array([[0, 1, 1],
6164         [0, 0, 0]], dtype=int32)>
6165
6166
6167  Args:
6168    inputs: A tensor holding the inputs.
6169    decreasing: If set to False, the inequalities in the optimizing constrained
6170      are flipped.
6171    axis: The axis along which the problems should be solved.
6172
6173  Returns:
6174    output: The solutions, same shape as type as the input.
6175    segments: An int32 tensor, same shape as the input indicating the segments
6176      that have the same value. Specifically, those positions that have the same
6177      value correspond to the same segment. These values start at zero, and are
6178      monotonously increasing for each solution.
6179  """
6180  type_promotions = {
6181      # Float types get mapped to themselves, int8/16 to float32, rest to double
6182      dtypes.float32:
6183          dtypes.float32,
6184      dtypes.half:
6185          dtypes.half,
6186      dtypes.bfloat16:
6187          dtypes.bfloat16,
6188      dtypes.int8:
6189          dtypes.float32,
6190      dtypes.int16:
6191          dtypes.float32,
6192  }
6193  inputs = ops.convert_to_tensor(inputs)
6194  try:
6195    output_dtype = type_promotions[inputs.dtype]
6196  except KeyError:
6197    output_dtype = dtypes.float64
6198
6199  def compute_on_matrix(matrix, name=None):
6200    iso_fn = functools.partial(
6201        gen_nn_ops.isotonic_regression, output_dtype=output_dtype, name=name)
6202    if decreasing:
6203      return iso_fn(matrix)
6204    else:
6205      output, segments = iso_fn(-matrix)
6206      return -output, segments
6207
6208  return _wrap_2d_function(inputs, compute_on_matrix, axis)
6209