• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Primitive Neural Net (NN) Operations.
16
17## Notes on padding
18
19Several neural network operations, such as `tf.nn.conv2d` and
20`tf.nn.max_pool2d`, take a `padding` parameter, which controls how the input is
21padded before running the operation. The input is padded by inserting values
22(typically zeros) before and after the tensor in each spatial dimension. The
23`padding` parameter can either be the string `'VALID'`, which means use no
24padding, or `'SAME'` which adds padding according to a formula which is
25described below. Certain ops also allow the amount of padding per dimension to
26be explicitly specified by passing a list to `padding`.
27
28In the case of convolutions, the input is padded with zeros. In case of pools,
29the padded input values are ignored. For example, in a max pool, the sliding
30window ignores padded values, which is equivalent to the padded values being
31`-infinity`.
32
33### `'VALID'` padding
34
35Passing `padding='VALID'` to an op causes no padding to be used. This causes the
36output size to typically be smaller than the input size, even when the stride is
37one. In the 2D case, the output size is computed as:
38
39```python
40out_height = ceil((in_height - filter_height + 1) / stride_height)
41out_width  = ceil((in_width - filter_width + 1) / stride_width)
42```
43
44The 1D and 3D cases are similar. Note `filter_height` and `filter_width` refer
45to the filter size after dilations (if any) for convolutions, and refer to the
46window size for pools.
47
48### `'SAME'` padding
49
50With `'SAME'` padding, padding is applied to each spatial dimension. When the
51strides are 1, the input is padded such that the output size is the same as the
52input size. In the 2D case, the output size is computed as:
53
54```python
55out_height = ceil(in_height / stride_height)
56out_width  = ceil(in_width / stride_width)
57```
58
59The amount of padding used is the smallest amount that results in the output
60size. The formula for the total amount of padding per dimension is:
61
62```python
63if (in_height % strides[1] == 0):
64  pad_along_height = max(filter_height - stride_height, 0)
65else:
66  pad_along_height = max(filter_height - (in_height % stride_height), 0)
67if (in_width % strides[2] == 0):
68  pad_along_width = max(filter_width - stride_width, 0)
69else:
70  pad_along_width = max(filter_width - (in_width % stride_width), 0)
71```
72
73Finally, the padding on the top, bottom, left and right are:
74
75```python
76pad_top = pad_along_height // 2
77pad_bottom = pad_along_height - pad_top
78pad_left = pad_along_width // 2
79pad_right = pad_along_width - pad_left
80```
81
82Note that the division by 2 means that there might be cases when the padding on
83both sides (top vs bottom, right vs left) are off by one. In this case, the
84bottom and right sides always get the one additional padded pixel. For example,
85when pad_along_height is 5, we pad 2 pixels at the top and 3 pixels at the
86bottom. Note that this is different from existing libraries such as PyTorch and
87Caffe, which explicitly specify the number of padded pixels and always pad the
88same number of pixels on both sides.
89
90Here is an example of `'SAME'` padding:
91
92>>> in_height = 5
93>>> filter_height = 3
94>>> stride_height = 2
95>>>
96>>> in_width = 2
97>>> filter_width = 2
98>>> stride_width = 1
99>>>
100>>> inp = tf.ones((2, in_height, in_width, 2))
101>>> filter = tf.ones((filter_height, filter_width, 2, 2))
102>>> strides = [stride_height, stride_width]
103>>> output = tf.nn.conv2d(inp, filter, strides, padding='SAME')
104>>> output.shape[1]  # output_height: ceil(5 / 2)
1053
106>>> output.shape[2] # output_width: ceil(2 / 1)
1072
108
109### Explicit padding
110
111Certain ops, like `tf.nn.conv2d`, also allow a list of explicit padding amounts
112to be passed to the `padding` parameter. This list is in the same format as what
113is passed to `tf.pad`, except the padding must be a nested list, not a tensor.
114For example, in the 2D case, the list is in the format `[[0, 0], [pad_top,
115pad_bottom], [pad_left, pad_right], [0, 0]]` when `data_format` is its default
116value of `'NHWC'`. The two `[0, 0]` pairs  indicate the batch and channel
117dimensions have no padding, which is required, as only spatial dimensions can
118have padding.
119
120For example:
121
122>>> inp = tf.ones((1, 3, 3, 1))
123>>> filter = tf.ones((2, 2, 1, 1))
124>>> strides = [1, 1]
125>>> padding = [[0, 0], [1, 2], [0, 1], [0, 0]]
126>>> output = tf.nn.conv2d(inp, filter, strides, padding=padding)
127>>> tuple(output.shape)
128(1, 5, 3, 1)
129>>> # Equivalently, tf.pad can be used, since convolutions pad with zeros.
130>>> inp = tf.pad(inp, padding)
131>>> # 'VALID' means to use no padding in conv2d (we already padded inp)
132>>> output2 = tf.nn.conv2d(inp, filter, strides, padding='VALID')
133>>> tf.debugging.assert_equal(output, output2)
134
135### Difference between convolution and pooling layers
136How padding is used in convolution layers and pooling layers is different. For
137convolution layers, padding is filled with values of zero, and padding is
138multiplied with kernels. For pooling layers, padding is excluded from the
139computation. For example when applying average pooling to a 4x4 grid, how much
140padding is added will not impact the output. Here is an example that
141demonstrates the difference.
142
143>>> x_in = np.array([[
144...   [[2], [2]],
145...   [[1], [1]],
146...   [[1], [1]]]])
147>>> kernel_in = np.array([  # simulate the avg_pool with conv2d
148...  [ [[0.25]], [[0.25]] ],
149...  [ [[0.25]], [[0.25]] ]])
150>>> x = tf.constant(x_in, dtype=tf.float32)
151>>> kernel = tf.constant(kernel_in, dtype=tf.float32)
152>>> conv_out = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='SAME')
153>>> pool_out = tf.nn.avg_pool(x, [2, 2], strides=[1, 1, 1, 1], padding='SAME')
154>>> print(conv_out.shape, pool_out.shape)
155(1, 3, 2, 1) (1, 3, 2, 1)
156>>> tf.reshape(conv_out, [3, 2]).numpy()  # conv2d takes account of padding
157array([[1.5 , 0.75],
158       [1.  , 0.5 ],
159       [0.5 , 0.25]], dtype=float32)
160>>> tf.reshape(pool_out, [3, 2]).numpy()  # avg_pool excludes padding
161array([[1.5, 1.5],
162       [1. , 1. ],
163       [1. , 1. ]], dtype=float32)
164
165"""
166
167import functools
168import numbers
169
170import numpy as np
171
172from tensorflow.python.eager import context
173from tensorflow.python.framework import config
174from tensorflow.python.framework import constant_op
175from tensorflow.python.framework import dtypes
176from tensorflow.python.framework import errors_impl
177from tensorflow.python.framework import graph_util
178from tensorflow.python.framework import ops
179from tensorflow.python.framework import random_seed
180from tensorflow.python.framework import tensor_shape
181from tensorflow.python.framework import tensor_util
182from tensorflow.python.ops import array_ops
183from tensorflow.python.ops import check_ops
184from tensorflow.python.ops import gen_math_ops
185from tensorflow.python.ops import gen_nn_ops
186from tensorflow.python.ops import math_ops
187from tensorflow.python.ops import random_ops
188from tensorflow.python.ops import stateless_random_ops
189from tensorflow.python.ops import variables as variables_lib
190# go/tf-wildcard-import
191# pylint: disable=wildcard-import
192from tensorflow.python.ops.gen_nn_ops import *
193# pylint: enable=wildcard-import
194from tensorflow.python.platform import device_context
195from tensorflow.python.util import deprecation
196from tensorflow.python.util import dispatch
197from tensorflow.python.util.compat import collections_abc
198from tensorflow.python.util.deprecation import deprecated_args
199from tensorflow.python.util.deprecation import deprecated_argument_lookup
200
201from tensorflow.python.util.tf_export import tf_export
202
203# Aliases for some automatically-generated names.
204local_response_normalization = gen_nn_ops.lrn
205
206# pylint: disable=protected-access
207# pylint: disable=g-classes-have-attributes
208
209# Acceptable channels last formats (robust to H, W, D order).
210_CHANNELS_LAST_FORMATS = frozenset({
211    "NWC", "NHC", "NHWC", "NWHC", "NDHWC", "NDWHC", "NHDWC", "NHWDC", "NWDHC",
212    "NWHDC"
213})
214
215
216def _get_sequence(value, n, channel_index, name):
217  """Formats a value input for gen_nn_ops."""
218  # Performance is fast-pathed for common cases:
219  # `None`, `list`, `tuple` and `int`.
220  if value is None:
221    return [1] * (n + 2)
222
223  # Always convert `value` to a `list`.
224  if isinstance(value, list):
225    pass
226  elif isinstance(value, tuple):
227    value = list(value)
228  elif isinstance(value, int):
229    value = [value]
230  elif not isinstance(value, collections_abc.Sized):
231    value = [value]
232  else:
233    value = list(value)  # Try casting to a list.
234
235  len_value = len(value)
236
237  # Fully specified, including batch and channel dims.
238  if len_value == n + 2:
239    return value
240
241  # Apply value to spatial dims only.
242  if len_value == 1:
243    value = value * n  # Broadcast to spatial dimensions.
244  elif len_value != n:
245    raise ValueError(f"{name} should be of length 1, {n} or {n + 2}. "
246                     f"Received: {name}={value} of length {len_value}")
247
248  # Add batch and channel dims (always 1).
249  if channel_index == 1:
250    return [1, 1] + value
251  else:
252    return [1] + value + [1]
253
254
255def _non_atrous_convolution(
256    input,  # pylint: disable=redefined-builtin
257    filter,  # pylint: disable=redefined-builtin
258    padding,
259    data_format=None,  # pylint: disable=redefined-builtin
260    strides=None,
261    name=None):
262  """Computes sums of N-D convolutions (actually cross correlation).
263
264  It is required that 1 <= N <= 3.
265
266  This is used to implement the more generic `convolution` function, which
267  extends the interface of this function with a `dilation_rate` parameter.
268
269  Args:
270
271    input: Rank N+2 tensor of type T of shape
272      `[batch_size] + input_spatial_shape + [in_channels]` if `data_format`
273      does not start with `"NC"`, or
274      `[batch_size, in_channels] + input_spatial_shape` if `data_format` starts
275      with `"NC"`.
276    filter: Rank N+2 tensor of type T of shape
277      `filter_spatial_shape + [in_channels, out_channels]`.  Rank of either
278      `input` or `filter` must be known.
279    padding: Padding method to use, must be either "VALID" or "SAME".
280    data_format: A string or None.  Specifies whether the channel dimension of
281      the `input` and output is the last dimension (default, or if `data_format`
282      does not start with "NC"), or the second dimension (if `data_format`
283      starts with "NC").  For N=1, the valid values are "NWC" (default) and
284      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
285      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
286    strides: Sequence of N positive integers, defaults to `[1] * N`.
287    name: Name prefix to use.
288
289  Returns:
290    Rank N+2 tensor of type T of shape
291    `[batch_size] + output_spatial_shape + [out_channels]`, where
292    if padding == "SAME":
293      output_spatial_shape = input_spatial_shape
294    if padding == "VALID":
295      output_spatial_shape = input_spatial_shape - filter_spatial_shape + 1.
296
297  Raises:
298    ValueError: if ranks are incompatible.
299
300  """
301  with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope:
302    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
303    input_shape = input.shape
304    filter = ops.convert_to_tensor(filter, name="filter")  # pylint: disable=redefined-builtin
305    filter_shape = filter.shape
306    op = _NonAtrousConvolution(
307        input_shape,
308        filter_shape=filter_shape,
309        padding=padding,
310        data_format=data_format,
311        strides=strides,
312        name=scope)
313    return op(input, filter)
314
315
316class _NonAtrousConvolution:
317  """Helper class for _non_atrous_convolution.
318
319  Note that this class assumes that shapes of input and filter passed to
320  `__call__` are compatible with `input_shape` and filter_shape passed to the
321  constructor.
322
323  Args:
324    input_shape: static input shape, i.e. input.shape.
325    filter_shape: static filter shape, i.e. filter.shape.
326    padding: see _non_atrous_convolution.
327    data_format: see _non_atrous_convolution.
328    strides: see _non_atrous_convolution.
329    name: see _non_atrous_convolution.
330    num_batch_dims: (Optional.)  The number of batch dimensions in the input;
331     if not provided, the default of `1` is used.
332  """
333
334  def __init__(
335      self,
336      input_shape,
337      filter_shape,
338      padding,
339      data_format=None,
340      strides=None,
341      name=None,
342      num_batch_dims=1):
343    # filter shape is always rank num_spatial_dims + 2
344    # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1
345    if input_shape.ndims is not None:
346      filter_shape = filter_shape.with_rank(
347          input_shape.ndims - num_batch_dims + 1)
348    self.padding = padding
349    self.name = name
350    # input shape is == num_spatial_dims + num_batch_dims + 1
351    # and filter_shape is always rank num_spatial_dims + 2
352    if filter_shape.ndims is not None:
353      input_shape = input_shape.with_rank(
354          filter_shape.ndims + num_batch_dims - 1)
355    if input_shape.ndims is None:
356      raise ValueError(
357          "Rank of convolution must be known. "
358          f"Received: input_shape={input_shape} of rank {input_shape.rank}")
359    if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5:
360      raise ValueError(
361          "`input_shape.rank - num_batch_dims + 1` must be at least 3 and at "
362          f"most 5. Received: input_shape.rank={input_shape.rank} and "
363          f"num_batch_dims={num_batch_dims}")
364    conv_dims = input_shape.ndims - num_batch_dims - 1
365    if strides is None:
366      strides = [1] * conv_dims
367    elif len(strides) != conv_dims:
368      raise ValueError(
369          f"`len(strides)` should be {conv_dims}. "
370          f"Received: strides={strides} of length {len(strides)}")
371    if conv_dims == 1:
372      # conv1d uses the 2-d data format names
373      if data_format is None:
374        data_format = "NWC"
375      elif data_format not in {"NCW", "NWC", "NCHW", "NHWC"}:
376        raise ValueError("`data_format` must be 'NWC' or 'NCW'. "
377                         f"Received: data_format={data_format}")
378      self.strides = strides[0]
379      self.data_format = data_format
380      self.conv_op = self._conv1d
381    elif conv_dims == 2:
382      if data_format is None or data_format == "NHWC":
383        data_format = "NHWC"
384        strides = [1] + list(strides) + [1]
385      elif data_format == "NCHW":
386        strides = [1, 1] + list(strides)
387      else:
388        raise ValueError("`data_format` must be 'NHWC' or 'NCHW'. "
389                         f"Received: data_format={data_format}")
390      self.strides = strides
391      self.data_format = data_format
392      self.conv_op = conv2d
393    elif conv_dims == 3:
394      if data_format is None or data_format == "NDHWC":
395        strides = [1] + list(strides) + [1]
396      elif data_format == "NCDHW":
397        strides = [1, 1] + list(strides)
398      else:
399        raise ValueError("`data_format` must be 'NDHWC' or 'NCDHW'. "
400                         f"Received: data_format={data_format}")
401      self.strides = strides
402      self.data_format = data_format
403      self.conv_op = _conv3d_expanded_batch
404
405  # Note that we need this adapter since argument names for conv1d don't match
406  # those for gen_nn_ops.conv2d and gen_nn_ops.conv3d.
407  # pylint: disable=redefined-builtin
408  def _conv1d(self, input, filter, strides, padding, data_format, name):
409    return conv1d(
410        value=input,
411        filters=filter,
412        stride=strides,
413        padding=padding,
414        data_format=data_format,
415        name=name)
416  # pylint: enable=redefined-builtin
417
418  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
419    return self.conv_op(
420        input=inp,
421        filter=filter,
422        strides=self.strides,
423        padding=self.padding,
424        data_format=self.data_format,
425        name=self.name)
426
427
428def squeeze_batch_dims(inp, op, inner_rank, name=None):
429  """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
430
431  Where `squeeze_batch` reshapes `inp` to shape
432  `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
433  and `unsqueeze_batch` does the reverse reshape but on the output.
434
435  Args:
436    inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
437      is length `inner_rank`.
438    op: A callable that takes a single input tensor and returns a single.
439      output tensor.
440    inner_rank: A python integer.
441    name: A string.
442
443  Returns:
444    `unsqueeze_batch_op(squeeze_batch(inp))`.
445  """
446  with ops.name_scope(name, "squeeze_batch_dims", [inp]):
447    inp = ops.convert_to_tensor(inp, name="input")
448    shape = inp.shape
449
450    inner_shape = shape[-inner_rank:]
451    if not inner_shape.is_fully_defined():
452      inner_shape = array_ops.shape(inp)[-inner_rank:]
453
454    batch_shape = shape[:-inner_rank]
455    if not batch_shape.is_fully_defined():
456      batch_shape = array_ops.shape(inp)[:-inner_rank]
457
458    if isinstance(inner_shape, tensor_shape.TensorShape):
459      inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list())
460    else:
461      inp_reshaped = array_ops.reshape(
462          inp, array_ops.concat(([-1], inner_shape), axis=-1))
463
464    out_reshaped = op(inp_reshaped)
465
466    out_inner_shape = out_reshaped.shape[-inner_rank:]
467    if not out_inner_shape.is_fully_defined():
468      out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:]
469
470    out = array_ops.reshape(
471        out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1))
472
473    out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
474    return out
475
476
477@tf_export("nn.dilation2d", v1=[])
478@dispatch.add_dispatch_support
479def dilation2d_v2(
480    input,   # pylint: disable=redefined-builtin
481    filters,  # pylint: disable=redefined-builtin
482    strides,
483    padding,
484    data_format,
485    dilations,
486    name=None):
487  """Computes the grayscale dilation of 4-D `input` and 3-D `filters` tensors.
488
489  The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
490  `filters` tensor has shape `[filter_height, filter_width, depth]`, i.e., each
491  input channel is processed independently of the others with its own
492  structuring function. The `output` tensor has shape
493  `[batch, out_height, out_width, depth]`. The spatial dimensions of the output
494  tensor depend on the `padding` algorithm. We currently only support the
495  default "NHWC" `data_format`.
496
497  In detail, the grayscale morphological 2-D dilation is the max-sum correlation
498  (for consistency with `conv2d`, we use unmirrored filters):
499
500      output[b, y, x, c] =
501         max_{dy, dx} input[b,
502                            strides[1] * y + rates[1] * dy,
503                            strides[2] * x + rates[2] * dx,
504                            c] +
505                      filters[dy, dx, c]
506
507  Max-pooling is a special case when the filter has size equal to the pooling
508  kernel size and contains all zeros.
509
510  Note on duality: The dilation of `input` by the `filters` is equal to the
511  negation of the erosion of `-input` by the reflected `filters`.
512
513  Args:
514    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
515      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
516      `uint32`, `uint64`.
517      4-D with shape `[batch, in_height, in_width, depth]`.
518    filters: A `Tensor`. Must have the same type as `input`.
519      3-D with shape `[filter_height, filter_width, depth]`.
520    strides: A list of `ints` that has length `>= 4`.
521      The stride of the sliding window for each dimension of the input
522      tensor. Must be: `[1, stride_height, stride_width, 1]`.
523    padding: A `string` from: `"SAME", "VALID"`.
524      The type of padding algorithm to use. See
525      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
526      for more information.
527    data_format: A `string`, only `"NHWC"` is currently supported.
528    dilations: A list of `ints` that has length `>= 4`.
529      The input stride for atrous morphological dilation. Must be:
530      `[1, rate_height, rate_width, 1]`.
531    name: A name for the operation (optional).
532
533  Returns:
534    A `Tensor`. Has the same type as `input`.
535  """
536  if data_format != "NHWC":
537    raise ValueError("`data_format` values other  than 'NHWC' are not "
538                     f"supported. Received: data_format={data_format}")
539
540  return gen_nn_ops.dilation2d(input=input,
541                               filter=filters,
542                               strides=strides,
543                               rates=dilations,
544                               padding=padding,
545                               name=name)
546
547
548@tf_export(v1=["nn.dilation2d"])
549@dispatch.add_dispatch_support
550def dilation2d_v1(  # pylint: disable=missing-docstring
551    input,  # pylint: disable=redefined-builtin
552    filter=None,  # pylint: disable=redefined-builtin
553    strides=None,
554    rates=None,
555    padding=None,
556    name=None,
557    filters=None,
558    dilations=None):
559  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
560  rates = deprecated_argument_lookup("dilations", dilations, "rates", rates)
561  return gen_nn_ops.dilation2d(input, filter, strides, rates, padding, name)
562
563
564dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__
565
566
567@tf_export("nn.with_space_to_batch")
568@dispatch.add_dispatch_support
569def with_space_to_batch(
570    input,  # pylint: disable=redefined-builtin
571    dilation_rate,
572    padding,
573    op,
574    filter_shape=None,
575    spatial_dims=None,
576    data_format=None):
577  """Performs `op` on the space-to-batch representation of `input`.
578
579  This has the effect of transforming sliding window operations into the
580  corresponding "atrous" operation in which the input is sampled at the
581  specified `dilation_rate`.
582
583  In the special case that `dilation_rate` is uniformly 1, this simply returns:
584
585    op(input, num_spatial_dims, padding)
586
587  Otherwise, it returns:
588
589    batch_to_space_nd(
590      op(space_to_batch_nd(input, adjusted_dilation_rate, adjusted_paddings),
591         num_spatial_dims,
592         "VALID")
593      adjusted_dilation_rate,
594      adjusted_crops),
595
596  where:
597
598    adjusted_dilation_rate is an int64 tensor of shape [max(spatial_dims)],
599    adjusted_{paddings,crops} are int64 tensors of shape [max(spatial_dims), 2]
600
601  defined as follows:
602
603  We first define two int64 tensors `paddings` and `crops` of shape
604  `[num_spatial_dims, 2]` based on the value of `padding` and the spatial
605  dimensions of the `input`:
606
607  If `padding = "VALID"`, then:
608
609    paddings, crops = required_space_to_batch_paddings(
610      input_shape[spatial_dims],
611      dilation_rate)
612
613  If `padding = "SAME"`, then:
614
615    dilated_filter_shape =
616      filter_shape + (filter_shape - 1) * (dilation_rate - 1)
617
618    paddings, crops = required_space_to_batch_paddings(
619      input_shape[spatial_dims],
620      dilation_rate,
621      [(dilated_filter_shape - 1) // 2,
622       dilated_filter_shape - 1 - (dilated_filter_shape - 1) // 2])
623
624  Because `space_to_batch_nd` and `batch_to_space_nd` assume that the spatial
625  dimensions are contiguous starting at the second dimension, but the specified
626  `spatial_dims` may not be, we must adjust `dilation_rate`, `paddings` and
627  `crops` in order to be usable with these operations.  For a given dimension,
628  if the block size is 1, and both the starting and ending padding and crop
629  amounts are 0, then space_to_batch_nd effectively leaves that dimension alone,
630  which is what is needed for dimensions not part of `spatial_dims`.
631  Furthermore, `space_to_batch_nd` and `batch_to_space_nd` handle this case
632  efficiently for any number of leading and trailing dimensions.
633
634  For 0 <= i < len(spatial_dims), we assign:
635
636    adjusted_dilation_rate[spatial_dims[i] - 1] = dilation_rate[i]
637    adjusted_paddings[spatial_dims[i] - 1, :] = paddings[i, :]
638    adjusted_crops[spatial_dims[i] - 1, :] = crops[i, :]
639
640  All unassigned values of `adjusted_dilation_rate` default to 1, while all
641  unassigned values of `adjusted_paddings` and `adjusted_crops` default to 0.
642
643  Note in the case that `dilation_rate` is not uniformly 1, specifying "VALID"
644  padding is equivalent to specifying `padding = "SAME"` with a filter_shape of
645  `[1]*N`.
646
647  Advanced usage. Note the following optimization: A sequence of
648  `with_space_to_batch` operations with identical (not uniformly 1)
649  `dilation_rate` parameters and "VALID" padding
650
651    net = with_space_to_batch(net, dilation_rate, "VALID", op_1)
652    ...
653    net = with_space_to_batch(net, dilation_rate, "VALID", op_k)
654
655  can be combined into a single `with_space_to_batch` operation as follows:
656
657    def combined_op(converted_input, num_spatial_dims, _):
658      result = op_1(converted_input, num_spatial_dims, "VALID")
659      ...
660      result = op_k(result, num_spatial_dims, "VALID")
661
662    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
663
664  This eliminates the overhead of `k-1` calls to `space_to_batch_nd` and
665  `batch_to_space_nd`.
666
667  Similarly, a sequence of `with_space_to_batch` operations with identical (not
668  uniformly 1) `dilation_rate` parameters, "SAME" padding, and odd filter
669  dimensions
670
671    net = with_space_to_batch(net, dilation_rate, "SAME", op_1, filter_shape_1)
672    ...
673    net = with_space_to_batch(net, dilation_rate, "SAME", op_k, filter_shape_k)
674
675  can be combined into a single `with_space_to_batch` operation as follows:
676
677    def combined_op(converted_input, num_spatial_dims, _):
678      result = op_1(converted_input, num_spatial_dims, "SAME")
679      ...
680      result = op_k(result, num_spatial_dims, "SAME")
681
682    net = with_space_to_batch(net, dilation_rate, "VALID", combined_op)
683
684  Args:
685    input: Tensor of rank > max(spatial_dims).
686    dilation_rate: int32 Tensor of *known* shape [num_spatial_dims].
687    padding: str constant equal to "VALID" or "SAME"
688    op: Function that maps (input, num_spatial_dims, padding) -> output
689    filter_shape: If padding = "SAME", specifies the shape of the convolution
690      kernel/pooling window as an integer Tensor of shape [>=num_spatial_dims].
691      If padding = "VALID", filter_shape is ignored and need not be specified.
692    spatial_dims: Monotonically increasing sequence of `num_spatial_dims`
693      integers (which are >= 1) specifying the spatial dimensions of `input`
694      and output.  Defaults to: `range(1, num_spatial_dims+1)`.
695    data_format: A string or None.  Specifies whether the channel dimension of
696      the `input` and output is the last dimension (default, or if `data_format`
697      does not start with "NC"), or the second dimension (if `data_format`
698      starts with "NC").  For N=1, the valid values are "NWC" (default) and
699      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
700      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
701
702  Returns:
703    The output Tensor as described above, dimensions will vary based on the op
704    provided.
705
706  Raises:
707    ValueError: if `padding` is invalid or the arguments are incompatible.
708    ValueError: if `spatial_dims` are invalid.
709  """
710  input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
711  input_shape = input.shape
712
713  def build_op(num_spatial_dims, padding):
714    return lambda inp, _: op(inp, num_spatial_dims, padding)
715
716  new_op = _WithSpaceToBatch(
717      input_shape,
718      dilation_rate,
719      padding,
720      build_op,
721      filter_shape=filter_shape,
722      spatial_dims=spatial_dims,
723      data_format=data_format)
724  return new_op(input, None)
725
726
727class _WithSpaceToBatch:
728  """Helper class for with_space_to_batch.
729
730  Note that this class assumes that shapes of input and filter passed to
731  `__call__` are compatible with `input_shape`, `filter_shape`, and
732  `spatial_dims` passed to the constructor.
733
734  Arguments
735    input_shape: static shape of input. i.e. input.shape.
736    dilation_rate: see `with_space_to_batch`.
737    padding: see `with_space_to_batch`.
738    build_op: Function that maps (num_spatial_dims, paddings) -> (function that
739      maps (input, filter) -> output).
740    filter_shape: see `with_space_to_batch`.
741    spatial_dims: `see with_space_to_batch`.
742    data_format: see `with_space_to_batch`.
743    num_batch_dims: (Optional).  Number of batch dims in `input_shape`.
744  """
745
746  def __init__(self,
747               input_shape,
748               dilation_rate,
749               padding,
750               build_op,
751               filter_shape=None,
752               spatial_dims=None,
753               data_format=None,
754               num_batch_dims=1):
755    """Helper class for _with_space_to_batch."""
756    dilation_rate = ops.convert_to_tensor(
757        dilation_rate, dtypes.int32, name="dilation_rate")
758    if dilation_rate.shape.ndims not in (None, 1):
759      raise ValueError(
760          "`dilation_rate.shape.rank` must be 1. Received: "
761          f"dilation_rate={dilation_rate} of rank {dilation_rate.shape.rank}")
762
763    if not dilation_rate.shape.is_fully_defined():
764      raise ValueError(
765          "`dilation_rate.shape` must be fully defined. Received: "
766          f"dilation_rate={dilation_rate} with shape "
767          f"{dilation_rate.shape}")
768
769    num_spatial_dims = dilation_rate.shape.dims[0].value
770
771    if data_format is not None and data_format.startswith("NC"):
772      starting_spatial_dim = num_batch_dims + 1
773    else:
774      starting_spatial_dim = num_batch_dims
775
776    if spatial_dims is None:
777      spatial_dims = range(starting_spatial_dim,
778                           num_spatial_dims + starting_spatial_dim)
779    orig_spatial_dims = list(spatial_dims)
780    spatial_dims = sorted(set(int(x) for x in orig_spatial_dims))
781    if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims):
782      raise ValueError(
783          "`spatial_dims` must be a monotonically increasing sequence of "
784          f"positive integers. Received: spatial_dims={orig_spatial_dims}")
785
786    if data_format is not None and data_format.startswith("NC"):
787      expected_input_rank = spatial_dims[-1]
788    else:
789      expected_input_rank = spatial_dims[-1] + 1
790
791    try:
792      input_shape.with_rank_at_least(expected_input_rank)
793    except ValueError:
794      raise ValueError(
795          f"`input.shape.rank` must be at least {expected_input_rank}. "
796          f"Received: input.shape={input_shape} with rank {input_shape.rank}")
797
798    const_rate = tensor_util.constant_value(dilation_rate)
799    rate_or_const_rate = dilation_rate
800    if const_rate is not None:
801      rate_or_const_rate = const_rate
802      if np.any(const_rate < 1):
803        raise ValueError(
804            "`dilation_rate` must be positive. "
805            f"Received: dilation_rate={const_rate}")
806      if np.all(const_rate == 1):
807        self.call = build_op(num_spatial_dims, padding)
808        return
809
810    padding, explicit_paddings = convert_padding(padding)
811
812    # We have two padding contributions. The first is used for converting "SAME"
813    # to "VALID". The second is required so that the height and width of the
814    # zero-padded value tensor are multiples of rate.
815
816    # Padding required to reduce to "VALID" convolution
817    if padding == "SAME":
818      if filter_shape is None:
819        raise ValueError(
820            "`filter_shape` must be specified for `padding='SAME'`. "
821            f"Received: filter_shape={filter_shape} and padding={padding}")
822      filter_shape = ops.convert_to_tensor(filter_shape, name="filter_shape")
823      const_filter_shape = tensor_util.constant_value(filter_shape)
824      if const_filter_shape is not None:
825        filter_shape = const_filter_shape
826        self.base_paddings = _with_space_to_batch_base_paddings(
827            const_filter_shape, num_spatial_dims, rate_or_const_rate)
828      else:
829        self.num_spatial_dims = num_spatial_dims
830        self.rate_or_const_rate = rate_or_const_rate
831        self.base_paddings = None
832    elif padding == "VALID":
833      self.base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
834    elif padding == "EXPLICIT":
835      base_paddings = (np.array(explicit_paddings)
836                       .reshape([num_spatial_dims + 2, 2]))
837      # Remove batch and channel dimensions
838      if data_format is not None and data_format.startswith("NC"):
839        self.base_paddings = base_paddings[2:]
840      else:
841        self.base_paddings = base_paddings[1:-1]
842    else:
843      raise ValueError("`padding` must be one of 'SAME' or 'VALID'. "
844                       f"Received: padding={padding}")
845
846    self.input_shape = input_shape
847    self.spatial_dims = spatial_dims
848    self.dilation_rate = dilation_rate
849    self.data_format = data_format
850    self.op = build_op(num_spatial_dims, "VALID")
851    self.call = self._with_space_to_batch_call
852
853  def _with_space_to_batch_call(self, inp, filter):  # pylint: disable=redefined-builtin
854    """Call functionality for with_space_to_batch."""
855    # Handle input whose shape is unknown during graph creation.
856    input_spatial_shape = None
857    input_shape = self.input_shape
858    spatial_dims = self.spatial_dims
859    if input_shape.ndims is not None:
860      input_shape_list = input_shape.as_list()
861      input_spatial_shape = [input_shape_list[i] for i in spatial_dims]
862    if input_spatial_shape is None or None in input_spatial_shape:
863      input_shape_tensor = array_ops.shape(inp)
864      input_spatial_shape = array_ops.stack(
865          [input_shape_tensor[i] for i in spatial_dims])
866
867    base_paddings = self.base_paddings
868    if base_paddings is None:
869      # base_paddings could not be computed at build time since static filter
870      # shape was not fully defined.
871      filter_shape = array_ops.shape(filter)
872      base_paddings = _with_space_to_batch_base_paddings(
873          filter_shape, self.num_spatial_dims, self.rate_or_const_rate)
874
875    paddings, crops = array_ops.required_space_to_batch_paddings(
876        input_shape=input_spatial_shape,
877        base_paddings=base_paddings,
878        block_shape=self.dilation_rate)
879
880    dilation_rate = _with_space_to_batch_adjust(self.dilation_rate, 1,
881                                                spatial_dims)
882    paddings = _with_space_to_batch_adjust(paddings, 0, spatial_dims)
883    crops = _with_space_to_batch_adjust(crops, 0, spatial_dims)
884    input_converted = array_ops.space_to_batch_nd(
885        input=inp, block_shape=dilation_rate, paddings=paddings)
886
887    result = self.op(input_converted, filter)
888
889    result_converted = array_ops.batch_to_space_nd(
890        input=result, block_shape=dilation_rate, crops=crops)
891
892    # Recover channel information for output shape if channels are not last.
893    if self.data_format is not None and self.data_format.startswith("NC"):
894      if not result_converted.shape.dims[1].value and filter is not None:
895        output_shape = result_converted.shape.as_list()
896        output_shape[1] = filter.shape[-1]
897        result_converted.set_shape(output_shape)
898
899    return result_converted
900
901  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
902    return self.call(inp, filter)
903
904
905def _with_space_to_batch_base_paddings(filter_shape, num_spatial_dims,
906                                       rate_or_const_rate):
907  """Helper function to compute base_paddings."""
908  # Spatial dimensions of the filters and the upsampled filters in which we
909  # introduce (rate - 1) zeros between consecutive filter values.
910  filter_spatial_shape = filter_shape[:num_spatial_dims]
911  pad_extra_shape = (filter_spatial_shape - 1) * rate_or_const_rate
912
913  # When full_padding_shape is odd, we pad more at end, following the same
914  # convention as conv2d.
915  pad_extra_start = pad_extra_shape // 2
916  pad_extra_end = pad_extra_shape - pad_extra_start
917  base_paddings = array_ops.stack(
918      [[pad_extra_start[i], pad_extra_end[i]] for i in range(num_spatial_dims)])
919  return base_paddings
920
921
922def _with_space_to_batch_adjust(orig, fill_value, spatial_dims):
923  """Returns an `adjusted` version of `orig` based on `spatial_dims`.
924
925  Tensor of the same type as `orig` and with shape
926  `[max(spatial_dims), ...]` where:
927
928    adjusted[spatial_dims[i] - 1, ...] = orig[i, ...]
929
930  for 0 <= i < len(spatial_dims), and
931
932    adjusted[j, ...] = fill_value
933
934  for j != spatial_dims[i] - 1 for some i.
935
936  If `orig` is a constant value, then the result will be a constant value.
937
938  Args:
939    orig: Tensor of rank > max(spatial_dims).
940    fill_value: Numpy scalar (of same data type as `orig) specifying the fill
941      value for non-spatial dimensions.
942    spatial_dims: See with_space_to_batch.
943
944  Returns:
945    `adjusted` tensor.
946  """
947  fill_dims = orig.get_shape().as_list()[1:]
948  dtype = orig.dtype.as_numpy_dtype
949  parts = []
950  const_orig = tensor_util.constant_value(orig)
951  const_or_orig = const_orig if const_orig is not None else orig
952  prev_spatial_dim = 0
953  i = 0
954  while i < len(spatial_dims):
955    start_i = i
956    start_spatial_dim = spatial_dims[i]
957    if start_spatial_dim > 1:
958      # Fill in any gap from the previous spatial dimension (or dimension 1 if
959      # this is the first spatial dimension) with `fill_value`.
960      parts.append(
961          np.full(
962              [start_spatial_dim - 1 - prev_spatial_dim] + fill_dims,
963              fill_value,
964              dtype=dtype))
965    # Find the largest value of i such that:
966    #   [spatial_dims[start_i], ..., spatial_dims[i]]
967    #     == [start_spatial_dim, ..., start_spatial_dim + i - start_i],
968    # i.e. the end of a contiguous group of spatial dimensions.
969    while (i + 1 < len(spatial_dims) and
970           spatial_dims[i + 1] == spatial_dims[i] + 1):
971      i += 1
972    parts.append(const_or_orig[start_i:i + 1])
973    prev_spatial_dim = spatial_dims[i]
974    i += 1
975  if const_orig is not None:
976    return np.concatenate(parts)
977  else:
978    return array_ops.concat(parts, 0)
979
980
981def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
982  """Helper function for verifying strides and dilation_rate arguments.
983
984  This is used by `convolution` and `pool`.
985
986  Args:
987    num_spatial_dims: int
988    strides: Optional.  List of N ints >= 1.  Defaults to `[1]*N`.  If any value
989      of strides is > 1, then all values of dilation_rate must be 1.
990    dilation_rate: Optional.  List of N ints >= 1.  Defaults to `[1]*N`.  If any
991      value of dilation_rate is > 1, then all values of strides must be 1.
992
993  Returns:
994    Normalized (strides, dilation_rate) as int32 numpy arrays of shape
995    [num_spatial_dims].
996
997  Raises:
998    ValueError: if the parameters are invalid.
999  """
1000  if dilation_rate is None:
1001    dilation_rate = [1] * num_spatial_dims
1002  elif len(dilation_rate) != num_spatial_dims:
1003    raise ValueError(f"`len(dilation_rate)` should be {num_spatial_dims}. "
1004                     f"Received: dilation_rate={dilation_rate} of length "
1005                     f"{len(dilation_rate)}")
1006  dilation_rate = np.array(dilation_rate, dtype=np.int32)
1007  if np.any(dilation_rate < 1):
1008    raise ValueError("all values of `dilation_rate` must be positive. "
1009                     f"Received: dilation_rate={dilation_rate}")
1010
1011  if strides is None:
1012    strides = [1] * num_spatial_dims
1013  elif len(strides) != num_spatial_dims:
1014    raise ValueError(f"`len(strides)` should be {num_spatial_dims}. "
1015                     f"Received: strides={strides} of length {len(strides)}")
1016  strides = np.array(strides, dtype=np.int32)
1017  if np.any(strides < 1):
1018    raise ValueError("all values of `strides` must be positive. "
1019                     f"Received: strides={strides}")
1020
1021  if np.any(strides > 1) and np.any(dilation_rate > 1):
1022    raise ValueError(
1023        "`strides > 1` not supported in conjunction with `dilation_rate > 1`. "
1024        f"Received: strides={strides} and dilation_rate={dilation_rate}")
1025  return strides, dilation_rate
1026
1027
1028@tf_export(v1=["nn.convolution"])
1029@dispatch.add_dispatch_support
1030def convolution(
1031    input,  # pylint: disable=redefined-builtin
1032    filter,  # pylint: disable=redefined-builtin
1033    padding,
1034    strides=None,
1035    dilation_rate=None,
1036    name=None,
1037    data_format=None,
1038    filters=None,
1039    dilations=None):  # pylint: disable=g-doc-args
1040  """Computes sums of N-D convolutions (actually cross-correlation).
1041
1042  This also supports either output striding via the optional `strides` parameter
1043  or atrous convolution (also known as convolution with holes or dilated
1044  convolution, based on the French word "trous" meaning holes in English) via
1045  the optional `dilation_rate` parameter.  Currently, however, output striding
1046  is not supported for atrous convolutions.
1047
1048  Specifically, in the case that `data_format` does not start with "NC", given
1049  a rank (N+2) `input` Tensor of shape
1050
1051    [num_batches,
1052     input_spatial_shape[0],
1053     ...,
1054     input_spatial_shape[N-1],
1055     num_input_channels],
1056
1057  a rank (N+2) `filter` Tensor of shape
1058
1059    [spatial_filter_shape[0],
1060     ...,
1061     spatial_filter_shape[N-1],
1062     num_input_channels,
1063     num_output_channels],
1064
1065  an optional `dilation_rate` tensor of shape N (defaults to `[1]*N`) specifying
1066  the filter upsampling/input downsampling rate, and an optional list of N
1067  `strides` (defaults to `[1]*N`), this computes for each N-D spatial output
1068  position `(x[0], ..., x[N-1])`:
1069
1070  ```
1071  output[b, x[0], ..., x[N-1], k] =
1072      sum_{z[0], ..., z[N-1], q}
1073          filter[z[0], ..., z[N-1], q, k] *
1074          padded_input[b,
1075                       x[0]*strides[0] + dilation_rate[0]*z[0],
1076                       ...,
1077                       x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],
1078                       q]
1079  ```
1080
1081  where b is the index into the batch, k is the output channel number, q is the
1082  input channel number, and z is the N-D spatial offset within the filter. Here,
1083  `padded_input` is obtained by zero padding the input using an effective
1084  spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
1085  output striding `strides`.
1086
1087  In the case that `data_format` does start with `"NC"`, the `input` and output
1088  (but not the `filter`) are simply transposed as follows:
1089
1090  ```python
1091  convolution(input, data_format, **kwargs) =
1092    tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),
1093                             **kwargs),
1094                 [0, N+1] + range(1, N+1))
1095  ```
1096
1097  It is required that 1 <= N <= 3.
1098
1099  Args:
1100    input: An (N+2)-D `Tensor` of type `T`, of shape
1101      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
1102      not start with "NC" (default), or
1103      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
1104      with "NC".
1105    filter: An (N+2)-D `Tensor` with the same type as `input` and shape
1106      `spatial_filter_shape + [in_channels, out_channels]`.
1107    padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.
1108      `"valid"` means no padding. `"same"` results in padding evenly to
1109      the left/right or up/down of the input such that output has the same
1110      height/width dimension as the input when the strides are 1. See
1111      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
1112      for more information.
1113    strides: Optional.  Sequence of N ints >= 1.  Specifies the output stride.
1114      Defaults to `[1]*N`.  If any value of strides is > 1, then all values of
1115      dilation_rate must be 1.
1116    dilation_rate: Optional.  Sequence of N ints >= 1.  Specifies the filter
1117      upsampling/input downsampling rate.  In the literature, the same parameter
1118      is sometimes called `input stride` or `dilation`.  The effective filter
1119      size used for the convolution will be `spatial_filter_shape +
1120      (spatial_filter_shape - 1) * (rate - 1)`, obtained by inserting
1121      (dilation_rate[i]-1) zeros between consecutive elements of the original
1122      filter in each spatial dimension i.  If any value of dilation_rate is > 1,
1123      then all values of strides must be 1.
1124    name: Optional name for the returned tensor.
1125    data_format: A string or None.  Specifies whether the channel dimension of
1126      the `input` and output is the last dimension (default, or if `data_format`
1127      does not start with "NC"), or the second dimension (if `data_format`
1128      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1129      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
1130      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
1131
1132  Returns:
1133    A `Tensor` with the same type as `input` of shape
1134
1135        `[batch_size] + output_spatial_shape + [out_channels]`
1136
1137    if data_format is None or does not start with "NC", or
1138
1139        `[batch_size, out_channels] + output_spatial_shape`
1140
1141    if data_format starts with "NC",
1142    where `output_spatial_shape` depends on the value of `padding`.
1143
1144    If padding == "SAME":
1145      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1146
1147    If padding == "VALID":
1148      output_spatial_shape[i] =
1149        ceil((input_spatial_shape[i] -
1150              (spatial_filter_shape[i]-1) * dilation_rate[i])
1151             / strides[i]).
1152
1153  Raises:
1154    ValueError: If input/output depth does not match `filter` shape, if padding
1155      is other than `"VALID"` or `"SAME"`, or if data_format is invalid.
1156
1157  """
1158  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
1159  dilation_rate = deprecated_argument_lookup(
1160      "dilations", dilations, "dilation_rate", dilation_rate)
1161  return convolution_internal(
1162      input,
1163      filter,
1164      strides=strides,
1165      padding=padding,
1166      data_format=data_format,
1167      dilations=dilation_rate,
1168      name=name)
1169
1170
1171@tf_export("nn.convolution", v1=[])
1172@dispatch.add_dispatch_support
1173def convolution_v2(  # pylint: disable=missing-docstring
1174    input,  # pylint: disable=redefined-builtin
1175    filters,
1176    strides=None,
1177    padding="VALID",
1178    data_format=None,
1179    dilations=None,
1180    name=None):
1181  return convolution_internal(
1182      input,  # pylint: disable=redefined-builtin
1183      filters,
1184      strides=strides,
1185      padding=padding,
1186      data_format=data_format,
1187      dilations=dilations,
1188      name=name)
1189
1190
1191convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
1192    deprecation.rewrite_argument_docstring(
1193        convolution.__doc__, "dilation_rate", "dilations"),
1194    "filter", "filters")
1195
1196
1197def convolution_internal(
1198    input,  # pylint: disable=redefined-builtin
1199    filters,
1200    strides=None,
1201    padding="VALID",
1202    data_format=None,
1203    dilations=None,
1204    name=None,
1205    call_from_convolution=True,
1206    num_spatial_dims=None):
1207  """Internal function which performs rank agnostic convolution.
1208
1209  Args:
1210    input: See `convolution`.
1211    filters: See `convolution`.
1212    strides: See `convolution`.
1213    padding: See `convolution`.
1214    data_format: See `convolution`.
1215    dilations: See `convolution`.
1216    name: See `convolution`.
1217    call_from_convolution: See `convolution`.
1218    num_spatial_dims: (Optional.).  It is a integer describing the
1219      rank of the spatial dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1220      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1221      This argument is only required to disambiguate the rank of `batch_shape`
1222      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1223      backwards compatibility, if `num_spatial_dims is None` and
1224     `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1225     `1` (i.e., the input is expected to be
1226     `[batch_size, num_channels] + input_spatial_shape`
1227     or `[batch_size] + input_spatial_shape + [num_channels]`.
1228
1229  Returns:
1230    A tensor of shape and dtype matching that of `input`.
1231
1232  Raises:
1233    ValueError: If input and filter both have unknown shapes, or if
1234      `num_spatial_dims` is provided and incompatible with the value
1235      estimated from `filters.shape`.
1236  """
1237  if (not isinstance(filters, variables_lib.Variable) and
1238      not tensor_util.is_tf_type(filters)):
1239    with ops.name_scope("convolution_internal", None, [filters, input]):
1240      filters = ops.convert_to_tensor(filters, name='filters')
1241  if (not isinstance(input, ops.Tensor) and not tensor_util.is_tf_type(input)):
1242    with ops.name_scope("convolution_internal", None, [filters, input]):
1243      input = ops.convert_to_tensor(input, name="input")
1244
1245  filters_rank = filters.shape.rank
1246  inputs_rank = input.shape.rank
1247  if num_spatial_dims is None:
1248    if filters_rank:
1249      num_spatial_dims = filters_rank - 2
1250    elif inputs_rank:
1251      num_spatial_dims = inputs_rank - 2
1252    else:
1253      raise ValueError(
1254          "When `num_spatial_dims` is not set, one of `input.shape.rank` or "
1255          "`filters.shape.rank` must be known. "
1256          f"Received: input.shape={input.shape} of rank {inputs_rank} and "
1257          f"filters.shape={filters.shape} of rank {filters_rank}")
1258  elif filters_rank and filters_rank - 2 != num_spatial_dims:
1259    raise ValueError(
1260        "`filters.shape.rank - 2` should equal `num_spatial_dims`. Received: "
1261        f"filters.shape={filters.shape} of rank {filters_rank} and "
1262        f"num_spatial_dims={num_spatial_dims}")
1263
1264  if inputs_rank:
1265    num_batch_dims = inputs_rank - num_spatial_dims - 1  # Channel dimension.
1266  else:
1267    num_batch_dims = 1  # By default, assume single batch dimension.
1268
1269  if num_spatial_dims not in {1, 2, 3}:
1270    raise ValueError(
1271        "`num_spatial_dims` must be 1, 2, or 3. "
1272        f"Received: num_spatial_dims={num_spatial_dims}.")
1273
1274  if data_format is None or data_format in _CHANNELS_LAST_FORMATS:
1275    channel_index = num_batch_dims + num_spatial_dims
1276  else:
1277    channel_index = num_batch_dims
1278
1279  if dilations is None:
1280    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1281                              "dilations")
1282    is_dilated_conv = False
1283  else:
1284    dilations = _get_sequence(dilations, num_spatial_dims, channel_index,
1285                              "dilations")
1286    is_dilated_conv = any(i != 1 for i in dilations)
1287
1288  strides = _get_sequence(strides, num_spatial_dims, channel_index, "strides")
1289  has_tpu_context = device_context.enclosing_tpu_context() is not None
1290
1291  if name:
1292    default_name = None
1293  elif not has_tpu_context or call_from_convolution:
1294    default_name = "convolution"
1295  elif num_spatial_dims == 2:  # Most common case.
1296    default_name = "Conv2D"
1297  elif num_spatial_dims == 3:
1298    default_name = "Conv3D"
1299  else:
1300    default_name = "conv1d"
1301
1302  with ops.name_scope(name, default_name, [input, filters]) as name:
1303    # Fast path for TPU or if no dilation, as gradient only supported on TPU
1304    # for dilations.
1305    if not is_dilated_conv or has_tpu_context:
1306      if num_spatial_dims == 2:  # Most common case.
1307        op = _conv2d_expanded_batch
1308      elif num_spatial_dims == 3:
1309        op = _conv3d_expanded_batch
1310      else:
1311        op = conv1d
1312
1313      return op(
1314          input,
1315          filters,
1316          strides,
1317          padding=padding,
1318          data_format=data_format,
1319          dilations=dilations,
1320          name=name)
1321    else:
1322      if channel_index == 1:
1323        strides = strides[2:]
1324        dilations = dilations[2:]
1325      else:
1326        strides = strides[1:-1]
1327        dilations = dilations[1:-1]
1328
1329      op = Convolution(
1330          tensor_shape.as_shape(input.shape),
1331          tensor_shape.as_shape(filters.shape),
1332          padding,
1333          strides=strides,
1334          dilation_rate=dilations,
1335          name=name,
1336          data_format=data_format,
1337          num_spatial_dims=num_spatial_dims)
1338      return op(input, filters)
1339
1340
1341class Convolution:
1342  """Helper class for convolution.
1343
1344  Note that this class assumes that shapes of input and filter passed to
1345  `__call__` are compatible with `input_shape`, `filter_shape`, and
1346  `num_spatial_dims` passed to the constructor.
1347
1348  Arguments
1349    input_shape: static shape of input. i.e. input.shape.  Its length is
1350      `batch_shape + input_spatial_shape + [num_channels]` if `data_format`
1351      does not start with `NC`, or
1352      `batch_shape + [num_channels] + input_spatial_shape` if `data_format`
1353      starts with `NC`.
1354    filter_shape: static shape of the filter. i.e. filter.shape.
1355    padding: The padding algorithm, must be "SAME" or "VALID".
1356    strides: see convolution.
1357    dilation_rate: see convolution.
1358    name: see convolution.
1359    data_format: A string or `None`.  Specifies whether the channel dimension of
1360      the `input` and output is the last dimension (if `data_format` is `None`
1361      or does not start with `NC`), or the first post-batch dimension (i.e. if
1362      `data_format` starts with `NC`).
1363    num_spatial_dims: (Usually optional.)  Python integer, the rank of the
1364      spatial and channel dimensions.  For `1-D`, `2-D` and `3-D` convolutions,
1365      the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively.
1366      This argument is only required to disambiguate the rank of `batch_shape`
1367      when `filter_shape.ndims is None` and `len(batch_shape) > 1`.  For
1368      backwards compatibility, if `num_spatial_dims is None` and
1369      `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be
1370      `1` (i.e., the input is expected to be
1371      `[batch_size, num_channels] + input_spatial_shape`
1372      or `[batch_size] + input_spatial_shape + [num_channels]`.
1373  """
1374
1375  def __init__(self,
1376               input_shape,
1377               filter_shape,
1378               padding,
1379               strides=None,
1380               dilation_rate=None,
1381               name=None,
1382               data_format=None,
1383               num_spatial_dims=None):
1384    """Helper function for convolution."""
1385    num_batch_dims = None
1386    filter_shape = tensor_shape.as_shape(filter_shape)
1387    input_shape = tensor_shape.as_shape(input_shape)
1388
1389    if filter_shape.ndims is not None:
1390      if (num_spatial_dims is not None and
1391          filter_shape.ndims != num_spatial_dims + 2):
1392        raise ValueError(
1393            "`filters.shape.rank` must be `num_spatial_dims + 2`. Received: "
1394            f"filters.shape={filter_shape} of rank {filter_shape.rank} and "
1395            f"num_spatial_dims={num_spatial_dims}")
1396      else:
1397        num_spatial_dims = filter_shape.ndims - 2
1398
1399    if input_shape.ndims is not None and num_spatial_dims is not None:
1400      num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1401
1402    if num_spatial_dims is None:
1403      num_spatial_dims = input_shape.ndims - 2
1404    else:
1405      if input_shape.ndims is not None:
1406        if input_shape.ndims < num_spatial_dims + 2:
1407          raise ValueError(
1408              "`input.shape.rank` must be >= than `num_spatial_dims + 2`. "
1409              f"Received: input.shape={input_shape} of rank {input_shape.rank} "
1410              f"and num_spatial_dims={num_spatial_dims}")
1411        else:
1412          if num_batch_dims is None:
1413            num_batch_dims = input_shape.ndims - num_spatial_dims - 1
1414
1415    if num_spatial_dims is None:
1416      raise ValueError(
1417          "When `num_spatial_dims` is not set, one of `input.shape.rank` or "
1418          "`filters.shape.rank` must be known. "
1419          f"Received: input.shape={input_shape} of rank {input_shape.rank} and "
1420          f"`filters.shape={filter_shape}` of rank {filter_shape.rank}")
1421
1422    if num_batch_dims is None:
1423      num_batch_dims = 1
1424
1425    if num_batch_dims < 1:
1426      raise ValueError(
1427          f"Batch dims should be >= 1, but found {num_batch_dims}. "
1428          "Batch dims was estimated as "
1429          "`input.shape.rank - num_spatial_dims - 1` and `num_spatial_dims` "
1430          "was either provided or estimated as `filters.shape.rank - 2`. "
1431          f"Received: input.shape={input_shape} of rank {input_shape.rank}, "
1432          f"filters.shape={filter_shape} of rank {filter_shape.rank}, and "
1433          f"num_spatial_dims={num_spatial_dims}")
1434
1435    if data_format is None or not data_format.startswith("NC"):
1436      input_channels_dim = tensor_shape.dimension_at_index(
1437          input_shape, num_spatial_dims + num_batch_dims)
1438      spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims)
1439    else:
1440      input_channels_dim = tensor_shape.dimension_at_index(
1441          input_shape, num_batch_dims)
1442      spatial_dims = range(
1443          num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
1444
1445    filter_dim = tensor_shape.dimension_at_index(filter_shape, num_spatial_dims)
1446    if not (input_channels_dim % filter_dim).is_compatible_with(0):
1447      raise ValueError(
1448          "The number of input channels is not divisible by the corresponding "
1449          f"number of output filters. Received: input.shape={input_shape} with "
1450          f"{input_channels_dim} channels and filters.shape={filter_shape} "
1451          f"with {filter_dim} output filters.")
1452
1453    strides, dilation_rate = _get_strides_and_dilation_rate(
1454        num_spatial_dims, strides, dilation_rate)
1455
1456    self.input_shape = input_shape
1457    self.filter_shape = filter_shape
1458    self.data_format = data_format
1459    self.strides = strides
1460    self.padding = padding
1461    self.name = name
1462    self.dilation_rate = dilation_rate
1463    self.num_batch_dims = num_batch_dims
1464    self.num_spatial_dims = num_spatial_dims
1465    self.conv_op = _WithSpaceToBatch(
1466        input_shape,
1467        dilation_rate=dilation_rate,
1468        padding=padding,
1469        build_op=self._build_op,
1470        filter_shape=filter_shape,
1471        spatial_dims=spatial_dims,
1472        data_format=data_format,
1473        num_batch_dims=num_batch_dims)
1474
1475  def _build_op(self, _, padding):
1476    return _NonAtrousConvolution(
1477        self.input_shape,
1478        filter_shape=self.filter_shape,
1479        padding=padding,
1480        data_format=self.data_format,
1481        strides=self.strides,
1482        name=self.name,
1483        num_batch_dims=self.num_batch_dims)
1484
1485  def __call__(self, inp, filter):  # pylint: disable=redefined-builtin
1486    # TPU convolution supports dilations greater than 1.
1487    if device_context.enclosing_tpu_context() is not None:
1488      return convolution_internal(
1489          inp,
1490          filter,
1491          strides=self.strides,
1492          padding=self.padding,
1493          data_format=self.data_format,
1494          dilations=self.dilation_rate,
1495          name=self.name,
1496          call_from_convolution=False,
1497          num_spatial_dims=self.num_spatial_dims)
1498    else:
1499      return self.conv_op(inp, filter)
1500
1501
1502@tf_export(v1=["nn.pool"])
1503@dispatch.add_dispatch_support
1504def pool(
1505    input,  # pylint: disable=redefined-builtin
1506    window_shape,
1507    pooling_type,
1508    padding,
1509    dilation_rate=None,
1510    strides=None,
1511    name=None,
1512    data_format=None,
1513    dilations=None):
1514  """Performs an N-D pooling operation.
1515
1516  In the case that `data_format` does not start with "NC", computes for
1517      0 <= b < batch_size,
1518      0 <= x[i] < output_spatial_shape[i],
1519      0 <= c < num_channels:
1520
1521  ```
1522  output[b, x[0], ..., x[N-1], c] =
1523    REDUCE_{z[0], ..., z[N-1]}
1524      input[b,
1525            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1526            ...
1527            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1528            c],
1529  ```
1530
1531  where the reduction function REDUCE depends on the value of `pooling_type`,
1532  and pad_before is defined based on the value of `padding` as described in
1533  the "returns" section of `tf.nn.convolution` for details.
1534  The reduction never includes out-of-bounds positions.
1535
1536  In the case that `data_format` starts with `"NC"`, the `input` and output are
1537  simply transposed as follows:
1538
1539  ```python
1540  pool(input, data_format, **kwargs) =
1541    tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1542                      **kwargs),
1543                 [0, N+1] + range(1, N+1))
1544  ```
1545
1546  Args:
1547    input: Tensor of rank N+2, of shape
1548      `[batch_size] + input_spatial_shape + [num_channels]` if data_format does
1549      not start with "NC" (default), or
1550      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1551      with "NC".  Pooling happens over the spatial dimensions only.
1552    window_shape: Sequence of N ints >= 1.
1553    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1554    padding: The padding algorithm, must be "SAME" or "VALID".
1555      See the "returns" section of `tf.nn.convolution` for details.
1556    dilation_rate: Optional.  Dilation rate.  List of N ints >= 1.
1557      Defaults to `[1]*N`.  If any value of dilation_rate is > 1, then all
1558      values of strides must be 1.
1559    strides: Optional.  Sequence of N ints >= 1.  Defaults to `[1]*N`.
1560      If any value of strides is > 1, then all values of dilation_rate must be
1561      1.
1562    name: Optional. Name of the op.
1563    data_format: A string or None.  Specifies whether the channel dimension of
1564      the `input` and output is the last dimension (default, or if `data_format`
1565      does not start with "NC"), or the second dimension (if `data_format`
1566      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1567      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
1568      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
1569    dilations: Alias for dilation_rate
1570
1571  Returns:
1572    Tensor of rank N+2, of shape
1573      [batch_size] + output_spatial_shape + [num_channels]
1574
1575    if data_format is None or does not start with "NC", or
1576
1577      [batch_size, num_channels] + output_spatial_shape
1578
1579    if data_format starts with "NC",
1580    where `output_spatial_shape` depends on the value of padding:
1581
1582    If padding = "SAME":
1583      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1584
1585    If padding = "VALID":
1586      output_spatial_shape[i] =
1587        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1588             / strides[i]).
1589
1590  Raises:
1591    ValueError: if arguments are invalid.
1592
1593  """
1594  dilation_rate = deprecated_argument_lookup(
1595      "dilations", dilations, "dilation_rate", dilation_rate)
1596  # pylint: enable=line-too-long
1597  with ops.name_scope(name, "%s_pool" % (pooling_type.lower()),
1598                      [input]) as scope:
1599    input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtin
1600
1601    num_spatial_dims = len(window_shape)
1602    if num_spatial_dims < 1 or num_spatial_dims > 3:
1603      raise ValueError("`len(window_shape)` must be 1, 2, or 3. Received: "
1604                       f"window_shape={window_shape} of length "
1605                       f"{len(window_shape)}")
1606
1607    input.get_shape().with_rank(num_spatial_dims + 2)
1608
1609    strides, dilation_rate = _get_strides_and_dilation_rate(
1610        num_spatial_dims, strides, dilation_rate)
1611
1612    if padding == "SAME" and np.any(dilation_rate > 1):
1613      raise ValueError(
1614          "pooling with 'SAME' padding is not implemented for "
1615          f"`dilation_rate` > 1. Received: padding={padding} and "
1616          f"dilation_rate={dilation_rate}")
1617
1618    if np.any(strides > window_shape):
1619      raise ValueError(
1620          "`strides` > `window_shape` not supported due to inconsistency "
1621          f"between CPU and GPU implementations. Received: strides={strides} "
1622          f"and window_shape={window_shape}")
1623
1624    pooling_ops = {
1625        ("MAX", 1): max_pool,
1626        ("MAX", 2): max_pool,
1627        ("MAX", 3): max_pool3d,  # pylint: disable=undefined-variable
1628        ("AVG", 1): avg_pool,
1629        ("AVG", 2): avg_pool,
1630        ("AVG", 3): avg_pool3d,  # pylint: disable=undefined-variable
1631    }
1632    op_key = (pooling_type, num_spatial_dims)
1633    if op_key not in pooling_ops:
1634      raise ValueError(
1635          f"{num_spatial_dims}-D {pooling_type} pooling is not supported.")
1636
1637    if data_format is None or not data_format.startswith("NC"):
1638      adjusted_window_shape = [1] + list(window_shape) + [1]
1639      adjusted_strides = [1] + list(strides) + [1]
1640      spatial_dims = range(1, num_spatial_dims + 1)
1641    else:
1642      adjusted_window_shape = [1, 1] + list(window_shape)
1643      adjusted_strides = [1, 1] + list(strides)
1644      spatial_dims = range(2, num_spatial_dims + 2)
1645
1646    if num_spatial_dims == 1:
1647      if data_format is None or data_format == "NWC":
1648        data_format_kwargs = dict(data_format="NHWC")
1649      elif data_format == "NCW":
1650        data_format_kwargs = dict(data_format="NCHW")
1651      else:
1652        raise ValueError("data_format must be either 'NWC' or 'NCW'. "
1653                         f"Received: data_format={data_format}")
1654      adjusted_window_shape = [1] + adjusted_window_shape
1655      adjusted_strides = [1] + adjusted_strides
1656    else:
1657      data_format_kwargs = dict(data_format=data_format)
1658
1659    def op(converted_input, _, converted_padding):  # pylint: disable=missing-docstring
1660      if num_spatial_dims == 1:
1661        converted_input = array_ops.expand_dims(converted_input,
1662                                                spatial_dims[0])
1663      result = pooling_ops[op_key](
1664          converted_input,
1665          adjusted_window_shape,
1666          adjusted_strides,
1667          converted_padding,
1668          name=scope,
1669          **data_format_kwargs)
1670      if num_spatial_dims == 1:
1671        result = array_ops.squeeze(result, [spatial_dims[0]])
1672      return result
1673
1674    return with_space_to_batch(
1675        input=input,
1676        dilation_rate=dilation_rate,
1677        padding=padding,
1678        op=op,
1679        spatial_dims=spatial_dims,
1680        filter_shape=window_shape)
1681
1682
1683@tf_export("nn.pool", v1=[])
1684@dispatch.add_dispatch_support
1685def pool_v2(
1686    input,  # pylint: disable=redefined-builtin
1687    window_shape,
1688    pooling_type,
1689    strides=None,
1690    padding="VALID",
1691    data_format=None,
1692    dilations=None,
1693    name=None):
1694  # pylint: disable=line-too-long
1695  """Performs an N-D pooling operation.
1696
1697  In the case that `data_format` does not start with "NC", computes for
1698      0 <= b < batch_size,
1699      0 <= x[i] < output_spatial_shape[i],
1700      0 <= c < num_channels:
1701
1702  ```
1703  output[b, x[0], ..., x[N-1], c] =
1704    REDUCE_{z[0], ..., z[N-1]}
1705      input[b,
1706            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
1707            ...
1708            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1],
1709            c],
1710  ```
1711
1712  where the reduction function REDUCE depends on the value of `pooling_type`,
1713  and pad_before is defined based on the value of `padding` as described in
1714  the "returns" section of `tf.nn.convolution` for details.
1715  The reduction never includes out-of-bounds positions.
1716
1717  In the case that `data_format` starts with `"NC"`, the `input` and output are
1718  simply transposed as follows:
1719
1720  ```python
1721  pool(input, data_format, **kwargs) =
1722    tf.transpose(pool(tf.transpose(input, [0] + range(2,N+2) + [1]),
1723                      **kwargs),
1724                 [0, N+1] + range(1, N+1))
1725  ```
1726
1727  Args:
1728    input: Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
1729      [num_channels]` if data_format does not start with "NC" (default), or
1730      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
1731      with "NC".  Pooling happens over the spatial dimensions only.
1732    window_shape: Sequence of N ints >= 1.
1733    pooling_type: Specifies pooling operation, must be "AVG" or "MAX".
1734    strides: Optional. Sequence of N ints >= 1.  Defaults to `[1]*N`. If any value of
1735      strides is > 1, then all values of dilation_rate must be 1.
1736    padding: The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
1737      See
1738      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
1739      for more information.
1740    data_format: A string or None.  Specifies whether the channel dimension of
1741      the `input` and output is the last dimension (default, or if `data_format`
1742      does not start with "NC"), or the second dimension (if `data_format`
1743      starts with "NC").  For N=1, the valid values are "NWC" (default) and
1744      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW". For
1745      N=3, the valid values are "NDHWC" (default) and "NCDHW".
1746    dilations: Optional.  Dilation rate.  List of N ints >= 1. Defaults to
1747      `[1]*N`.  If any value of dilation_rate is > 1, then all values of strides
1748      must be 1.
1749    name: Optional. Name of the op.
1750
1751  Returns:
1752    Tensor of rank N+2, of shape
1753      [batch_size] + output_spatial_shape + [num_channels]
1754
1755    if data_format is None or does not start with "NC", or
1756
1757      [batch_size, num_channels] + output_spatial_shape
1758
1759    if data_format starts with "NC",
1760    where `output_spatial_shape` depends on the value of padding:
1761
1762    If padding = "SAME":
1763      output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
1764
1765    If padding = "VALID":
1766      output_spatial_shape[i] =
1767        ceil((input_spatial_shape[i] - (window_shape[i] - 1) * dilation_rate[i])
1768             / strides[i]).
1769
1770  Raises:
1771    ValueError: if arguments are invalid.
1772  """
1773  return pool(
1774      input=input,
1775      window_shape=window_shape,
1776      pooling_type=pooling_type,
1777      padding=padding,
1778      dilation_rate=dilations,
1779      strides=strides,
1780      name=name,
1781      data_format=data_format)
1782
1783
1784@tf_export("nn.atrous_conv2d")
1785@dispatch.add_dispatch_support
1786def atrous_conv2d(value, filters, rate, padding, name=None):
1787  """Atrous convolution (a.k.a. convolution with holes or dilated convolution).
1788
1789  This function is a simpler wrapper around the more general
1790  `tf.nn.convolution`, and exists only for backwards compatibility. You can
1791  use `tf.nn.convolution` to perform 1-D, 2-D, or 3-D atrous convolution.
1792
1793  Computes a 2-D atrous convolution, also known as convolution with holes or
1794  dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
1795  parameter is equal to one, it performs regular 2-D convolution. If the `rate`
1796  parameter is greater than one, it performs convolution with holes, sampling
1797  the input values every `rate` pixels in the `height` and `width` dimensions.
1798  This is equivalent to convolving the input with a set of upsampled filters,
1799  produced by inserting `rate - 1` zeros between two consecutive values of the
1800  filters along the `height` and `width` dimensions, hence the name atrous
1801  convolution or convolution with holes (the French word trous means holes in
1802  English).
1803
1804  More specifically:
1805
1806  ```
1807  output[batch, height, width, out_channel] =
1808      sum_{dheight, dwidth, in_channel} (
1809          filters[dheight, dwidth, in_channel, out_channel] *
1810          value[batch, height + rate*dheight, width + rate*dwidth, in_channel]
1811      )
1812  ```
1813
1814  Atrous convolution allows us to explicitly control how densely to compute
1815  feature responses in fully convolutional networks. Used in conjunction with
1816  bilinear interpolation, it offers an alternative to `conv2d_transpose` in
1817  dense prediction tasks such as semantic image segmentation, optical flow
1818  computation, or depth estimation. It also allows us to effectively enlarge
1819  the field of view of filters without increasing the number of parameters or
1820  the amount of computation.
1821
1822  For a description of atrous convolution and how it can be used for dense
1823  feature extraction, please see: (Chen et al., 2015). The same operation is
1824  investigated further in (Yu et al., 2016). Previous works that effectively
1825  use atrous convolution in different ways are, among others,
1826  (Sermanet et al., 2014) and (Giusti et al., 2013).
1827  Atrous convolution is also closely related to the so-called noble identities
1828  in multi-rate signal processing.
1829
1830  There are many different ways to implement atrous convolution (see the refs
1831  above). The implementation here reduces
1832
1833  ```python
1834  atrous_conv2d(value, filters, rate, padding=padding)
1835  ```
1836
1837  to the following three operations:
1838
1839  ```python
1840  paddings = ...
1841  net = space_to_batch(value, paddings, block_size=rate)
1842  net = conv2d(net, filters, strides=[1, 1, 1, 1], padding="VALID")
1843  crops = ...
1844  net = batch_to_space(net, crops, block_size=rate)
1845  ```
1846
1847  Advanced usage. Note the following optimization: A sequence of `atrous_conv2d`
1848  operations with identical `rate` parameters, 'SAME' `padding`, and filters
1849  with odd heights/ widths:
1850
1851  ```python
1852  net = atrous_conv2d(net, filters1, rate, padding="SAME")
1853  net = atrous_conv2d(net, filters2, rate, padding="SAME")
1854  ...
1855  net = atrous_conv2d(net, filtersK, rate, padding="SAME")
1856  ```
1857
1858  can be equivalently performed cheaper in terms of computation and memory as:
1859
1860  ```python
1861  pad = ...  # padding so that the input dims are multiples of rate
1862  net = space_to_batch(net, paddings=pad, block_size=rate)
1863  net = conv2d(net, filters1, strides=[1, 1, 1, 1], padding="SAME")
1864  net = conv2d(net, filters2, strides=[1, 1, 1, 1], padding="SAME")
1865  ...
1866  net = conv2d(net, filtersK, strides=[1, 1, 1, 1], padding="SAME")
1867  net = batch_to_space(net, crops=pad, block_size=rate)
1868  ```
1869
1870  because a pair of consecutive `space_to_batch` and `batch_to_space` ops with
1871  the same `block_size` cancel out when their respective `paddings` and `crops`
1872  inputs are identical.
1873
1874  Args:
1875    value: A 4-D `Tensor` of type `float`. It needs to be in the default "NHWC"
1876      format. Its shape is `[batch, in_height, in_width, in_channels]`.
1877    filters: A 4-D `Tensor` with the same type as `value` and shape
1878      `[filter_height, filter_width, in_channels, out_channels]`. `filters`'
1879      `in_channels` dimension must match that of `value`. Atrous convolution is
1880      equivalent to standard convolution with upsampled filters with effective
1881      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
1882      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
1883      inserting `rate - 1` zeros along consecutive elements across the
1884      `filters`' spatial dimensions.
1885    rate: A positive int32. The stride with which we sample input values across
1886      the `height` and `width` dimensions. Equivalently, the rate by which we
1887      upsample the filter values by inserting zeros across the `height` and
1888      `width` dimensions. In the literature, the same parameter is sometimes
1889      called `input stride` or `dilation`.
1890    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
1891      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
1892      for more information.
1893    name: Optional name for the returned tensor.
1894
1895  Returns:
1896    A `Tensor` with the same type as `value`.
1897    Output shape with `'VALID'` padding is:
1898
1899        [batch, height - rate * (filter_width - 1),
1900         width - rate * (filter_height - 1), out_channels].
1901
1902    Output shape with `'SAME'` padding is:
1903
1904        [batch, height, width, out_channels].
1905
1906  Raises:
1907    ValueError: If input/output depth does not match `filters`' shape, or if
1908      padding is other than `'VALID'` or `'SAME'`.
1909
1910  References:
1911    Multi-Scale Context Aggregation by Dilated Convolutions:
1912      [Yu et al., 2016](https://arxiv.org/abs/1511.07122)
1913      ([pdf](https://arxiv.org/pdf/1511.07122.pdf))
1914    Semantic Image Segmentation with Deep Convolutional Nets and Fully
1915    Connected CRFs:
1916      [Chen et al., 2015](http://arxiv.org/abs/1412.7062)
1917      ([pdf](https://arxiv.org/pdf/1412.7062))
1918    OverFeat - Integrated Recognition, Localization and Detection using
1919    Convolutional Networks:
1920      [Sermanet et al., 2014](https://arxiv.org/abs/1312.6229)
1921      ([pdf](https://arxiv.org/pdf/1312.6229.pdf))
1922    Fast Image Scanning with Deep Max-Pooling Convolutional Neural Networks:
1923      [Giusti et al., 2013]
1924      (https://ieeexplore.ieee.org/abstract/document/6738831)
1925      ([pdf](https://arxiv.org/pdf/1302.1700.pdf))
1926  """
1927  return convolution(
1928      input=value,
1929      filter=filters,
1930      padding=padding,
1931      dilation_rate=np.broadcast_to(rate, (2,)),
1932      name=name)
1933
1934
1935def convert_padding(padding, expected_length=4):
1936  """Converts Python padding to C++ padding for ops which take EXPLICIT padding.
1937
1938  Args:
1939    padding: the `padding` argument for a Python op which supports EXPLICIT
1940      padding.
1941    expected_length: Expected number of entries in the padding list when
1942      explicit padding is used.
1943
1944  Returns:
1945    (padding, explicit_paddings) pair, which should be passed as attributes to a
1946    C++ op.
1947
1948  Raises:
1949    ValueError: If padding is invalid.
1950  """
1951  explicit_paddings = []
1952  if padding == "EXPLICIT":
1953    raise ValueError("'EXPLICIT' is not a valid value for `padding`. To use "
1954                     "explicit padding, `padding` must be a list.")
1955  if isinstance(padding, (list, tuple)):
1956    for i, dim_paddings in enumerate(padding):
1957      if not isinstance(dim_paddings, (list, tuple)):
1958        raise ValueError("When `padding` is a list, each element of `padding` "
1959                         "must be a list/tuple of size 2. Received: "
1960                         f"padding={padding} with element at index {i} of type "
1961                         f"{type(dim_paddings)}")
1962      if len(dim_paddings) != 2:
1963        raise ValueError("When `padding` is a list, each element of `padding` "
1964                         "must be a list/tuple of size 2. Received: "
1965                         f"padding={padding} with element at index {i} of size "
1966                         f"{len(dim_paddings)}")
1967      explicit_paddings.extend(dim_paddings)
1968    if len(padding) != expected_length:
1969      raise ValueError(
1970          f"When padding is a list, it must be of size {expected_length}. "
1971          f"Received: padding={padding} of size {len(padding)}")
1972    padding = "EXPLICIT"
1973  return padding, explicit_paddings
1974
1975
1976@tf_export(v1=["nn.conv1d"])
1977@dispatch.add_dispatch_support
1978@deprecation.deprecated_arg_values(
1979    None,
1980    "`NCHW` for data_format is deprecated, use `NCW` instead",
1981    warn_once=True,
1982    data_format="NCHW")
1983@deprecation.deprecated_arg_values(
1984    None,
1985    "`NHWC` for data_format is deprecated, use `NWC` instead",
1986    warn_once=True,
1987    data_format="NHWC")
1988def conv1d(
1989    value=None,
1990    filters=None,
1991    stride=None,
1992    padding=None,
1993    use_cudnn_on_gpu=None,
1994    data_format=None,
1995    name=None,
1996    input=None,  # pylint: disable=redefined-builtin
1997    dilations=None):
1998  r"""Computes a 1-D convolution of input with rank `>=3` and a `3-D` filter.
1999
2000  Given an input tensor of shape
2001    `batch_shape + [in_width, in_channels]`
2002  if `data_format` is `"NWC"`, or
2003    `batch_shape + [in_channels, in_width]`
2004  if `data_format` is `"NCW"`,
2005  and a filter / kernel tensor of shape
2006  `[filter_width, in_channels, out_channels]`, this op reshapes
2007  the arguments to pass them to `conv2d` to perform the equivalent
2008  convolution operation.
2009
2010  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
2011  For example, if `data_format` does not start with "NC", a tensor of shape
2012    `batch_shape + [in_width, in_channels]`
2013  is reshaped to
2014    `batch_shape + [1, in_width, in_channels]`,
2015  and the filter is reshaped to
2016    `[1, filter_width, in_channels, out_channels]`.
2017  The result is then reshaped back to
2018    `batch_shape + [out_width, out_channels]`
2019  \(where out_width is a function of the stride and padding as in conv2d\) and
2020  returned to the caller.
2021
2022  Args:
2023    value: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
2024      `float64`.
2025    filters: A Tensor of rank at least 3.  Must have the same type as `value`.
2026    stride: An int or list of `ints` that has length `1` or `3`.  The number of
2027      entries by which the filter is moved right at each step.
2028    padding: 'SAME' or 'VALID'
2029    use_cudnn_on_gpu: An optional `bool`.  Defaults to `True`.
2030    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
2031      the data is stored in the order of `batch_shape + [in_width,
2032      in_channels]`.  The `"NCW"` format stores data as `batch_shape +
2033      [in_channels, in_width]`.
2034    name: A name for the operation (optional).
2035    input: Alias for value.
2036    dilations: An int or list of `ints` that has length `1` or `3` which
2037      defaults to 1. The dilation factor for each dimension of input. If set to
2038      k > 1, there will be k-1 skipped cells between each filter element on that
2039      dimension. Dilations in the batch and depth dimensions must be 1.
2040
2041  Returns:
2042    A `Tensor`.  Has the same type as input.
2043
2044  Raises:
2045    ValueError: if `data_format` is invalid.
2046  """
2047  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
2048  with ops.name_scope(name, "conv1d", [value, filters]) as name:
2049    # Reshape the input tensor to batch_shape + [1, in_width, in_channels]
2050    if data_format is None or data_format == "NHWC" or data_format == "NWC":
2051      data_format = "NHWC"
2052      spatial_start_dim = -3
2053      channel_index = 2
2054    elif data_format == "NCHW" or data_format == "NCW":
2055      data_format = "NCHW"
2056      spatial_start_dim = -2
2057      channel_index = 1
2058    else:
2059      raise ValueError("`data_format` must be 'NWC' or 'NCW'. "
2060                       f"Received: data_format={data_format}")
2061    strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
2062    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
2063
2064    value = array_ops.expand_dims(value, spatial_start_dim)
2065    filters = array_ops.expand_dims(filters, 0)
2066    if value.shape.ndims in (4, 3, 2, 1, 0, None):
2067      result = gen_nn_ops.conv2d(
2068          value,
2069          filters,
2070          strides,
2071          padding,
2072          use_cudnn_on_gpu=use_cudnn_on_gpu,
2073          data_format=data_format,
2074          dilations=dilations,
2075          name=name)
2076    else:
2077      result = squeeze_batch_dims(
2078          value,
2079          functools.partial(
2080              gen_nn_ops.conv2d,
2081              filter=filters,
2082              strides=strides,
2083              padding=padding,
2084              use_cudnn_on_gpu=use_cudnn_on_gpu,
2085              data_format=data_format,
2086              dilations=dilations,
2087          ),
2088          inner_rank=3,
2089          name=name)
2090    return array_ops.squeeze(result, [spatial_start_dim])
2091
2092
2093@tf_export("nn.conv1d", v1=[])
2094@dispatch.add_dispatch_support
2095def conv1d_v2(
2096    input,  # pylint: disable=redefined-builtin
2097    filters,
2098    stride,
2099    padding,
2100    data_format="NWC",
2101    dilations=None,
2102    name=None):
2103  r"""Computes a 1-D convolution given 3-D input and filter tensors.
2104
2105  Given an input tensor of shape
2106    `batch_shape + [in_width, in_channels]`
2107  if `data_format` is `"NWC"`, or
2108    `batch_shape + [in_channels, in_width]`
2109  if `data_format` is `"NCW"`,
2110  and a filter / kernel tensor of shape
2111  `[filter_width, in_channels, out_channels]`, this op reshapes
2112  the arguments to pass them to `conv2d` to perform the equivalent
2113  convolution operation.
2114
2115  Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`.
2116  For example, if `data_format` does not start with `"NC"`, a tensor of shape
2117    `batch_shape + [in_width, in_channels]`
2118  is reshaped to
2119    `batch_shape + [1, in_width, in_channels]`,
2120  and the filter is reshaped to
2121    `[1, filter_width, in_channels, out_channels]`.
2122  The result is then reshaped back to
2123    `batch_shape + [out_width, out_channels]`
2124  \(where out_width is a function of the stride and padding as in conv2d\) and
2125  returned to the caller.
2126
2127  Args:
2128    input: A Tensor of rank at least 3. Must be of type `float16`, `float32`, or
2129      `float64`.
2130    filters: A Tensor of rank at least 3.  Must have the same type as `input`.
2131    stride: An int or list of `ints` that has length `1` or `3`.  The number of
2132      entries by which the filter is moved right at each step.
2133    padding: 'SAME' or 'VALID'. See
2134      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
2135      for more information.
2136    data_format: An optional `string` from `"NWC", "NCW"`.  Defaults to `"NWC"`,
2137      the data is stored in the order of
2138      `batch_shape + [in_width, in_channels]`.  The `"NCW"` format stores data
2139      as `batch_shape + [in_channels, in_width]`.
2140    dilations: An int or list of `ints` that has length `1` or `3` which
2141      defaults to 1. The dilation factor for each dimension of input. If set to
2142      k > 1, there will be k-1 skipped cells between each filter element on that
2143      dimension. Dilations in the batch and depth dimensions must be 1.
2144    name: A name for the operation (optional).
2145
2146  Returns:
2147    A `Tensor`.  Has the same type as input.
2148
2149  Raises:
2150    ValueError: if `data_format` is invalid.
2151  """
2152  return conv1d(
2153      input,  # pylint: disable=redefined-builtin
2154      filters,
2155      stride,
2156      padding,
2157      use_cudnn_on_gpu=True,
2158      data_format=data_format,
2159      name=name,
2160      dilations=dilations)
2161
2162
2163@tf_export("nn.conv1d_transpose")
2164@dispatch.add_dispatch_support
2165def conv1d_transpose(
2166    input,  # pylint: disable=redefined-builtin
2167    filters,
2168    output_shape,
2169    strides,
2170    padding="SAME",
2171    data_format="NWC",
2172    dilations=None,
2173    name=None):
2174  """The transpose of `conv1d`.
2175
2176  This operation is sometimes called "deconvolution" after
2177  (Zeiler et al., 2010), but is actually the transpose (gradient) of `conv1d`
2178  rather than an actual deconvolution.
2179
2180  Args:
2181    input: A 3-D `Tensor` of type `float` and shape
2182      `[batch, in_width, in_channels]` for `NWC` data format or
2183      `[batch, in_channels, in_width]` for `NCW` data format.
2184    filters: A 3-D `Tensor` with the same type as `input` and shape
2185      `[filter_width, output_channels, in_channels]`.  `filter`'s
2186      `in_channels` dimension must match that of `input`.
2187    output_shape: A 1-D `Tensor`, containing three elements, representing the
2188      output shape of the deconvolution op.
2189    strides: An int or list of `ints` that has length `1` or `3`.  The number of
2190      entries by which the filter is moved right at each step.
2191    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
2192      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
2193      for more information.
2194    data_format: A string. `'NWC'` and `'NCW'` are supported.
2195    dilations: An int or list of `ints` that has length `1` or `3` which
2196      defaults to 1. The dilation factor for each dimension of input. If set to
2197      k > 1, there will be k-1 skipped cells between each filter element on that
2198      dimension. Dilations in the batch and depth dimensions must be 1.
2199    name: Optional name for the returned tensor.
2200
2201  Returns:
2202    A `Tensor` with the same type as `input`.
2203
2204  Raises:
2205    ValueError: If input/output depth does not match `filter`'s shape, if
2206      `output_shape` is not at 3-element vector, if `padding` is other than
2207      `'VALID'` or `'SAME'`, or if `data_format` is invalid.
2208
2209  References:
2210    Deconvolutional Networks:
2211      [Zeiler et al., 2010]
2212      (https://ieeexplore.ieee.org/abstract/document/5539957)
2213      ([pdf]
2214      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2215  """
2216  with ops.name_scope(name, "conv1d_transpose",
2217                      [input, filters, output_shape]) as name:
2218    # The format could be either NWC or NCW, map to NHWC or NCHW
2219    if data_format is None or data_format == "NWC":
2220      data_format = "NHWC"
2221      spatial_start_dim = 1
2222      channel_index = 2
2223    elif data_format == "NCW":
2224      data_format = "NCHW"
2225      spatial_start_dim = 2
2226      channel_index = 1
2227    else:
2228      raise ValueError("`data_format` must be 'NWC' or 'NCW'. "
2229                       f"Received: data_format={data_format}")
2230
2231    # Reshape the input tensor to [batch, 1, in_width, in_channels]
2232    strides = [1] + _get_sequence(strides, 1, channel_index, "stride")
2233    dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
2234
2235    input = array_ops.expand_dims(input, spatial_start_dim)
2236    filters = array_ops.expand_dims(filters, 0)
2237    output_shape = list(output_shape) if not isinstance(
2238        output_shape, ops.Tensor) else output_shape
2239    output_shape = array_ops.concat([output_shape[: spatial_start_dim], [1],
2240                                     output_shape[spatial_start_dim:]], 0)
2241
2242    result = gen_nn_ops.conv2d_backprop_input(
2243        input_sizes=output_shape,
2244        filter=filters,
2245        out_backprop=input,
2246        strides=strides,
2247        padding=padding,
2248        data_format=data_format,
2249        dilations=dilations,
2250        name=name)
2251    return array_ops.squeeze(result, spatial_start_dim)
2252
2253
2254@tf_export("nn.conv2d", v1=[])
2255@dispatch.add_dispatch_support
2256def conv2d_v2(input,  # pylint: disable=redefined-builtin
2257              filters,
2258              strides,
2259              padding,
2260              data_format="NHWC",
2261              dilations=None,
2262              name=None):
2263  # pylint: disable=line-too-long
2264  r"""Computes a 2-D convolution given `input` and 4-D `filters` tensors.
2265
2266  The `input` tensor may have rank `4` or higher, where shape dimensions `[:-3]`
2267  are considered batch dimensions (`batch_shape`).
2268
2269  Given an input tensor of shape
2270  `batch_shape + [in_height, in_width, in_channels]` and a filter / kernel
2271  tensor of shape `[filter_height, filter_width, in_channels, out_channels]`,
2272  this op performs the following:
2273
2274  1. Flattens the filter to a 2-D matrix with shape
2275     `[filter_height * filter_width * in_channels, output_channels]`.
2276  2. Extracts image patches from the input tensor to form a *virtual*
2277     tensor of shape `[batch, out_height, out_width,
2278     filter_height * filter_width * in_channels]`.
2279  3. For each patch, right-multiplies the filter matrix and the image patch
2280     vector.
2281
2282  In detail, with the default NHWC format,
2283
2284      output[b, i, j, k] =
2285          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
2286                          filter[di, dj, q, k]
2287
2288  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2289  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2290
2291  Usage Example:
2292
2293  >>> x_in = np.array([[
2294  ...   [[2], [1], [2], [0], [1]],
2295  ...   [[1], [3], [2], [2], [3]],
2296  ...   [[1], [1], [3], [3], [0]],
2297  ...   [[2], [2], [0], [1], [1]],
2298  ...   [[0], [0], [3], [1], [2]], ]])
2299  >>> kernel_in = np.array([
2300  ...  [ [[2, 0.1]], [[3, 0.2]] ],
2301  ...  [ [[0, 0.3]], [[1, 0.4]] ], ])
2302  >>> x = tf.constant(x_in, dtype=tf.float32)
2303  >>> kernel = tf.constant(kernel_in, dtype=tf.float32)
2304  >>> tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
2305  <tf.Tensor: shape=(1, 4, 4, 2), dtype=float32, numpy=..., dtype=float32)>
2306
2307  Args:
2308    input: A `Tensor`. Must be one of the following types:
2309      `half`, `bfloat16`, `float32`, `float64`.
2310      A Tensor of rank at least 4. The dimension order is interpreted according
2311      to the value of `data_format`; with the all-but-inner-3 dimensions acting
2312      as batch dimensions. See below for details.
2313    filters: A `Tensor`. Must have the same type as `input`.
2314      A 4-D tensor of shape
2315      `[filter_height, filter_width, in_channels, out_channels]`
2316    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2317      stride of the sliding window for each dimension of `input`. If a single
2318      value is given it is replicated in the `H` and `W` dimension. By default
2319      the `N` and `C` dimensions are set to 1. The dimension order is determined
2320      by the value of `data_format`, see below for details.
2321    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2322      padding algorithm to use, or a list indicating the explicit paddings at
2323      the start and end of each dimension. See
2324      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
2325      for more information. When explicit padding is used and data_format is
2326      `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2327      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2328      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2329      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2330    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2331      Defaults to `"NHWC"`.
2332      Specify the data format of the input and output data. With the
2333      default format "NHWC", the data is stored in the order of:
2334          `batch_shape + [height, width, channels]`.
2335      Alternatively, the format could be "NCHW", the data storage order of:
2336          `batch_shape + [channels, height, width]`.
2337    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2338      defaults to 1. The dilation factor for each dimension of`input`. If a
2339      single value is given it is replicated in the `H` and `W` dimension. By
2340      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2341      will be k-1 skipped cells between each filter element on that dimension.
2342      The dimension order is determined by the value of `data_format`, see above
2343      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2344      must be 1.
2345    name: A name for the operation (optional).
2346
2347  Returns:
2348    A `Tensor`. Has the same type as `input` and the same outer batch shape.
2349  """
2350  # pylint: enable=line-too-long
2351  return conv2d(input,  # pylint: disable=redefined-builtin
2352                filters,
2353                strides,
2354                padding,
2355                use_cudnn_on_gpu=True,
2356                data_format=data_format,
2357                dilations=dilations,
2358                name=name)
2359
2360
2361@tf_export(v1=["nn.conv2d"])
2362@dispatch.add_dispatch_support
2363def conv2d(  # pylint: disable=redefined-builtin,dangerous-default-value
2364    input,
2365    filter=None,
2366    strides=None,
2367    padding=None,
2368    use_cudnn_on_gpu=True,
2369    data_format="NHWC",
2370    dilations=[1, 1, 1, 1],
2371    name=None,
2372    filters=None):
2373  r"""Computes a 2-D convolution given 4-D `input` and `filter` tensors.
2374
2375  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2376  and a filter / kernel tensor of shape
2377  `[filter_height, filter_width, in_channels, out_channels]`, this op
2378  performs the following:
2379
2380  1. Flattens the filter to a 2-D matrix with shape
2381     `[filter_height * filter_width * in_channels, output_channels]`.
2382  2. Extracts image patches from the input tensor to form a *virtual*
2383     tensor of shape `[batch, out_height, out_width,
2384     filter_height * filter_width * in_channels]`.
2385  3. For each patch, right-multiplies the filter matrix and the image patch
2386     vector.
2387
2388  In detail, with the default NHWC format,
2389
2390      output[b, i, j, k] =
2391          sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q]
2392                          * filter[di, dj, q, k]
2393
2394  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
2395  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
2396
2397  Args:
2398    input: A `Tensor`. Must be one of the following types:
2399      `half`, `bfloat16`, `float32`, `float64`.
2400      A 4-D tensor. The dimension order is interpreted according to the value
2401      of `data_format`, see below for details.
2402    filter: A `Tensor`. Must have the same type as `input`.
2403      A 4-D tensor of shape
2404      `[filter_height, filter_width, in_channels, out_channels]`
2405    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2406      stride of the sliding window for each dimension of `input`. If a single
2407      value is given it is replicated in the `H` and `W` dimension. By default
2408      the `N` and `C` dimensions are set to 1. The dimension order is determined
2409      by the value of `data_format`, see below for details.
2410    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2411      padding algorithm to use, or a list indicating the explicit paddings at
2412      the start and end of each dimension. When explicit padding is used and
2413      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2414      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2415      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2416      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2417    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2418    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2419      Defaults to `"NHWC"`.
2420      Specify the data format of the input and output data. With the
2421      default format "NHWC", the data is stored in the order of:
2422          [batch, height, width, channels].
2423      Alternatively, the format could be "NCHW", the data storage order of:
2424          [batch, channels, height, width].
2425    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2426      defaults to 1. The dilation factor for each dimension of`input`. If a
2427      single value is given it is replicated in the `H` and `W` dimension. By
2428      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2429      will be k-1 skipped cells between each filter element on that dimension.
2430      The dimension order is determined by the value of `data_format`, see above
2431      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2432      must be 1.
2433    name: A name for the operation (optional).
2434    filters: Alias for filter.
2435
2436  Returns:
2437    A `Tensor`. Has the same type as `input`.
2438  """
2439  filter = deprecation.deprecated_argument_lookup(
2440      "filters", filters, "filter", filter)
2441  padding, explicit_paddings = convert_padding(padding)
2442  if data_format is None:
2443    data_format = "NHWC"
2444  channel_index = 1 if data_format.startswith("NC") else 3
2445
2446  strides = _get_sequence(strides, 2, channel_index, "strides")
2447  dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2448
2449  shape = input.shape
2450  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
2451  # we fall back to len(shape).
2452  ndims = getattr(shape, "ndims", -1)
2453  if ndims == -1:
2454    ndims = len(shape)
2455  if ndims in (4, 3, 2, 1, 0, None):
2456    # We avoid calling squeeze_batch_dims to reduce extra python function
2457    # call slowdown in eager mode.  This branch doesn't require reshapes.
2458    return gen_nn_ops.conv2d(
2459        input,
2460        filter=filter,
2461        strides=strides,
2462        padding=padding,
2463        use_cudnn_on_gpu=use_cudnn_on_gpu,
2464        explicit_paddings=explicit_paddings,
2465        data_format=data_format,
2466        dilations=dilations,
2467        name=name)
2468  return squeeze_batch_dims(
2469      input,
2470      functools.partial(
2471          gen_nn_ops.conv2d,
2472          filter=filter,
2473          strides=strides,
2474          padding=padding,
2475          use_cudnn_on_gpu=use_cudnn_on_gpu,
2476          explicit_paddings=explicit_paddings,
2477          data_format=data_format,
2478          dilations=dilations),
2479      inner_rank=3,
2480      name=name)
2481
2482
2483@tf_export(v1=["nn.conv2d_backprop_filter"])
2484@dispatch.add_dispatch_support
2485def conv2d_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
2486    input,
2487    filter_sizes,
2488    out_backprop,
2489    strides,
2490    padding,
2491    use_cudnn_on_gpu=True,
2492    data_format="NHWC",
2493    dilations=[1, 1, 1, 1],
2494    name=None):
2495  r"""Computes the gradients of convolution with respect to the filter.
2496
2497  Args:
2498    input: A `Tensor`. Must be one of the following types:
2499      `half`, `bfloat16`, `float32`, `float64`.
2500      4-D with shape `[batch, in_height, in_width, in_channels]`.
2501    filter_sizes: A `Tensor` of type `int32`.
2502      An integer vector representing the tensor shape of `filter`,
2503      where `filter` is a 4-D
2504      `[filter_height, filter_width, in_channels, out_channels]` tensor.
2505    out_backprop: A `Tensor`. Must have the same type as `input`.
2506      4-D with shape `[batch, out_height, out_width, out_channels]`.
2507      Gradients w.r.t. the output of the convolution.
2508    strides: A list of `ints`.
2509      The stride of the sliding window for each dimension of the input
2510      of the convolution. Must be in the same order as the dimension specified
2511      with format.
2512    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2513      padding algorithm to use, or a list indicating the explicit paddings at
2514      the start and end of each dimension. When explicit padding is used and
2515      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2516      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2517      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2518      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2519    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2520    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2521      Defaults to `"NHWC"`.
2522      Specify the data format of the input and output data. With the
2523      default format "NHWC", the data is stored in the order of:
2524          [batch, in_height, in_width, in_channels].
2525      Alternatively, the format could be "NCHW", the data storage order of:
2526          [batch, in_channels, in_height, in_width].
2527    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2528      1-D tensor of length 4.  The dilation factor for each dimension of
2529      `input`. If set to k > 1, there will be k-1 skipped cells between each
2530      filter element on that dimension. The dimension order is determined by
2531      the value of `data_format`, see above for details. Dilations in the batch
2532      and depth dimensions must be 1.
2533    name: A name for the operation (optional).
2534
2535  Returns:
2536    A `Tensor`. Has the same type as `input`.
2537  """
2538  padding, explicit_paddings = convert_padding(padding)
2539  return gen_nn_ops.conv2d_backprop_filter(
2540      input, filter_sizes, out_backprop, strides, padding, use_cudnn_on_gpu,
2541      explicit_paddings, data_format, dilations, name)
2542
2543
2544@tf_export(v1=["nn.conv2d_backprop_input"])
2545@dispatch.add_dispatch_support
2546def conv2d_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
2547    input_sizes,
2548    filter=None,
2549    out_backprop=None,
2550    strides=None,
2551    padding=None,
2552    use_cudnn_on_gpu=True,
2553    data_format="NHWC",
2554    dilations=[1, 1, 1, 1],
2555    name=None,
2556    filters=None):
2557  r"""Computes the gradients of convolution with respect to the input.
2558
2559  Args:
2560    input_sizes: A `Tensor` of type `int32`.
2561      An integer vector representing the shape of `input`,
2562      where `input` is a 4-D `[batch, height, width, channels]` tensor.
2563    filter: A `Tensor`. Must be one of the following types:
2564      `half`, `bfloat16`, `float32`, `float64`.
2565      4-D with shape
2566      `[filter_height, filter_width, in_channels, out_channels]`.
2567    out_backprop: A `Tensor`. Must have the same type as `filter`.
2568      4-D with shape `[batch, out_height, out_width, out_channels]`.
2569      Gradients w.r.t. the output of the convolution.
2570    strides: A list of `ints`.
2571      The stride of the sliding window for each dimension of the input
2572      of the convolution. Must be in the same order as the dimension specified
2573      with format.
2574    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2575      padding algorithm to use, or a list indicating the explicit paddings at
2576      the start and end of each dimension. When explicit padding is used and
2577      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
2578      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
2579      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2580      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2581    use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
2582    data_format: An optional `string` from: `"NHWC", "NCHW"`.
2583      Defaults to `"NHWC"`.
2584      Specify the data format of the input and output data. With the
2585      default format "NHWC", the data is stored in the order of:
2586          [batch, in_height, in_width, in_channels].
2587      Alternatively, the format could be "NCHW", the data storage order of:
2588          [batch, in_channels, in_height, in_width].
2589    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
2590      1-D tensor of length 4.  The dilation factor for each dimension of
2591      `input`. If set to k > 1, there will be k-1 skipped cells between each
2592      filter element on that dimension. The dimension order is determined by
2593      the value of `data_format`, see above for details. Dilations in the batch
2594      and depth dimensions must be 1.
2595    name: A name for the operation (optional).
2596    filters: Alias for filter.
2597
2598  Returns:
2599    A `Tensor`. Has the same type as `filter`.
2600  """
2601  filter = deprecation.deprecated_argument_lookup(
2602      "filters", filters, "filter", filter)
2603  padding, explicit_paddings = convert_padding(padding)
2604  return gen_nn_ops.conv2d_backprop_input(
2605      input_sizes, filter, out_backprop, strides, padding, use_cudnn_on_gpu,
2606      explicit_paddings, data_format, dilations, name)
2607
2608
2609@tf_export(v1=["nn.conv2d_transpose"])
2610@dispatch.add_dispatch_support
2611def conv2d_transpose(
2612    value=None,
2613    filter=None,  # pylint: disable=redefined-builtin
2614    output_shape=None,
2615    strides=None,
2616    padding="SAME",
2617    data_format="NHWC",
2618    name=None,
2619    input=None,  # pylint: disable=redefined-builtin
2620    filters=None,
2621    dilations=None):
2622  """The transpose of `conv2d`.
2623
2624  This operation is sometimes called "deconvolution" after
2625  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv2d`
2626  rather than an actual deconvolution.
2627
2628  Args:
2629    value: A 4-D `Tensor` of type `float` and shape
2630      `[batch, height, width, in_channels]` for `NHWC` data format or
2631      `[batch, in_channels, height, width]` for `NCHW` data format.
2632    filter: A 4-D `Tensor` with the same type as `value` and shape
2633      `[height, width, output_channels, in_channels]`.  `filter`'s
2634      `in_channels` dimension must match that of `value`.
2635    output_shape: A 1-D `Tensor` representing the output shape of the
2636      deconvolution op.
2637    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2638      stride of the sliding window for each dimension of `input`. If a single
2639      value is given it is replicated in the `H` and `W` dimension. By default
2640      the `N` and `C` dimensions are set to 0. The dimension order is determined
2641      by the value of `data_format`, see below for details.
2642    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
2643      See the "returns" section of `tf.nn.convolution` for details.
2644    data_format: A string. 'NHWC' and 'NCHW' are supported.
2645    name: Optional name for the returned tensor.
2646    input: Alias for value.
2647    filters: Alias for filter.
2648    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2649      defaults to 1. The dilation factor for each dimension of`input`. If a
2650      single value is given it is replicated in the `H` and `W` dimension. By
2651      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2652      will be k-1 skipped cells between each filter element on that dimension.
2653      The dimension order is determined by the value of `data_format`, see above
2654      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2655      must be 1.
2656
2657  Returns:
2658    A `Tensor` with the same type as `value`.
2659
2660  Raises:
2661    ValueError: If input/output depth does not match `filter`'s shape, or if
2662      padding is other than `'VALID'` or `'SAME'`.
2663
2664  References:
2665    Deconvolutional Networks:
2666      [Zeiler et al., 2010]
2667      (https://ieeexplore.ieee.org/abstract/document/5539957)
2668      ([pdf]
2669      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2670  """
2671  value = deprecated_argument_lookup("input", input, "value", value)
2672  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
2673  with ops.name_scope(name, "conv2d_transpose",
2674                      [value, filter, output_shape]) as name:
2675    return conv2d_transpose_v2(
2676        value,
2677        filter,
2678        output_shape,
2679        strides,
2680        padding=padding,
2681        data_format=data_format,
2682        dilations=dilations,
2683        name=name)
2684
2685
2686@tf_export("nn.conv2d_transpose", v1=[])
2687@dispatch.add_dispatch_support
2688def conv2d_transpose_v2(
2689    input,  # pylint: disable=redefined-builtin
2690    filters,  # pylint: disable=redefined-builtin
2691    output_shape,
2692    strides,
2693    padding="SAME",
2694    data_format="NHWC",
2695    dilations=None,
2696    name=None):
2697  """The transpose of `conv2d`.
2698
2699  This operation is sometimes called "deconvolution" after
2700  (Zeiler et al., 2010), but is really the transpose (gradient) of
2701  `atrous_conv2d` rather than an actual deconvolution.
2702
2703  Args:
2704    input: A 4-D `Tensor` of type `float` and shape `[batch, height, width,
2705      in_channels]` for `NHWC` data format or `[batch, in_channels, height,
2706      width]` for `NCHW` data format.
2707    filters: A 4-D `Tensor` with the same type as `input` and shape `[height,
2708      width, output_channels, in_channels]`.  `filter`'s `in_channels` dimension
2709      must match that of `input`.
2710    output_shape: A 1-D `Tensor` representing the output shape of the
2711      deconvolution op.
2712    strides: An int or list of `ints` that has length `1`, `2` or `4`.  The
2713      stride of the sliding window for each dimension of `input`. If a single
2714      value is given it is replicated in the `H` and `W` dimension. By default
2715      the `N` and `C` dimensions are set to 0. The dimension order is determined
2716      by the value of `data_format`, see below for details.
2717    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
2718      padding algorithm to use, or a list indicating the explicit paddings at
2719      the start and end of each dimension. See
2720      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
2721      for more information.  When explicit padding is used and data_format is
2722      `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
2723      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
2724      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
2725      [pad_top, pad_bottom], [pad_left, pad_right]]`.
2726    data_format: A string. 'NHWC' and 'NCHW' are supported.
2727    dilations: An int or list of `ints` that has length `1`, `2` or `4`,
2728      defaults to 1. The dilation factor for each dimension of`input`. If a
2729      single value is given it is replicated in the `H` and `W` dimension. By
2730      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
2731      will be k-1 skipped cells between each filter element on that dimension.
2732      The dimension order is determined by the value of `data_format`, see above
2733      for details. Dilations in the batch and depth dimensions if a 4-d tensor
2734      must be 1.
2735    name: Optional name for the returned tensor.
2736
2737  Returns:
2738    A `Tensor` with the same type as `input`.
2739
2740  Raises:
2741    ValueError: If input/output depth does not match `filter`'s shape, or if
2742      padding is other than `'VALID'` or `'SAME'`.
2743
2744  References:
2745    Deconvolutional Networks:
2746      [Zeiler et al., 2010]
2747      (https://ieeexplore.ieee.org/abstract/document/5539957)
2748      ([pdf]
2749      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2750  """
2751  with ops.name_scope(name, "conv2d_transpose",
2752                      [input, filter, output_shape]) as name:
2753    if data_format is None:
2754      data_format = "NHWC"
2755    channel_index = 1 if data_format.startswith("NC") else 3
2756
2757    strides = _get_sequence(strides, 2, channel_index, "strides")
2758    dilations = _get_sequence(dilations, 2, channel_index, "dilations")
2759    padding, explicit_paddings = convert_padding(padding)
2760
2761    return gen_nn_ops.conv2d_backprop_input(
2762        input_sizes=output_shape,
2763        filter=filters,
2764        out_backprop=input,
2765        strides=strides,
2766        padding=padding,
2767        explicit_paddings=explicit_paddings,
2768        data_format=data_format,
2769        dilations=dilations,
2770        name=name)
2771
2772
2773def _conv2d_expanded_batch(
2774    input,  # pylint: disable=redefined-builtin
2775    filters,
2776    strides,
2777    padding,
2778    data_format,
2779    dilations,
2780    name):
2781  """Helper function for `convolution_internal`; handles expanded batches."""
2782  # Try really hard to avoid modifying the legacy name scopes - return early.
2783  input_rank = input.shape.rank
2784  if input_rank is None or input_rank < 5:
2785    # We avoid calling squeeze_batch_dims to reduce extra python function
2786    # call slowdown in eager mode.  This branch doesn't require reshapes.
2787    return gen_nn_ops.conv2d(
2788        input,
2789        filter=filters,
2790        strides=strides,
2791        padding=padding,
2792        data_format=data_format,
2793        dilations=dilations,
2794        name=name)
2795  return squeeze_batch_dims(
2796      input,
2797      functools.partial(
2798          gen_nn_ops.conv2d,
2799          filter=filters,
2800          strides=strides,
2801          padding=padding,
2802          data_format=data_format,
2803          dilations=dilations),
2804      inner_rank=3,
2805      name=name)
2806
2807
2808@tf_export("nn.atrous_conv2d_transpose")
2809@dispatch.add_dispatch_support
2810def atrous_conv2d_transpose(value,
2811                            filters,
2812                            output_shape,
2813                            rate,
2814                            padding,
2815                            name=None):
2816  """The transpose of `atrous_conv2d`.
2817
2818  This operation is sometimes called "deconvolution" after
2819  (Zeiler et al., 2010), but is really the transpose (gradient) of
2820  `atrous_conv2d` rather than an actual deconvolution.
2821
2822  Args:
2823    value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
2824      format. Its shape is `[batch, in_height, in_width, in_channels]`.
2825    filters: A 4-D `Tensor` with the same type as `value` and shape
2826      `[filter_height, filter_width, out_channels, in_channels]`. `filters`'
2827      `in_channels` dimension must match that of `value`. Atrous convolution is
2828      equivalent to standard convolution with upsampled filters with effective
2829      height `filter_height + (filter_height - 1) * (rate - 1)` and effective
2830      width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
2831      inserting `rate - 1` zeros along consecutive elements across the
2832      `filters`' spatial dimensions.
2833    output_shape: A 1-D `Tensor` of shape representing the output shape of the
2834      deconvolution op, of form `[batch, out_height, out_width, out_channels]`.
2835    rate: A positive int32. The stride with which we sample input values across
2836      the `height` and `width` dimensions. Equivalently, the rate by which we
2837      upsample the filter values by inserting zeros across the `height` and
2838      `width` dimensions. In the literature, the same parameter is sometimes
2839      called `input stride` or `dilation`.
2840    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
2841      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
2842      for more information.
2843    name: Optional name for the returned tensor.
2844
2845  Returns:
2846    A `Tensor` with the same type as `value`.
2847
2848  Raises:
2849    ValueError: If input/output depth does not match `filters`' shape, or if
2850      padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
2851      than one, or if the output_shape is not a tensor with 4 elements.
2852
2853  References:
2854    Deconvolutional Networks:
2855      [Zeiler et al., 2010]
2856      (https://ieeexplore.ieee.org/abstract/document/5539957)
2857      ([pdf]
2858      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
2859  """
2860  with ops.name_scope(name, "atrous_conv2d_transpose",
2861                      [value, filters, output_shape]) as name:
2862    value = ops.convert_to_tensor(value, name="value")
2863    filters = ops.convert_to_tensor(filters, name="filters")
2864    if not value.get_shape().dims[3].is_compatible_with(filters.get_shape()[3]):
2865      raise ValueError(
2866          "`value` channel count must be compatible with `filters` input "
2867          f"channel count. Received: value.shape={value.get_shape()} with "
2868          f"channel count {value.get_shape()[3]} and "
2869          f"filters.shape={filters.get_shape()} with input channel count "
2870          f"{filters.get_shape()[3]}.")
2871    if rate < 1:
2872      raise ValueError(f"`rate` cannot be less than one. Received: rate={rate}")
2873
2874    if rate == 1:
2875      return conv2d_transpose(
2876          value,
2877          filters,
2878          output_shape,
2879          strides=[1, 1, 1, 1],
2880          padding=padding,
2881          data_format="NHWC")
2882
2883    output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
2884    if not output_shape_.get_shape().is_compatible_with(
2885        tensor_shape.TensorShape([4])):
2886      raise ValueError("`output_shape` must have shape (4,). "
2887                       f"Received: output_shape={output_shape_.get_shape()}")
2888
2889    if isinstance(output_shape, tuple):
2890      output_shape = list(output_shape)
2891
2892    if isinstance(output_shape, (list, np.ndarray)):
2893      # output_shape's shape should be == [4] if reached this point.
2894      if not filters.get_shape().dims[2].is_compatible_with(output_shape[3]):
2895        raise ValueError(
2896            "`output_shape` channel count must be compatible with `filters` "
2897            f"output channel count. Received: output_shape={output_shape} with "
2898            f"channel count {output_shape[3]} and "
2899            f"filters.shape={filters.get_shape()} with output channel count "
2900            f"{filters.get_shape()[3]}.")
2901
2902    # We have two padding contributions. The first is used for converting "SAME"
2903    # to "VALID". The second is required so that the height and width of the
2904    # zero-padded value tensor are multiples of rate.
2905
2906    # Padding required to reduce to "VALID" convolution
2907    if padding == "SAME":
2908      # Handle filters whose shape is unknown during graph creation.
2909      if filters.get_shape().is_fully_defined():
2910        filter_shape = filters.get_shape().as_list()
2911      else:
2912        filter_shape = array_ops.shape(filters)
2913      filter_height, filter_width = filter_shape[0], filter_shape[1]
2914
2915      # Spatial dimensions of the filters and the upsampled filters in which we
2916      # introduce (rate - 1) zeros between consecutive filter values.
2917      filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
2918      filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
2919
2920      pad_height = filter_height_up - 1
2921      pad_width = filter_width_up - 1
2922
2923      # When pad_height (pad_width) is odd, we pad more to bottom (right),
2924      # following the same convention as conv2d().
2925      pad_top = pad_height // 2
2926      pad_bottom = pad_height - pad_top
2927      pad_left = pad_width // 2
2928      pad_right = pad_width - pad_left
2929    elif padding == "VALID":
2930      pad_top = 0
2931      pad_bottom = 0
2932      pad_left = 0
2933      pad_right = 0
2934    else:
2935      raise ValueError("`padding` must be either 'VALID' or 'SAME'. "
2936                       f"Received: padding={padding}")
2937
2938    in_height = output_shape[1] + pad_top + pad_bottom
2939    in_width = output_shape[2] + pad_left + pad_right
2940
2941    # More padding so that rate divides the height and width of the input.
2942    pad_bottom_extra = (rate - in_height % rate) % rate
2943    pad_right_extra = (rate - in_width % rate) % rate
2944
2945    # The paddings argument to space_to_batch is just the extra padding
2946    # component.
2947    space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]]
2948
2949    value = array_ops.space_to_batch(
2950        input=value, paddings=space_to_batch_pad, block_size=rate)
2951
2952    input_sizes = [
2953        rate * rate * output_shape[0], (in_height + pad_bottom_extra) // rate,
2954        (in_width + pad_right_extra) // rate, output_shape[3]
2955    ]
2956
2957    value = gen_nn_ops.conv2d_backprop_input(
2958        input_sizes=input_sizes,
2959        filter=filters,
2960        out_backprop=value,
2961        strides=[1, 1, 1, 1],
2962        padding="VALID",
2963        data_format="NHWC")
2964
2965    # The crops argument to batch_to_space includes both padding components.
2966    batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
2967                           [pad_left, pad_right + pad_right_extra]]
2968
2969    return array_ops.batch_to_space(
2970        input=value, crops=batch_to_space_crop, block_size=rate)
2971
2972
2973@tf_export(v1=["nn.depthwise_conv2d_native"])
2974@dispatch.add_dispatch_support
2975@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
2976def depthwise_conv2d_native(  # pylint: disable=redefined-builtin,dangerous-default-value
2977    input,
2978    filter,
2979    strides,
2980    padding,
2981    data_format="NHWC",
2982    dilations=[1, 1, 1, 1],
2983    name=None):
2984  r"""Computes a 2-D depthwise convolution.
2985
2986  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
2987  and a filter / kernel tensor of shape
2988  `[filter_height, filter_width, in_channels, channel_multiplier]`, containing
2989  `in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
2990  a different filter to each input channel (expanding from 1 channel to
2991  `channel_multiplier` channels for each), then concatenates the results
2992  together. Thus, the output has `in_channels * channel_multiplier` channels.
2993
2994  ```
2995  for k in 0..in_channels-1
2996    for q in 0..channel_multiplier-1
2997      output[b, i, j, k * channel_multiplier + q] =
2998        sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
2999                          filter[di, dj, k, q]
3000  ```
3001
3002  Must have `strides[0] = strides[3] = 1`.  For the most common case of the same
3003  horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
3004
3005  Args:
3006    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
3007      `float32`, `float64`.
3008    filter: A `Tensor`. Must have the same type as `input`.
3009    strides: A list of `ints`. 1-D of length 4.  The stride of the sliding
3010      window for each dimension of `input`.
3011    padding: Controls how to pad the image before applying the convolution. Can
3012      be the string `"SAME"` or `"VALID"` indicating the type of padding
3013      algorithm to use, or a list indicating the explicit paddings at the start
3014      and end of each dimension. When explicit padding is used and data_format
3015      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
3016      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
3017      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
3018      [pad_top, pad_bottom], [pad_left, pad_right]]`.
3019    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
3020      `"NHWC"`. Specify the data format of the input and output data. With the
3021      default format "NHWC", the data is stored in the order of: [batch, height,
3022        width, channels].
3023      Alternatively, the format could be "NCHW", the data storage order of:
3024        [batch, channels, height, width].
3025    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
3026      tensor of length 4.  The dilation factor for each dimension of `input`. If
3027      set to k > 1, there will be k-1 skipped cells between each filter element
3028      on that dimension. The dimension order is determined by the value of
3029      `data_format`, see above for details. Dilations in the batch and depth
3030      dimensions must be 1.
3031    name: A name for the operation (optional).
3032
3033  Returns:
3034    A `Tensor`. Has the same type as `input`.
3035  """
3036  padding, explicit_paddings = convert_padding(padding)
3037  return gen_nn_ops.depthwise_conv2d_native(
3038      input,
3039      filter,
3040      strides,
3041      padding,
3042      explicit_paddings=explicit_paddings,
3043      data_format=data_format,
3044      dilations=dilations,
3045      name=name)
3046
3047
3048@tf_export(
3049    "nn.depthwise_conv2d_backprop_input",
3050    v1=[
3051        "nn.depthwise_conv2d_native_backprop_input",
3052        "nn.depthwise_conv2d_backprop_input"
3053    ])
3054@dispatch.add_dispatch_support
3055@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
3056def depthwise_conv2d_native_backprop_input(  # pylint: disable=redefined-builtin,dangerous-default-value
3057    input_sizes,
3058    filter,
3059    out_backprop,
3060    strides,
3061    padding,
3062    data_format="NHWC",
3063    dilations=[1, 1, 1, 1],
3064    name=None):
3065  r"""Computes the gradients of depthwise convolution with respect to the input.
3066
3067  Args:
3068    input_sizes: A `Tensor` of type `int32`. An integer vector representing the
3069      shape of `input`, based on `data_format`.  For example, if `data_format`
3070      is 'NHWC' then `input` is a 4-D `[batch, height, width, channels]` tensor.
3071    filter: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
3072      `float32`, `float64`. 4-D with shape `[filter_height, filter_width,
3073      in_channels, depthwise_multiplier]`.
3074    out_backprop: A `Tensor`. Must have the same type as `filter`. 4-D with
3075      shape  based on `data_format`. For example, if `data_format` is 'NHWC'
3076      then out_backprop shape is `[batch, out_height, out_width, out_channels]`.
3077      Gradients w.r.t. the output of the convolution.
3078    strides: A list of `ints`. The stride of the sliding window for each
3079      dimension of the input of the convolution.
3080    padding: Controls how to pad the image before applying the convolution. Can
3081      be the string `"SAME"` or `"VALID"` indicating the type of padding
3082      algorithm to use, or a list indicating the explicit paddings at the start
3083      and end of each dimension. See
3084      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
3085      for more information. When explicit padding is used and data_format is
3086      `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
3087      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
3088      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
3089      [pad_top, pad_bottom], [pad_left, pad_right]]`.
3090    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
3091      `"NHWC"`. Specify the data format of the input and output data. With the
3092      default format "NHWC", the data is stored in the order of: [batch, height,
3093        width, channels].
3094      Alternatively, the format could be "NCHW", the data storage order of:
3095        [batch, channels, height, width].
3096    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
3097      tensor of length 4.  The dilation factor for each dimension of `input`. If
3098      set to k > 1, there will be k-1 skipped cells between each filter element
3099      on that dimension. The dimension order is determined by the value of
3100      `data_format`, see above for details. Dilations in the batch and depth
3101      dimensions must be 1.
3102    name: A name for the operation (optional).
3103
3104  Returns:
3105    A `Tensor`. Has the same type as `filter`.
3106  """
3107  padding, explicit_paddings = convert_padding(padding)
3108  return gen_nn_ops.depthwise_conv2d_native_backprop_input(
3109      input_sizes,
3110      filter,
3111      out_backprop,
3112      strides,
3113      padding,
3114      explicit_paddings=explicit_paddings,
3115      data_format=data_format,
3116      dilations=dilations,
3117      name=name)
3118
3119
3120@tf_export(
3121    "nn.depthwise_conv2d_backprop_filter",
3122    v1=[
3123        "nn.depthwise_conv2d_native_backprop_filter",
3124        "nn.depthwise_conv2d_backprop_filter"
3125    ])
3126@dispatch.add_dispatch_support
3127@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
3128def depthwise_conv2d_native_backprop_filter(  # pylint: disable=redefined-builtin,dangerous-default-value
3129    input,
3130    filter_sizes,
3131    out_backprop,
3132    strides,
3133    padding,
3134    data_format="NHWC",
3135    dilations=[1, 1, 1, 1],
3136    name=None):
3137  r"""Computes the gradients of depthwise convolution with respect to the filter.
3138
3139  Args:
3140    input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
3141      `float32`, `float64`. 4-D with shape based on `data_format`.  For example,
3142      if `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
3143      in_width, in_channels]` tensor.
3144    filter_sizes: A `Tensor` of type `int32`. An integer vector representing the
3145      tensor shape of `filter`, where `filter` is a 4-D `[filter_height,
3146      filter_width, in_channels, depthwise_multiplier]` tensor.
3147    out_backprop: A `Tensor`. Must have the same type as `input`. 4-D with shape
3148      based on `data_format`. For example, if `data_format` is 'NHWC' then
3149      out_backprop shape is `[batch, out_height, out_width, out_channels]`.
3150      Gradients w.r.t. the output of the convolution.
3151    strides: A list of `ints`. The stride of the sliding window for each
3152      dimension of the input of the convolution.
3153    padding: Controls how to pad the image before applying the convolution. Can
3154      be the string `"SAME"` or `"VALID"` indicating the type of padding
3155      algorithm to use, or a list indicating the explicit paddings at the start
3156      and end of each dimension. See
3157      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
3158      for more information. When explicit padding is used and data_format is
3159      `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
3160      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
3161      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
3162      [pad_top, pad_bottom], [pad_left, pad_right]]`.
3163    data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
3164      `"NHWC"`. Specify the data format of the input and output data. With the
3165      default format "NHWC", the data is stored in the order of: [batch, height,
3166        width, channels].
3167      Alternatively, the format could be "NCHW", the data storage order of:
3168        [batch, channels, height, width].
3169    dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
3170      tensor of length 4.  The dilation factor for each dimension of `input`. If
3171      set to k > 1, there will be k-1 skipped cells between each filter element
3172      on that dimension. The dimension order is determined by the value of
3173      `data_format`, see above for details. Dilations in the batch and depth
3174      dimensions must be 1.
3175    name: A name for the operation (optional).
3176
3177  Returns:
3178    A `Tensor`. Has the same type as `input`.
3179  """
3180  padding, explicit_paddings = convert_padding(padding)
3181  return gen_nn_ops.depthwise_conv2d_native_backprop_filter(
3182      input,
3183      filter_sizes,
3184      out_backprop,
3185      strides,
3186      padding,
3187      explicit_paddings=explicit_paddings,
3188      data_format=data_format,
3189      dilations=dilations,
3190      name=name)
3191
3192
3193def _conv3d_expanded_batch(
3194    input,  # pylint: disable=redefined-builtin
3195    filter,  # pylint: disable=redefined-builtin
3196    strides,
3197    padding,
3198    data_format,
3199    dilations=None,
3200    name=None):
3201  """Helper function for `conv3d`; handles expanded batches."""
3202  shape = input.shape
3203  # shape object may lack ndims, e.g., if input is an np.ndarray.  In that case,
3204  # we fall back to len(shape).
3205  ndims = getattr(shape, "ndims", -1)
3206  if ndims == -1:
3207    ndims = len(shape)
3208  if ndims in (5, 4, 3, 2, 1, 0, None):
3209    # We avoid calling squeeze_batch_dims to reduce extra python function
3210    # call slowdown in eager mode.  This branch doesn't require reshapes.
3211    return gen_nn_ops.conv3d(
3212        input,
3213        filter,
3214        strides,
3215        padding,
3216        data_format=data_format,
3217        dilations=dilations,
3218        name=name)
3219  else:
3220    return squeeze_batch_dims(
3221        input,
3222        functools.partial(
3223            gen_nn_ops.conv3d,
3224            filter=filter,
3225            strides=strides,
3226            padding=padding,
3227            data_format=data_format,
3228            dilations=dilations),
3229        inner_rank=4,
3230        name=name)
3231
3232
3233@tf_export("nn.conv3d", v1=[])
3234@dispatch.add_dispatch_support
3235def conv3d_v2(input,  # pylint: disable=redefined-builtin,missing-docstring
3236              filters,
3237              strides,
3238              padding,
3239              data_format="NDHWC",
3240              dilations=None,
3241              name=None):
3242  if dilations is None:
3243    dilations = [1, 1, 1, 1, 1]
3244  return _conv3d_expanded_batch(input, filters, strides, padding, data_format,
3245                                dilations, name)
3246
3247
3248@tf_export(v1=["nn.conv3d"])
3249@dispatch.add_dispatch_support
3250def conv3d_v1(  # pylint: disable=missing-docstring,dangerous-default-value
3251    input,  # pylint: disable=redefined-builtin
3252    filter=None,  # pylint: disable=redefined-builtin
3253    strides=None,
3254    padding=None,
3255    data_format="NDHWC",
3256    dilations=[1, 1, 1, 1, 1],
3257    name=None,
3258    filters=None):
3259  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3260  return gen_nn_ops.conv3d(
3261      input, filter, strides, padding, data_format, dilations, name)
3262
3263
3264conv3d_v2.__doc__ = deprecation.rewrite_argument_docstring(
3265    gen_nn_ops.conv3d.__doc__, "filter", "filters")
3266conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__
3267
3268
3269@tf_export(v1=["nn.conv3d_transpose"])
3270@dispatch.add_dispatch_support
3271def conv3d_transpose(
3272    value,
3273    filter=None,  # pylint: disable=redefined-builtin
3274    output_shape=None,
3275    strides=None,
3276    padding="SAME",
3277    data_format="NDHWC",
3278    name=None,
3279    input=None,  # pylint: disable=redefined-builtin
3280    filters=None,
3281    dilations=None):
3282  """The transpose of `conv3d`.
3283
3284  This operation is sometimes called "deconvolution" after
3285  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3286  rather than an actual deconvolution.
3287
3288  Args:
3289    value: A 5-D `Tensor` of type `float` and shape
3290      `[batch, depth, height, width, in_channels]`.
3291    filter: A 5-D `Tensor` with the same type as `value` and shape
3292      `[depth, height, width, output_channels, in_channels]`.  `filter`'s
3293      `in_channels` dimension must match that of `value`.
3294    output_shape: A 1-D `Tensor` representing the output shape of the
3295      deconvolution op.
3296    strides: A list of ints. The stride of the sliding window for each
3297      dimension of the input tensor.
3298    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
3299      See the "returns" section of `tf.nn.convolution` for details.
3300    data_format: A string, either `'NDHWC'` or `'NCDHW`' specifying the layout
3301      of the input and output tensors. Defaults to `'NDHWC'`.
3302    name: Optional name for the returned tensor.
3303    input: Alias of value.
3304    filters: Alias of filter.
3305    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3306      defaults to 1. The dilation factor for each dimension of`input`. If a
3307      single value is given it is replicated in the `D`, `H` and `W` dimension.
3308      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3309      will be k-1 skipped cells between each filter element on that dimension.
3310      The dimension order is determined by the value of `data_format`, see above
3311      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3312      must be 1.
3313
3314  Returns:
3315    A `Tensor` with the same type as `value`.
3316
3317  Raises:
3318    ValueError: If input/output depth does not match `filter`'s shape, or if
3319      padding is other than `'VALID'` or `'SAME'`.
3320
3321  References:
3322    Deconvolutional Networks:
3323      [Zeiler et al., 2010]
3324      (https://ieeexplore.ieee.org/abstract/document/5539957)
3325      ([pdf]
3326      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3327  """
3328  filter = deprecated_argument_lookup("filters", filters, "filter", filter)
3329  value = deprecated_argument_lookup("input", input, "value", value)
3330  return conv3d_transpose_v2(
3331      value,
3332      filter,
3333      output_shape,
3334      strides,
3335      padding=padding,
3336      data_format=data_format,
3337      dilations=dilations,
3338      name=name)
3339
3340
3341@tf_export("nn.conv3d_transpose", v1=[])
3342@dispatch.add_dispatch_support
3343def conv3d_transpose_v2(input,  # pylint: disable=redefined-builtin
3344                        filters,
3345                        output_shape,
3346                        strides,
3347                        padding="SAME",
3348                        data_format="NDHWC",
3349                        dilations=None,
3350                        name=None):
3351  """The transpose of `conv3d`.
3352
3353  This operation is sometimes called "deconvolution" after
3354  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3355  rather than an actual deconvolution.
3356
3357  Args:
3358    input: A 5-D `Tensor` of type `float` and shape `[batch, depth, height,
3359      width, in_channels]` for `NDHWC` data format or `[batch, in_channels,
3360      depth, height, width]` for `NCDHW` data format.
3361    filters: A 5-D `Tensor` with the same type as `input` and shape `[depth,
3362      height, width, output_channels, in_channels]`.  `filter`'s `in_channels`
3363      dimension must match that of `input`.
3364    output_shape: A 1-D `Tensor` representing the output shape of the
3365      deconvolution op.
3366    strides: An int or list of `ints` that has length `1`, `3` or `5`.  The
3367      stride of the sliding window for each dimension of `input`. If a single
3368      value is given it is replicated in the `D`, `H` and `W` dimension. By
3369      default the `N` and `C` dimensions are set to 0. The dimension order is
3370      determined by the value of `data_format`, see below for details.
3371    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3372      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
3373      for more information.
3374    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
3375    dilations: An int or list of `ints` that has length `1`, `3` or `5`,
3376      defaults to 1. The dilation factor for each dimension of`input`. If a
3377      single value is given it is replicated in the `D`, `H` and `W` dimension.
3378      By default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3379      will be k-1 skipped cells between each filter element on that dimension.
3380      The dimension order is determined by the value of `data_format`, see above
3381      for details. Dilations in the batch and depth dimensions if a 5-d tensor
3382      must be 1.
3383    name: Optional name for the returned tensor.
3384
3385  Returns:
3386    A `Tensor` with the same type as `input`.
3387
3388  References:
3389    Deconvolutional Networks:
3390      [Zeiler et al., 2010]
3391      (https://ieeexplore.ieee.org/abstract/document/5539957)
3392      ([pdf]
3393      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3394  """
3395  with ops.name_scope(name, "conv3d_transpose",
3396                      [input, filter, output_shape]) as name:
3397    if data_format is None:
3398      data_format = "NDHWC"
3399    channel_index = 1 if data_format.startswith("NC") else 4
3400
3401    strides = _get_sequence(strides, 3, channel_index, "strides")
3402    dilations = _get_sequence(dilations, 3, channel_index, "dilations")
3403
3404    return gen_nn_ops.conv3d_backprop_input_v2(
3405        input_sizes=output_shape,
3406        filter=filters,
3407        out_backprop=input,
3408        strides=strides,
3409        padding=padding,
3410        data_format=data_format,
3411        dilations=dilations,
3412        name=name)
3413
3414
3415CONV_TRANSPOSE_OPS = (
3416    conv1d_transpose,
3417    conv2d_transpose_v2,
3418    conv3d_transpose_v2,
3419)
3420
3421
3422@tf_export("nn.conv_transpose")
3423@dispatch.add_dispatch_support
3424def conv_transpose(input,  # pylint: disable=redefined-builtin
3425                   filters,
3426                   output_shape,
3427                   strides,
3428                   padding="SAME",
3429                   data_format=None,
3430                   dilations=None,
3431                   name=None):
3432  """The transpose of `convolution`.
3433
3434  This operation is sometimes called "deconvolution" after
3435  (Zeiler et al., 2010), but is really the transpose (gradient) of `conv3d`
3436  rather than an actual deconvolution.
3437
3438  Args:
3439    input: An N+2 dimensional `Tensor` of shape
3440      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
3441      not start with "NC" (default), or
3442      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
3443      with "NC". It must be one of the following types:
3444      `half`, `bfloat16`, `float32`, `float64`.
3445    filters: An N+2 dimensional `Tensor` with the same type as `input` and
3446      shape `spatial_filter_shape + [in_channels, out_channels]`.
3447    output_shape: A 1-D `Tensor` representing the output shape of the
3448      deconvolution op.
3449    strides: An int or list of `ints` that has length `1`, `N` or `N+2`.  The
3450      stride of the sliding window for each dimension of `input`. If a single
3451      value is given it is replicated in the spatial dimensions. By default
3452      the `N` and `C` dimensions are set to 0. The dimension order is determined
3453      by the value of `data_format`, see below for details.
3454    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
3455      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
3456      for more information.
3457    data_format: A string or None.  Specifies whether the channel dimension of
3458      the `input` and output is the last dimension (default, or if `data_format`
3459      does not start with "NC"), or the second dimension (if `data_format`
3460      starts with "NC").  For N=1, the valid values are "NWC" (default) and
3461      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
3462      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
3463    dilations: An int or list of `ints` that has length `1`, `N` or `N+2`,
3464      defaults to 1. The dilation factor for each dimension of`input`. If a
3465      single value is given it is replicated in the spatial dimensions. By
3466      default the `N` and `C` dimensions are set to 1. If set to k > 1, there
3467      will be k-1 skipped cells between each filter element on that dimension.
3468      The dimension order is determined by the value of `data_format`, see above
3469      for details.
3470    name: A name for the operation (optional). If not specified "conv_transpose"
3471      is used.
3472
3473  Returns:
3474    A `Tensor` with the same type as `value`.
3475
3476  References:
3477    Deconvolutional Networks:
3478      [Zeiler et al., 2010]
3479      (https://ieeexplore.ieee.org/abstract/document/5539957)
3480      ([pdf]
3481      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.232.4023&rep=rep1&type=pdf))
3482  """
3483  with ops.name_scope(name, "conv_transpose",
3484                      [input, filter, output_shape]) as name:
3485    if tensor_util.is_tf_type(output_shape):
3486      n = output_shape.shape[0] - 2
3487    elif isinstance(output_shape, collections_abc.Sized):
3488      n = len(output_shape) - 2
3489    else:
3490      raise ValueError("`output_shape` must be a tensor or sized collection. "
3491                       f"Received: output_shape={output_shape}")
3492
3493    if not 1 <= n <= 3:
3494      raise ValueError(
3495          f"`output_shape` must be of length 3, 4 or 5. "
3496          f"Received: output_shape={output_shape} of length {n + 2}.")
3497
3498    op = CONV_TRANSPOSE_OPS[n-1]
3499    return op(
3500        input,
3501        filters,
3502        output_shape,
3503        strides,
3504        padding=padding,
3505        data_format=data_format,
3506        dilations=dilations,
3507        name=name)
3508
3509
3510@tf_export("nn.bias_add")
3511@dispatch.add_dispatch_support
3512def bias_add(value, bias, data_format=None, name=None):
3513  """Adds `bias` to `value`.
3514
3515  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3516  Broadcasting is supported, so `value` may have any number of dimensions.
3517  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3518  case where both types are quantized.
3519
3520  Args:
3521    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3522      `int16`, `int8`, `complex64`, or `complex128`.
3523    bias: A 1-D `Tensor` with size matching the channel dimension of `value`.
3524      Must be the same type as `value` unless `value` is a quantized type,
3525      in which case a different quantized type may be used.
3526    data_format: A string. 'N...C' and 'NC...' are supported. If `None` (the
3527      default) is specified then 'N..C' is assumed.
3528    name: A name for the operation (optional).
3529
3530  Returns:
3531    A `Tensor` with the same type as `value`.
3532
3533  Raises:
3534    ValueError if data format is unrecognized, if `value` has less than two
3535    dimensions when `data_format` is 'N..C'/`None` or `value` has less
3536    then three dimensions when `data_format` is `NC..`, if `bias` does not
3537    have exactly one dimension (is a vector), or if the size of `bias`
3538    does not match the size of the channel dimension of `value`.
3539  """
3540  with ops.name_scope(name, "BiasAdd", [value, bias]) as name:
3541    if data_format is not None:
3542      if data_format.startswith("NC"):
3543        data_format = "NCHW"
3544      elif data_format.startswith("N") and data_format.endswith("C"):
3545        data_format = "NHWC"
3546      else:
3547        raise ValueError("`data_format` must be of the form `N...C` or "
3548                         f"`NC...`. Received: data_format={data_format}")
3549
3550    if not context.executing_eagerly():
3551      value = ops.convert_to_tensor(value, name="input")
3552      bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3553
3554    return gen_nn_ops.bias_add(value, bias, data_format=data_format, name=name)
3555
3556
3557def bias_add_v1(value, bias, name=None):
3558  """Adds `bias` to `value`.
3559
3560  This is a deprecated version of bias_add and will soon to be removed.
3561
3562  This is (mostly) a special case of `tf.add` where `bias` is restricted to 1-D.
3563  Broadcasting is supported, so `value` may have any number of dimensions.
3564  Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
3565  case where both types are quantized.
3566
3567  Args:
3568    value: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
3569      `int16`, `int8`, `complex64`, or `complex128`.
3570    bias: A 1-D `Tensor` with size matching the last dimension of `value`.
3571      Must be the same type as `value` unless `value` is a quantized type,
3572      in which case a different quantized type may be used.
3573    name: A name for the operation (optional).
3574
3575  Returns:
3576    A `Tensor` with the same type as `value`.
3577  """
3578  with ops.name_scope(name, "BiasAddV1", [value, bias]) as name:
3579    value = ops.convert_to_tensor(value, name="input")
3580    bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias")
3581    return gen_nn_ops.bias_add_v1(value, bias, name=name)
3582
3583
3584@tf_export(v1=["nn.crelu"])
3585@dispatch.add_dispatch_support
3586def crelu(features, name=None, axis=-1):
3587  """Computes Concatenated ReLU.
3588
3589  Concatenates a ReLU which selects only the positive part of the activation
3590  with a ReLU which selects only the *negative* part of the activation.
3591  Note that as a result this non-linearity doubles the depth of the activations.
3592  Source: [Understanding and Improving Convolutional Neural Networks via
3593  Concatenated Rectified Linear Units. W. Shang, et
3594  al.](https://arxiv.org/abs/1603.05201)
3595
3596  Args:
3597    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3598      `int16`, or `int8`.
3599    name: A name for the operation (optional).
3600    axis: The axis that the output values are concatenated along. Default is -1.
3601
3602  Returns:
3603    A `Tensor` with the same type as `features`.
3604
3605  References:
3606    Understanding and Improving Convolutional Neural Networks via Concatenated
3607    Rectified Linear Units:
3608      [Shang et al., 2016](http://proceedings.mlr.press/v48/shang16)
3609      ([pdf](http://proceedings.mlr.press/v48/shang16.pdf))
3610  """
3611  with ops.name_scope(name, "CRelu", [features]) as name:
3612    features = ops.convert_to_tensor(features, name="features")
3613    c = array_ops.concat([features, -features], axis, name=name)  # pylint: disable=invalid-unary-operand-type
3614    return gen_nn_ops.relu(c)
3615
3616
3617@tf_export("nn.crelu", v1=[])
3618@dispatch.add_dispatch_support
3619def crelu_v2(features, axis=-1, name=None):
3620  return crelu(features, name=name, axis=axis)
3621crelu_v2.__doc__ = crelu.__doc__
3622
3623
3624@tf_export("nn.relu6")
3625@dispatch.register_unary_elementwise_api
3626@dispatch.add_dispatch_support
3627def relu6(features, name=None):
3628  """Computes Rectified Linear 6: `min(max(features, 0), 6)`.
3629
3630  In comparison with `tf.nn.relu`, relu6 activation functions have shown to
3631  empirically perform better under low-precision conditions (e.g. fixed point
3632  inference) by encouraging the model to learn sparse features earlier.
3633  Source: [Convolutional Deep Belief Networks on CIFAR-10: Krizhevsky et al.,
3634  2010](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf).
3635
3636  For example:
3637
3638  >>> x = tf.constant([-3.0, -1.0, 0.0, 6.0, 10.0], dtype=tf.float32)
3639  >>> y = tf.nn.relu6(x)
3640  >>> y.numpy()
3641  array([0., 0., 0., 6., 6.], dtype=float32)
3642
3643  Args:
3644    features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
3645      `int16`, or `int8`.
3646    name: A name for the operation (optional).
3647
3648  Returns:
3649    A `Tensor` with the same type as `features`.
3650
3651  References:
3652    Convolutional Deep Belief Networks on CIFAR-10:
3653      Krizhevsky et al., 2010
3654      ([pdf](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf))
3655  """
3656  with ops.name_scope(name, "Relu6", [features]) as name:
3657    features = ops.convert_to_tensor(features, name="features")
3658    return gen_nn_ops.relu6(features, name=name)
3659
3660
3661@tf_export("nn.leaky_relu")
3662@dispatch.register_unary_elementwise_api
3663@dispatch.add_dispatch_support
3664def leaky_relu(features, alpha=0.2, name=None):
3665  """Compute the Leaky ReLU activation function.
3666
3667  Source: [Rectifier Nonlinearities Improve Neural Network Acoustic Models.
3668  AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013]
3669  (https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf).
3670
3671  Args:
3672    features: A `Tensor` representing preactivation values. Must be one of
3673      the following types: `float16`, `float32`, `float64`, `int32`, `int64`.
3674    alpha: Slope of the activation function at x < 0.
3675    name: A name for the operation (optional).
3676
3677  Returns:
3678    The activation value.
3679
3680  References:
3681    Rectifier Nonlinearities Improve Neural Network Acoustic Models:
3682      [Maas et al., 2013]
3683      (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.693.1422)
3684      ([pdf]
3685      (http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.693.1422&rep=rep1&type=pdf))
3686  """
3687  with ops.name_scope(name, "LeakyRelu", [features, alpha]) as name:
3688    features = ops.convert_to_tensor(features, name="features")
3689    if features.dtype.is_integer:
3690      features = math_ops.cast(features, dtypes.float32)
3691    if isinstance(alpha, np.ndarray):
3692      alpha = alpha.item()
3693    return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
3694
3695
3696@tf_export("nn.gelu", v1=[])
3697@dispatch.register_unary_elementwise_api
3698@dispatch.add_dispatch_support
3699def gelu(features, approximate=False, name=None):
3700  """Compute the Gaussian Error Linear Unit (GELU) activation function.
3701
3702  Gaussian error linear unit (GELU) computes
3703  `x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
3704  The (GELU) nonlinearity weights inputs by their value, rather than gates
3705  inputs by their sign as in ReLU.
3706
3707  For example:
3708
3709  >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
3710  >>> y = tf.nn.gelu(x)
3711  >>> y.numpy()
3712  array([-0.00404951, -0.15865529,  0.        ,  0.8413447 ,  2.9959507 ],
3713      dtype=float32)
3714  >>> y = tf.nn.gelu(x, approximate=True)
3715  >>> y.numpy()
3716  array([-0.00363752, -0.15880796,  0.        ,  0.841192  ,  2.9963627 ],
3717      dtype=float32)
3718
3719  Args:
3720    features: A `Tensor` representing preactivation values.
3721    approximate: An optional `bool`. Defaults to `False`. Whether to enable
3722      approximation.
3723    name: A name for the operation (optional).
3724
3725  Returns:
3726    A `Tensor` with the same type as `features`.
3727
3728  References:
3729    [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415).
3730  """
3731  with ops.name_scope(name, "Gelu", [features]):
3732    features = ops.convert_to_tensor(features, name="features")
3733    if not features.dtype.is_floating:
3734      raise ValueError(
3735          "`features.dtype` must be a floating point tensor."
3736          f"Received:features.dtype={features.dtype}")
3737    if approximate:
3738      coeff = math_ops.cast(0.044715, features.dtype)
3739      return 0.5 * features * (
3740          1.0 + math_ops.tanh(0.7978845608028654 *
3741                              (features + coeff * math_ops.pow(features, 3))))
3742    else:
3743      return 0.5 * features * (1.0 + math_ops.erf(
3744          features / math_ops.cast(1.4142135623730951, features.dtype)))
3745
3746
3747def _flatten_outer_dims(logits):
3748  """Flattens logits' outer dimensions and keep its last dimension."""
3749  rank = array_ops.rank(logits)
3750  last_dim_size = array_ops.slice(
3751      array_ops.shape(logits), [math_ops.subtract(rank, 1)], [1])
3752  output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
3753
3754  # Set output shape if known.
3755  if not context.executing_eagerly():
3756    shape = logits.get_shape()
3757    if shape is not None and shape.dims is not None:
3758      shape = shape.as_list()
3759      product = 1
3760      product_valid = True
3761      for d in shape[:-1]:
3762        if d is None:
3763          product_valid = False
3764          break
3765        else:
3766          product *= d
3767      if product_valid:
3768        output_shape = [product, shape[-1]]
3769        output.set_shape(output_shape)
3770
3771  return output
3772
3773
3774def _wrap_2d_function(inputs, compute_op, dim=-1, name=None):
3775  """Helper function for ops that accept and return 2d inputs of same shape.
3776
3777  It reshapes and transposes the inputs into a 2-D Tensor and then invokes
3778  the given function. The output would be transposed and reshaped back.
3779  If the given function returns a tuple of tensors, each of them will be
3780  transposed and reshaped.
3781
3782  Args:
3783    inputs: A non-empty `Tensor`. Must be one of the following types: `half`,
3784      `float32`, `float64`.
3785    compute_op: The function to wrap. Must accept the input tensor as its first
3786      arugment, and a second keyword argument `name`.
3787    dim: The dimension softmax would be performed on. The default is -1 which
3788      indicates the last dimension.
3789    name: A name for the operation (optional).
3790
3791  Returns:
3792    A `Tensor`. Has the same shape as inputs. If compute_op returns multiple
3793      tensors, each of them have the same shape as the input.
3794  Raises:
3795    InvalidArgumentError: if `inputs` is empty or `dim` is beyond the last
3796      dimension of `inputs`.
3797  """
3798
3799  def _swap_axis(input_tensor, dim_index, last_index, name=None):
3800    """Swaps logits's dim_index and last_index."""
3801    return array_ops.transpose(
3802        input_tensor,
3803        array_ops.concat([
3804            math_ops.range(dim_index), [last_index],
3805            math_ops.range(dim_index + 1, last_index), [dim_index]
3806        ], 0),
3807        name=name)
3808
3809  inputs = ops.convert_to_tensor(inputs)
3810
3811  # We need its original shape for shape inference.
3812  shape = inputs.get_shape()
3813  is_last_dim = (dim == -1) or (dim == shape.ndims - 1)
3814
3815  if is_last_dim:
3816    return compute_op(inputs, name=name)
3817
3818  dim_val = dim
3819  if isinstance(dim, ops.Tensor):
3820    dim_val = tensor_util.constant_value(dim)
3821  if dim_val is not None and not -shape.ndims <= dim_val < shape.ndims:
3822    raise errors_impl.InvalidArgumentError(
3823        None, None,
3824        f"`dim` must be in the range [{-shape.ndims}, {shape.ndims}) where "
3825        f"{shape.ndims} is the number of dimensions in the input. "
3826        f"Received: dim={dim_val}")
3827
3828  # If dim is not the last dimension, we have to do a transpose so that we can
3829  # still perform the op on its last dimension.
3830
3831  # In case dim is negative (and is not last dimension -1), add shape.ndims
3832  ndims = array_ops.rank(inputs)
3833  if not isinstance(dim, ops.Tensor):
3834    if dim < 0:
3835      dim += ndims
3836  else:
3837    dim = array_ops.where(math_ops.less(dim, 0), dim + ndims, dim)
3838
3839  # Swap logits' dimension of dim and its last dimension.
3840  input_rank = array_ops.rank(inputs)
3841  dim_axis = dim % shape.ndims
3842  inputs = _swap_axis(inputs, dim_axis, math_ops.subtract(input_rank, 1))
3843
3844  # Do the actual call on its last dimension.
3845  def fix_output(output):
3846    output = _swap_axis(
3847        output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
3848
3849    # Make shape inference work since transpose may erase its static shape.
3850    output.set_shape(shape)
3851    return output
3852
3853  outputs = compute_op(inputs)
3854  if isinstance(outputs, tuple):
3855    return tuple(fix_output(output) for output in outputs)
3856  else:
3857    return fix_output(outputs)
3858
3859
3860@tf_export("nn.softmax", "math.softmax", v1=[])
3861@dispatch.add_dispatch_support
3862def softmax_v2(logits, axis=None, name=None):
3863  """Computes softmax activations.
3864
3865  Used for multi-class predictions. The sum of all outputs generated by softmax
3866  is 1.
3867
3868  This function performs the equivalent of
3869
3870  ```python
3871  softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis, keepdims=True)
3872  ```
3873  Example usage:
3874
3875  >>> softmax = tf.nn.softmax([-1, 0., 1.])
3876  >>> softmax
3877  <tf.Tensor: shape=(3,), dtype=float32,
3878  numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
3879  >>> sum(softmax)
3880  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
3881
3882  Args:
3883    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3884      `float32`, `float64`.
3885    axis: The dimension softmax would be performed on. The default is -1 which
3886      indicates the last dimension.
3887    name: A name for the operation (optional).
3888
3889  Returns:
3890    A `Tensor`. Has the same type and shape as `logits`.
3891
3892  Raises:
3893    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3894      dimension of `logits`.
3895  """
3896  if axis is None:
3897    axis = -1
3898  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3899
3900
3901@tf_export(v1=["nn.softmax", "math.softmax"])
3902@dispatch.add_dispatch_support
3903@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3904def softmax(logits, axis=None, name=None, dim=None):
3905  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3906  if axis is None:
3907    axis = -1
3908  return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
3909
3910
3911softmax.__doc__ = softmax_v2.__doc__
3912
3913
3914@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
3915@dispatch.register_unary_elementwise_api
3916@dispatch.add_dispatch_support
3917@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
3918def log_softmax(logits, axis=None, name=None, dim=None):
3919  """Computes log softmax activations.
3920
3921  For each batch `i` and class `j` we have
3922
3923      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3924
3925  Args:
3926    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3927      `float32`, `float64`.
3928    axis: The dimension softmax would be performed on. The default is -1 which
3929      indicates the last dimension.
3930    name: A name for the operation (optional).
3931    dim: Deprecated alias for `axis`.
3932
3933  Returns:
3934    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3935
3936  Raises:
3937    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3938      dimension of `logits`.
3939  """
3940  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
3941  if axis is None:
3942    axis = -1
3943  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3944
3945
3946@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
3947@dispatch.add_dispatch_support
3948def log_softmax_v2(logits, axis=None, name=None):
3949  """Computes log softmax activations.
3950
3951  For each batch `i` and class `j` we have
3952
3953      logsoftmax = logits - log(reduce_sum(exp(logits), axis))
3954
3955  Args:
3956    logits: A non-empty `Tensor`. Must be one of the following types: `half`,
3957      `float32`, `float64`.
3958    axis: The dimension softmax would be performed on. The default is -1 which
3959      indicates the last dimension.
3960    name: A name for the operation (optional).
3961
3962  Returns:
3963    A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
3964
3965  Raises:
3966    InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
3967      dimension of `logits`.
3968  """
3969  if axis is None:
3970    axis = -1
3971  return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
3972
3973
3974def _ensure_xent_args(name, sentinel, labels, logits):
3975  # Make sure that all arguments were passed as named arguments.
3976  if sentinel is not None:
3977    raise ValueError(
3978        f"Only call {name} with named arguments (labels=..., logits=..., ...). "
3979        f"Received unnamed argument: {sentinel}")
3980  if labels is None or logits is None:
3981    raise ValueError("Both `labels` and `logits` must be provided. "
3982                     f"Received: labels={labels} and logits={logits}")
3983
3984
3985@tf_export("nn.softmax_cross_entropy_with_logits", v1=[])
3986@dispatch.add_dispatch_support
3987def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
3988  """Computes softmax cross entropy between `logits` and `labels`.
3989
3990  Measures the probability error in discrete classification tasks in which the
3991  classes are mutually exclusive (each entry is in exactly one class).  For
3992  example, each CIFAR-10 image is labeled with one and only one label: an image
3993  can be a dog or a truck, but not both.
3994
3995  **NOTE:**  While the classes are mutually exclusive, their probabilities
3996  need not be.  All that is required is that each row of `labels` is
3997  a valid probability distribution.  If they are not, the computation of the
3998  gradient will be incorrect.
3999
4000  If using exclusive `labels` (wherein one and only
4001  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
4002
4003  Usage:
4004
4005  >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
4006  >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
4007  >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
4008  <tf.Tensor: shape=(2,), dtype=float32,
4009  numpy=array([0.16984604, 0.82474494], dtype=float32)>
4010
4011  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4012  on `logits` internally for efficiency.  Do not call this op with the
4013  output of `softmax`, as it will produce incorrect results.
4014
4015  A common use case is to have logits and labels of shape
4016  `[batch_size, num_classes]`, but higher dimensions are supported, with
4017  the `axis` argument specifying the class dimension.
4018
4019  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
4020  or `float64`).
4021
4022  Backpropagation will happen into both `logits` and `labels`.  To disallow
4023  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
4024  before feeding it to this function.
4025
4026  **Note that to avoid confusion, it is required to pass only named arguments to
4027  this function.**
4028
4029  Args:
4030    labels: Each vector along the class dimension should hold a valid
4031      probability distribution e.g. for the case in which labels are of shape
4032      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
4033      probability distribution.
4034    logits: Per-label activations, typically a linear output. These activation
4035      energies are interpreted as unnormalized log probabilities.
4036    axis: The class dimension. Defaulted to -1 which is the last dimension.
4037    name: A name for the operation (optional).
4038
4039  Returns:
4040    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4041    same as `logits` and its shape is the same as `labels` except that it does
4042    not have the last dimension of `labels`.
4043  """
4044  return softmax_cross_entropy_with_logits_v2_helper(
4045      labels=labels, logits=logits, axis=axis, name=name)
4046
4047
4048@tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"])
4049@dispatch.add_dispatch_support
4050@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
4051def softmax_cross_entropy_with_logits_v2_helper(
4052    labels, logits, axis=None, name=None, dim=None):
4053  """Computes softmax cross entropy between `logits` and `labels`.
4054
4055  Measures the probability error in discrete classification tasks in which the
4056  classes are mutually exclusive (each entry is in exactly one class).  For
4057  example, each CIFAR-10 image is labeled with one and only one label: an image
4058  can be a dog or a truck, but not both.
4059
4060  **NOTE:**  While the classes are mutually exclusive, their probabilities
4061  need not be.  All that is required is that each row of `labels` is
4062  a valid probability distribution.  If they are not, the computation of the
4063  gradient will be incorrect.
4064
4065  If using exclusive `labels` (wherein one and only
4066  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
4067
4068  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4069  on `logits` internally for efficiency.  Do not call this op with the
4070  output of `softmax`, as it will produce incorrect results.
4071
4072  A common use case is to have logits and labels of shape
4073  `[batch_size, num_classes]`, but higher dimensions are supported, with
4074  the `axis` argument specifying the class dimension.
4075
4076  `logits` and `labels` must have the same dtype (either `float16`, `float32`,
4077  or `float64`).
4078
4079  Backpropagation will happen into both `logits` and `labels`.  To disallow
4080  backpropagation into `labels`, pass label tensors through `tf.stop_gradient`
4081  before feeding it to this function.
4082
4083  **Note that to avoid confusion, it is required to pass only named arguments to
4084  this function.**
4085
4086  Args:
4087    labels: Each vector along the class dimension should hold a valid
4088      probability distribution e.g. for the case in which labels are of shape
4089      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
4090      probability distribution.
4091    logits: Unscaled log probabilities.
4092    axis: The class dimension. Defaulted to -1 which is the last dimension.
4093    name: A name for the operation (optional).
4094    dim: Deprecated alias for axis.
4095
4096  Returns:
4097    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4098    same as `logits` and its shape is the same as `labels` except that it does
4099    not have the last dimension of `labels`.
4100  """
4101  # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
4102  # could break users who call this with bad labels, but disregard the bad
4103  # results.
4104  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
4105  del dim
4106  if axis is None:
4107    axis = -1
4108
4109  with ops.name_scope(name, "softmax_cross_entropy_with_logits",
4110                      [logits, labels]) as name:
4111    logits = ops.convert_to_tensor(logits, name="logits")
4112    labels = ops.convert_to_tensor(labels, name="labels")
4113    convert_to_float32 = (
4114        logits.dtype == dtypes.float16 or logits.dtype == dtypes.bfloat16)
4115    precise_logits = math_ops.cast(
4116        logits, dtypes.float32) if convert_to_float32 else logits
4117    # labels and logits must be of the same type
4118    labels = math_ops.cast(labels, precise_logits.dtype)
4119    input_rank = array_ops.rank(precise_logits)
4120    # For shape inference.
4121    shape = logits.get_shape()
4122
4123    # Move the dim to the end if dim is not the last dimension.
4124    if axis != -1:
4125
4126      def _move_dim_to_end(tensor, dim_index, rank):
4127        return array_ops.transpose(
4128            tensor,
4129            array_ops.concat([
4130                math_ops.range(dim_index),
4131                math_ops.range(dim_index + 1, rank), [dim_index]
4132            ], 0))
4133
4134      precise_logits = _move_dim_to_end(precise_logits, axis, input_rank)
4135      labels = _move_dim_to_end(labels, axis, input_rank)
4136
4137    input_shape = array_ops.shape(precise_logits)
4138
4139    # Make precise_logits and labels into matrices.
4140    precise_logits = _flatten_outer_dims(precise_logits)
4141    labels = _flatten_outer_dims(labels)
4142
4143    # Do the actual op computation.
4144    if config.is_op_determinism_enabled():
4145      log_probs = log_softmax_v2(precise_logits)
4146      cost = -math_ops.reduce_sum(labels * log_probs, axis=1)
4147    else:
4148      # The second output tensor contains the gradients.  We use it in
4149      # CrossEntropyGrad() in nn_grad but not here.
4150      cost, unused_backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
4151          precise_logits, labels, name=name)
4152
4153    # The output cost shape should be the input minus axis.
4154    output_shape = array_ops.slice(input_shape, [0],
4155                                   [math_ops.subtract(input_rank, 1)])
4156    cost = array_ops.reshape(cost, output_shape)
4157
4158    # Make shape inference work since reshape and transpose may erase its static
4159    # shape.
4160    if not context.executing_eagerly(
4161    ) and shape is not None and shape.dims is not None:
4162      shape = shape.as_list()
4163      del shape[axis]
4164      cost.set_shape(shape)
4165
4166    if convert_to_float32:
4167      return math_ops.cast(cost, logits.dtype)
4168    else:
4169      return cost
4170
4171
4172_XENT_DEPRECATION = """
4173Future major versions of TensorFlow will allow gradients to flow
4174into the labels input on backprop by default.
4175
4176See `tf.nn.softmax_cross_entropy_with_logits_v2`.
4177"""
4178
4179
4180@tf_export(v1=["nn.softmax_cross_entropy_with_logits"])
4181@dispatch.add_dispatch_support
4182@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
4183def softmax_cross_entropy_with_logits(
4184    _sentinel=None,  # pylint: disable=invalid-name
4185    labels=None,
4186    logits=None,
4187    dim=-1,
4188    name=None,
4189    axis=None):
4190  """Computes softmax cross entropy between `logits` and `labels`.
4191
4192  Measures the probability error in discrete classification tasks in which the
4193  classes are mutually exclusive (each entry is in exactly one class).  For
4194  example, each CIFAR-10 image is labeled with one and only one label: an image
4195  can be a dog or a truck, but not both.
4196
4197  **NOTE:**  While the classes are mutually exclusive, their probabilities
4198  need not be.  All that is required is that each row of `labels` is
4199  a valid probability distribution.  If they are not, the computation of the
4200  gradient will be incorrect.
4201
4202  If using exclusive `labels` (wherein one and only
4203  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
4204
4205  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4206  on `logits` internally for efficiency.  Do not call this op with the
4207  output of `softmax`, as it will produce incorrect results.
4208
4209  A common use case is to have logits and labels of shape
4210  `[batch_size, num_classes]`, but higher dimensions are supported, with
4211  the `dim` argument specifying the class dimension.
4212
4213  Backpropagation will happen only into `logits`.  To calculate a cross entropy
4214  loss that allows backpropagation into both `logits` and `labels`, see
4215  `tf.nn.softmax_cross_entropy_with_logits_v2`.
4216
4217  **Note that to avoid confusion, it is required to pass only named arguments to
4218  this function.**
4219
4220  Args:
4221    _sentinel: Used to prevent positional parameters. Internal, do not use.
4222    labels: Each vector along the class dimension should hold a valid
4223      probability distribution e.g. for the case in which labels are of shape
4224      `[batch_size, num_classes]`, each row of `labels[i]` must be a valid
4225      probability distribution.
4226    logits: Per-label activations, typically a linear output. These activation
4227      energies are interpreted as unnormalized log probabilities.
4228    dim: The class dimension. Defaulted to -1 which is the last dimension.
4229    name: A name for the operation (optional).
4230    axis: Alias for dim.
4231
4232  Returns:
4233    A `Tensor` that contains the softmax cross entropy loss. Its type is the
4234    same as `logits` and its shape is the same as `labels` except that it does
4235    not have the last dimension of `labels`.
4236  """
4237  dim = deprecated_argument_lookup("axis", axis, "dim", dim)
4238  _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, labels,
4239                    logits)
4240
4241  with ops.name_scope(name, "softmax_cross_entropy_with_logits_sg",
4242                      [logits, labels]) as name:
4243    labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
4244
4245  return softmax_cross_entropy_with_logits_v2(
4246      labels=labels, logits=logits, axis=dim, name=name)
4247
4248
4249def _sparse_softmax_cross_entropy_with_rank_2_logits(logits, labels, name):
4250  if config.is_op_determinism_enabled():
4251    # TODO(duncanriach): Implement a GPU-deterministic version of this op at
4252    #     the C++/CUDA level.
4253
4254    # The actual op functionality
4255    log_probs = log_softmax_v2(logits)
4256    cost = math_ops.negative(array_ops.gather(log_probs, labels, batch_dims=1))
4257
4258    # Force the output to be NaN when the corresponding label is invalid.
4259    # Without the selective gradient gating provided by the following code,
4260    # backprop into the actual op functionality above, when there are invalid
4261    # labels, leads to corruption of the gradients associated with valid labels.
4262    # TODO(duncanriach): Uncover the source of the aforementioned corruption.
4263    nan_tensor = constant_op.constant(float("Nan"), dtype=logits.dtype)
4264    cost_all_nans = array_ops.broadcast_to(nan_tensor, array_ops.shape(cost))
4265    class_count = math_ops.cast(array_ops.shape(logits)[-1], labels.dtype)
4266    cost = array_ops.where(
4267        math_ops.logical_or(
4268            math_ops.less(labels, 0),
4269            math_ops.greater_equal(labels, class_count)), cost_all_nans, cost)
4270  else:
4271    # The second output tensor contains the gradients. We use it in
4272    # _CrossEntropyGrad() in nn_grad but not here.
4273    cost, _ = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
4274        logits, labels, name=name)
4275  return cost
4276
4277
4278@tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"])
4279@dispatch.add_dispatch_support
4280def sparse_softmax_cross_entropy_with_logits(
4281    _sentinel=None,  # pylint: disable=invalid-name
4282    labels=None,
4283    logits=None,
4284    name=None):
4285  """Computes sparse softmax cross entropy between `logits` and `labels`.
4286
4287  Measures the probability error in discrete classification tasks in which the
4288  classes are mutually exclusive (each entry is in exactly one class).  For
4289  example, each CIFAR-10 image is labeled with one and only one label: an image
4290  can be a dog or a truck, but not both.
4291
4292  **NOTE:**  For this operation, the probability of a given label is considered
4293  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4294  must provide a single specific index for the true class for each row of
4295  `logits` (each minibatch entry).  For soft softmax classification with
4296  a probability distribution for each entry, see
4297  `softmax_cross_entropy_with_logits_v2`.
4298
4299  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
4300  on `logits` internally for efficiency.  Do not call this op with the
4301  output of `softmax`, as it will produce incorrect results.
4302
4303  A common use case is to have logits of shape
4304  `[batch_size, num_classes]` and have labels of shape
4305  `[batch_size]`, but higher dimensions are supported, in which
4306  case the `dim`-th dimension is assumed to be of size `num_classes`.
4307  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4308  `labels` must have the dtype of `int32` or `int64`.
4309
4310  **Note that to avoid confusion, it is required to pass only named arguments to
4311  this function.**
4312
4313  Args:
4314    _sentinel: Used to prevent positional parameters. Internal, do not use.
4315    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4316      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4317      must be an index in `[0, num_classes)`. Other values will raise an
4318      exception when this op is run on CPU, and return `NaN` for corresponding
4319      loss and gradient rows on GPU.
4320    logits: Per-label activations (typically a linear output) of shape
4321      `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
4322      `float64`. These activation energies are interpreted as unnormalized log
4323      probabilities.
4324    name: A name for the operation (optional).
4325
4326  Returns:
4327    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4328    with the softmax cross entropy loss.
4329
4330  Raises:
4331    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4332      of the labels is not equal to the rank of the logits minus one.
4333  """
4334  _ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
4335                    labels, logits)
4336
4337  # TODO(pcmurray) Raise an error when the label is not an index in
4338  # [0, num_classes). Note: This could break users who call this with bad
4339  # labels, but disregard the bad results.
4340
4341  # Reshape logits and labels to rank 2.
4342  with ops.name_scope(name, "SparseSoftmaxCrossEntropyWithLogits",
4343                      [labels, logits]):
4344    labels = ops.convert_to_tensor(labels)
4345    logits = ops.convert_to_tensor(logits)
4346    precise_logits = math_ops.cast(logits, dtypes.float32) if (dtypes.as_dtype(
4347        logits.dtype) == dtypes.float16) else logits
4348
4349    # Store label shape for result later.
4350    labels_static_shape = labels.get_shape()
4351    labels_shape = array_ops.shape(labels)
4352    static_shapes_fully_defined = (
4353        labels_static_shape.is_fully_defined() and
4354        logits.get_shape()[:-1].is_fully_defined())
4355    if logits.get_shape().ndims is not None and logits.get_shape().ndims == 0:
4356      raise ValueError(
4357          f"`logits` cannot be a scalar. Received logits={logits}`")
4358    if logits.get_shape().ndims is not None and (
4359        labels_static_shape.ndims is not None and
4360        labels_static_shape.ndims != logits.get_shape().ndims - 1):
4361      raise ValueError(
4362          "`labels.shape.rank` must equal `logits.shape.rank - 1`. "
4363          f"Received: labels.shape={labels_static_shape} of rank "
4364          f"{labels_static_shape.rank} and logits.shape={logits.get_shape()} "
4365          f"of rank {logits.get_shape().rank}")
4366    if (static_shapes_fully_defined and
4367        labels_static_shape != logits.get_shape()[:-1]):
4368      raise ValueError(
4369          "`labels.shape` must equal `logits.shape` except for "
4370          f"the last dimension. Received: labels.shape={labels_static_shape} "
4371          f"and logits.shape={logits.get_shape()}")
4372    # Check if no reshapes are required.
4373    if logits.get_shape().ndims == 2:
4374      cost = _sparse_softmax_cross_entropy_with_rank_2_logits(
4375          precise_logits, labels, name=name)
4376      if logits.dtype == dtypes.float16:
4377        return math_ops.cast(cost, dtypes.float16)
4378      else:
4379        return cost
4380
4381    # Perform a check of the dynamic shapes if the static shapes are not fully
4382    # defined.
4383    shape_checks = []
4384    if not static_shapes_fully_defined:
4385      shape_checks.append(
4386          check_ops.assert_equal(
4387              array_ops.shape(labels),
4388              array_ops.shape(logits)[:-1]))
4389    with ops.control_dependencies(shape_checks):
4390      # Reshape logits to 2 dim, labels to 1 dim.
4391      num_classes = array_ops.shape(logits)[array_ops.rank(logits) - 1]
4392      precise_logits = array_ops.reshape(precise_logits, [-1, num_classes])
4393      labels = array_ops.reshape(labels, [-1])
4394      cost = _sparse_softmax_cross_entropy_with_rank_2_logits(
4395          precise_logits, labels, name=name)
4396      cost = array_ops.reshape(cost, labels_shape)
4397      cost.set_shape(labels_static_shape)
4398      if logits.dtype == dtypes.float16:
4399        return math_ops.cast(cost, dtypes.float16)
4400      else:
4401        return cost
4402
4403
4404@tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[])
4405@dispatch.add_dispatch_support
4406def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None):
4407  """Computes sparse softmax cross entropy between `logits` and `labels`.
4408
4409  Measures the probability error in discrete classification tasks in which the
4410  classes are mutually exclusive (each entry is in exactly one class).  For
4411  example, each CIFAR-10 image is labeled with one and only one label: an image
4412  can be a dog or a truck, but not both.
4413
4414  Note:  For this operation, the probability of a given label is considered
4415  exclusive.  That is, soft classes are not allowed, and the `labels` vector
4416  must provide a single specific index for the true class for each row of
4417  `logits` (each minibatch entry).  For soft softmax classification with
4418  a probability distribution for each entry, see
4419  `softmax_cross_entropy_with_logits_v2`.
4420
4421  Warning: This op expects unscaled logits, since it performs a `softmax`
4422  on `logits` internally for efficiency.  Do not call this op with the
4423  output of `softmax`, as it will produce incorrect results.
4424
4425  A common use case is to have logits of shape
4426  `[batch_size, num_classes]` and have labels of shape
4427  `[batch_size]`, but higher dimensions are supported, in which
4428  case the `dim`-th dimension is assumed to be of size `num_classes`.
4429  `logits` must have the dtype of `float16`, `float32`, or `float64`, and
4430  `labels` must have the dtype of `int32` or `int64`.
4431
4432  >>> logits = tf.constant([[2., -5., .5, -.1],
4433  ...                       [0., 0., 1.9, 1.4],
4434  ...                       [-100., 100., -100., -100.]])
4435  >>> labels = tf.constant([0, 3, 1])
4436  >>> tf.nn.sparse_softmax_cross_entropy_with_logits(
4437  ...     labels=labels, logits=logits).numpy()
4438  array([0.29750752, 1.1448325 , 0.        ], dtype=float32)
4439
4440  To avoid confusion, passing only named arguments to this function is
4441  recommended.
4442
4443  Args:
4444    labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
4445      `labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
4446      must be an index in `[0, num_classes)`. Other values will raise an
4447      exception when this op is run on CPU, and return `NaN` for corresponding
4448      loss and gradient rows on GPU.
4449    logits: Unscaled log probabilities of shape `[d_0, d_1, ..., d_{r-1},
4450      num_classes]` and dtype `float16`, `float32`, or `float64`.
4451    name: A name for the operation (optional).
4452
4453  Returns:
4454    A `Tensor` of the same shape as `labels` and of the same type as `logits`
4455    with the softmax cross entropy loss.
4456
4457  Raises:
4458    ValueError: If logits are scalars (need to have rank >= 1) or if the rank
4459      of the labels is not equal to the rank of the logits minus one.
4460  """
4461  return sparse_softmax_cross_entropy_with_logits(
4462      labels=labels, logits=logits, name=name)
4463
4464
4465@tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"])
4466@dispatch.add_dispatch_support
4467def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None):  # pylint: disable=redefined-builtin
4468  """Performs the avg pooling on the input.
4469
4470  Each entry in `output` is the mean of the corresponding size `ksize`
4471  window in `value`.
4472
4473  Args:
4474    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4475      [num_channels]` if `data_format` does not start with "NC" (default), or
4476      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4477      with "NC". Pooling happens over the spatial dimensions only.
4478    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4479      of the window for each dimension of the input tensor.
4480    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4481      stride of the sliding window for each dimension of the input tensor.
4482    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4483      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4484      for more information.
4485    data_format: A string. Specifies the channel dimension. For N=1 it can be
4486      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4487      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4488    name: Optional name for the operation.
4489
4490  Returns:
4491    A `Tensor` of format specified by `data_format`.
4492    The average pooled output tensor.
4493  """
4494  if input.shape is not None:
4495    n = len(input.shape) - 2
4496  elif data_format is not None:
4497    n = len(data_format) - 2
4498  else:
4499    raise ValueError(
4500        "`input` must have a static shape or `data_format` must be given. "
4501        f"Received: input.shape={input.shape} and "
4502        f"data_format={data_format}")
4503  if not 1 <= n <= 3:
4504    raise ValueError(
4505        f"`input.shape.rank` must be 3, 4 or 5. Received: "
4506        f"input.shape={input.shape} of rank {n + 2}.")
4507
4508  if data_format is None:
4509    channel_index = n + 1
4510  else:
4511    channel_index = 1 if data_format.startswith("NC") else n + 1
4512
4513  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4514  strides = _get_sequence(strides, n, channel_index, "strides")
4515
4516  avg_pooling_ops = {
4517      1: avg_pool1d,
4518      2: gen_nn_ops.avg_pool,
4519      3: gen_nn_ops.avg_pool3d
4520  }
4521
4522  op = avg_pooling_ops[n]
4523  return op(
4524      input,
4525      ksize=ksize,
4526      strides=strides,
4527      padding=padding,
4528      data_format=data_format,
4529      name=name)
4530
4531
4532@tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"])
4533@dispatch.add_dispatch_support
4534def avg_pool(value, ksize, strides, padding, data_format="NHWC",
4535             name=None, input=None):  # pylint: disable=redefined-builtin
4536  """Performs the average pooling on the input.
4537
4538  Each entry in `output` is the mean of the corresponding size `ksize`
4539  window in `value`.
4540
4541  Args:
4542    value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4543      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4544    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4545      the window for each dimension of the input tensor.
4546    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4547      stride of the sliding window for each dimension of the input tensor.
4548    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
4549      See the "returns" section of `tf.nn.convolution` for details.
4550    data_format: A string. 'NHWC' and 'NCHW' are supported.
4551    name: Optional name for the operation.
4552    input: Alias for value.
4553
4554  Returns:
4555    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4556  """
4557  with ops.name_scope(name, "AvgPool", [value]) as name:
4558    value = deprecation.deprecated_argument_lookup(
4559        "input", input, "value", value)
4560
4561    if data_format is None:
4562      data_format = "NHWC"
4563    channel_index = 1 if data_format.startswith("NC") else 3
4564
4565    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4566    strides = _get_sequence(strides, 2, channel_index, "strides")
4567
4568    return gen_nn_ops.avg_pool(
4569        value,
4570        ksize=ksize,
4571        strides=strides,
4572        padding=padding,
4573        data_format=data_format,
4574        name=name)
4575
4576
4577@tf_export("nn.avg_pool2d", v1=[])
4578@dispatch.add_dispatch_support
4579def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):  # pylint: disable=redefined-builtin
4580  """Performs the average pooling on the input.
4581
4582  Each entry in `output` is the mean of the corresponding size `ksize`
4583  window in `value`.
4584
4585  Args:
4586    input: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
4587      `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4588    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
4589      the window for each dimension of the input tensor.
4590    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
4591      stride of the sliding window for each dimension of the input tensor.
4592    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4593      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4594      for more information.
4595    data_format: A string. 'NHWC' and 'NCHW' are supported.
4596    name: Optional name for the operation.
4597
4598  Returns:
4599    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4600  """
4601  with ops.name_scope(name, "AvgPool2D", [input]) as name:
4602    if data_format is None:
4603      data_format = "NHWC"
4604    channel_index = 1 if data_format.startswith("NC") else 3
4605
4606    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4607    strides = _get_sequence(strides, 2, channel_index, "strides")
4608
4609    return gen_nn_ops.avg_pool(
4610        input,
4611        ksize=ksize,
4612        strides=strides,
4613        padding=padding,
4614        data_format=data_format,
4615        name=name)
4616
4617
4618@tf_export("nn.avg_pool1d")
4619@dispatch.add_dispatch_support
4620def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):  # pylint: disable=redefined-builtin
4621  """Performs the average pooling on the input.
4622
4623  Each entry in `output` is the mean of the corresponding size `ksize`
4624  window in `value`.
4625
4626  Note internally this op reshapes and uses the underlying 2d operation.
4627
4628  Args:
4629    input: A 3-D `Tensor` of the format specified by `data_format`.
4630    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4631      window for each dimension of the input tensor.
4632    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4633      the sliding window for each dimension of the input tensor.
4634    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4635      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4636      for more information.
4637    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4638    name: A name for the operation (optional).
4639
4640  Returns:
4641    A `Tensor` of format specified by `data_format`.
4642    The max pooled output tensor.
4643  """
4644  with ops.name_scope(name, "AvgPool1D", [input]) as name:
4645    if data_format is None:
4646      data_format = "NWC"
4647    channel_index = 1 if data_format.startswith("NC") else 2
4648    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4649    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4650
4651    expanding_dim = 1 if data_format == "NWC" else 2
4652    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4653
4654    input = array_ops.expand_dims_v2(input, expanding_dim)
4655    result = gen_nn_ops.avg_pool(
4656        input,
4657        ksize=ksize,
4658        strides=strides,
4659        padding=padding,
4660        data_format=data_format,
4661        name=name)
4662    return array_ops.squeeze(result, expanding_dim)
4663
4664
4665@tf_export("nn.avg_pool3d")
4666@dispatch.add_dispatch_support
4667def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):  # pylint: disable=redefined-builtin
4668  """Performs the average pooling on the input.
4669
4670  Each entry in `output` is the mean of the corresponding size `ksize`
4671  window in `value`.
4672
4673  Args:
4674    input: A 5-D `Tensor` of shape `[batch, depth, height, width, channels]`
4675      and type `float32`, `float64`, `qint8`, `quint8`, or `qint32`.
4676    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
4677      the window for each dimension of the input tensor.
4678    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
4679      stride of the sliding window for each dimension of the input tensor.
4680    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
4681      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4682      for more information.
4683    data_format: A string. 'NDHWC' and 'NCDHW' are supported.
4684    name: Optional name for the operation.
4685
4686  Returns:
4687    A `Tensor` with the same type as `value`.  The average pooled output tensor.
4688  """
4689  with ops.name_scope(name, "AvgPool3D", [input]) as name:
4690    if data_format is None:
4691      data_format = "NDHWC"
4692    channel_index = 1 if data_format.startswith("NC") else 3
4693
4694    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
4695    strides = _get_sequence(strides, 3, channel_index, "strides")
4696
4697    return gen_nn_ops.avg_pool3d(
4698        input,
4699        ksize=ksize,
4700        strides=strides,
4701        padding=padding,
4702        data_format=data_format,
4703        name=name)
4704
4705
4706# pylint: disable=redefined-builtin
4707@tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
4708@dispatch.add_dispatch_support
4709def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
4710  """Performs max pooling on the input.
4711
4712  For a given window of `ksize`, takes the maximum value within that window.
4713  Used for reducing computation and preventing overfitting.
4714
4715  Consider an example of pooling with 2x2, non-overlapping windows:
4716
4717  >>> matrix = tf.constant([
4718  ...     [0, 0, 1, 7],
4719  ...     [0, 2, 0, 0],
4720  ...     [5, 2, 0, 0],
4721  ...     [0, 0, 9, 8],
4722  ... ])
4723  >>> reshaped = tf.reshape(matrix, (1, 4, 4, 1))
4724  >>> tf.nn.max_pool(reshaped, ksize=2, strides=2, padding="SAME")
4725  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4726  array([[[[2],
4727           [7]],
4728          [[5],
4729           [9]]]], dtype=int32)>
4730
4731  We can adjust the window size using the `ksize` parameter. For example, if we
4732  were to expand the window to 3:
4733
4734  >>> tf.nn.max_pool(reshaped, ksize=3, strides=2, padding="SAME")
4735  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4736  array([[[[5],
4737           [7]],
4738          [[9],
4739           [9]]]], dtype=int32)>
4740
4741  We've now picked up two additional large numbers (5 and 9) in two of the
4742  pooled spots.
4743
4744  Note that our windows are now overlapping, since we're still moving by 2 units
4745  on each iteration. This is causing us to see the same 9 repeated twice, since
4746  it is part of two overlapping windows.
4747
4748  We can adjust how far we move our window with each iteration using the
4749  `strides` parameter. Updating this to the same value as our window size
4750  eliminates the overlap:
4751
4752  >>> tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="SAME")
4753  <tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
4754  array([[[[2],
4755           [7]],
4756          [[5],
4757           [9]]]], dtype=int32)>
4758
4759  Because the window does not neatly fit into our input, padding is added around
4760  the edges, giving us the same result as when we used a 2x2 window. We can skip
4761  padding altogether and simply drop the windows that do not fully fit into our
4762  input by instead passing `"VALID"` to the `padding` argument:
4763
4764  >>> tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="VALID")
4765  <tf.Tensor: shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],
4766   dtype=int32)>
4767
4768  Now we've grabbed the largest value in the 3x3 window starting from the upper-
4769  left corner. Since no other windows fit in our input, they are dropped.
4770
4771  Args:
4772    input:  Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
4773      [num_channels]` if `data_format` does not start with "NC" (default), or
4774      `[batch_size, num_channels] + input_spatial_shape` if data_format starts
4775      with "NC". Pooling happens over the spatial dimensions only.
4776    ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
4777      of the window for each dimension of the input tensor.
4778    strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
4779      stride of the sliding window for each dimension of the input tensor.
4780    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4781      padding algorithm to use, or a list indicating the explicit paddings at
4782      the start and end of each dimension. See
4783      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4784      for more information. When explicit padding is used and data_format is
4785      `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
4786      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
4787      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4788      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4789      padding, the size of the paddings cannot be greater than the sliding
4790      window size.
4791    data_format: A string. Specifies the channel dimension. For N=1 it can be
4792      either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
4793      or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
4794    name: Optional name for the operation.
4795
4796  Returns:
4797    A `Tensor` of format specified by `data_format`.
4798    The max pooled output tensor.
4799  """
4800  if input.shape is not None:
4801    n = len(input.shape) - 2
4802  elif data_format is not None:
4803    n = len(data_format) - 2
4804  else:
4805    raise ValueError(
4806        "`input` must have a static shape or a data format must be given. "
4807        f"Received: input.shape={input.shape} and "
4808        f"data_format={data_format}")
4809  if not 1 <= n <= 3:
4810    raise ValueError(
4811        f"`input.shape.rank` must be 3, 4 or 5. Received: "
4812        f"input.shape={input.shape} of rank {n + 2}.")
4813  if data_format is None:
4814    channel_index = n + 1
4815  else:
4816    channel_index = 1 if data_format.startswith("NC") else n + 1
4817
4818  if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4819    raise ValueError("`data_format='NCHW_VECT_C'` is not supported with "
4820                     f"explicit padding. Received: padding={padding}")
4821
4822  ksize = _get_sequence(ksize, n, channel_index, "ksize")
4823  strides = _get_sequence(strides, n, channel_index, "strides")
4824
4825  if (isinstance(padding, (list, tuple)) and n == 3):
4826    raise ValueError("Explicit padding is not supported with an input "
4827                     f"tensor of rank 5. Received: padding={padding}")
4828
4829  max_pooling_ops = {
4830      1: max_pool1d,
4831      2: max_pool2d,
4832      3: gen_nn_ops.max_pool3d
4833  }
4834
4835  op = max_pooling_ops[n]
4836  return op(
4837      input,
4838      ksize=ksize,
4839      strides=strides,
4840      padding=padding,
4841      data_format=data_format,
4842      name=name)
4843# pylint: enable=redefined-builtin
4844
4845
4846@tf_export(v1=["nn.max_pool"])
4847@dispatch.add_dispatch_support
4848def max_pool(value,
4849             ksize,
4850             strides,
4851             padding,
4852             data_format="NHWC",
4853             name=None,
4854             input=None):  # pylint: disable=redefined-builtin
4855  """Performs the max pooling on the input.
4856
4857  Args:
4858    value: A 4-D `Tensor` of the format specified by `data_format`.
4859    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
4860      The size of the window for each dimension of the input tensor.
4861    strides: An int or list of `ints` that has length `1`, `2` or `4`.
4862      The stride of the sliding window for each dimension of the input tensor.
4863    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4864      padding algorithm to use, or a list indicating the explicit paddings at
4865      the start and end of each dimension. When explicit padding is used and
4866      data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,
4867      pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used
4868      and data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
4869      [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
4870      padding, the size of the paddings cannot be greater than the sliding
4871      window size.
4872    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
4873    name: Optional name for the operation.
4874    input: Alias for value.
4875
4876  Returns:
4877    A `Tensor` of format specified by `data_format`.
4878    The max pooled output tensor.
4879  """
4880  value = deprecation.deprecated_argument_lookup("input", input, "value", value)
4881  with ops.name_scope(name, "MaxPool", [value]) as name:
4882    if data_format is None:
4883      data_format = "NHWC"
4884    channel_index = 1 if data_format.startswith("NC") else 3
4885
4886    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
4887    strides = _get_sequence(strides, 2, channel_index, "strides")
4888    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4889      raise ValueError("`data_format='NCHW_VECT_C'` is not supported with "
4890                       f"explicit padding. Received: padding={padding}")
4891    padding, explicit_paddings = convert_padding(padding)
4892    if ((np.isscalar(ksize) and ksize == 0) or
4893        (isinstance(ksize,
4894                    (list, tuple, np.ndarray)) and any(v == 0 for v in ksize))):
4895      raise ValueError(f"`ksize` cannot be zero. Received: ksize={ksize}")
4896
4897    return gen_nn_ops.max_pool(
4898        value,
4899        ksize=ksize,
4900        strides=strides,
4901        padding=padding,
4902        explicit_paddings=explicit_paddings,
4903        data_format=data_format,
4904        name=name)
4905
4906
4907# pylint: disable=redefined-builtin
4908@tf_export("nn.max_pool1d")
4909@dispatch.add_dispatch_support
4910def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
4911  """Performs the max pooling on the input.
4912
4913  Note internally this op reshapes and uses the underlying 2d operation.
4914
4915  Args:
4916    input: A 3-D `Tensor` of the format specified by `data_format`.
4917    ksize: An int or list of `ints` that has length `1` or `3`. The size of the
4918      window for each dimension of the input tensor.
4919    strides: An int or list of `ints` that has length `1` or `3`. The stride of
4920      the sliding window for each dimension of the input tensor.
4921    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
4922      padding algorithm to use, or a list indicating the explicit paddings at
4923      the start and end of each dimension. See
4924      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
4925      for more information. When explicit padding is used and data_format is
4926      `"NWC"`, this should be in the form `[[0, 0], [pad_left, pad_right], [0,
4927      0]]`. When explicit padding used and data_format is `"NCW"`, this should
4928      be in the form `[[0, 0], [0, 0], [pad_left, pad_right]]`. When using
4929      explicit padding, the size of the paddings cannot be greater than the
4930      sliding window size.
4931    data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
4932    name: A name for the operation (optional).
4933
4934  Returns:
4935    A `Tensor` of format specified by `data_format`.
4936    The max pooled output tensor.
4937  """
4938  with ops.name_scope(name, "MaxPool1d", [input]) as name:
4939    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
4940      raise ValueError("`data_format='NCHW_VECT_C'` is not supported with "
4941                       f"explicit padding. Received: padding={padding}")
4942    if data_format is None:
4943      data_format = "NWC"
4944    channel_index = 1 if data_format.startswith("NC") else 2
4945    ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
4946    strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
4947    padding, explicit_paddings = convert_padding(padding, 3)
4948    if padding == "EXPLICIT":
4949      explicit_paddings = [0, 0] + explicit_paddings
4950
4951    expanding_dim = 1 if data_format == "NWC" else 2
4952    data_format = "NHWC" if data_format == "NWC" else "NCHW"
4953
4954    input = array_ops.expand_dims_v2(input, expanding_dim)
4955    result = gen_nn_ops.max_pool(
4956        input,
4957        ksize=ksize,
4958        strides=strides,
4959        padding=padding,
4960        explicit_paddings=explicit_paddings,
4961        data_format=data_format,
4962        name=name)
4963    return array_ops.squeeze(result, expanding_dim)
4964# pylint: enable=redefined-builtin
4965
4966
4967# pylint: disable=redefined-builtin
4968@tf_export("nn.max_pool2d")
4969@dispatch.add_dispatch_support
4970def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
4971  """Performs max pooling on 2D spatial data such as images.
4972
4973  This is a more specific version of `tf.nn.max_pool` where the input tensor
4974  is 4D, representing 2D spatial data such as images. Using these APIs are
4975  equivalent
4976
4977  Downsamples the input images along theirs spatial dimensions (height and
4978  width) by taking its maximum over an input window defined by `ksize`.
4979  The window is shifted by `strides` along each dimension.
4980
4981  For example, for `strides=(2, 2)` and `padding=VALID` windows that extend
4982  outside of the input are not included in the output:
4983
4984  >>> x = tf.constant([[1., 2., 3., 4.],
4985  ...                  [5., 6., 7., 8.],
4986  ...                  [9., 10., 11., 12.]])
4987  >>> # Add the `batch` and `channels` dimensions.
4988  >>> x = x[tf.newaxis, :, :, tf.newaxis]
4989  >>> result = tf.nn.max_pool2d(x, ksize=(2, 2), strides=(2, 2),
4990  ...                           padding="VALID")
4991  >>> result[0, :, :, 0]
4992  <tf.Tensor: shape=(1, 2), dtype=float32, numpy=
4993  array([[6., 8.]], dtype=float32)>
4994
4995  With `padding=SAME`, we get:
4996
4997  >>> x = tf.constant([[1., 2., 3., 4.],
4998  ...                  [5., 6., 7., 8.],
4999  ...                  [9., 10., 11., 12.]])
5000  >>> x = x[tf.newaxis, :, :, tf.newaxis]
5001  >>> result = tf.nn.max_pool2d(x, ksize=(2, 2), strides=(2, 2),
5002  ...                           padding='SAME')
5003  >>> result[0, :, :, 0]
5004  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
5005  array([[ 6., 8.],
5006         [10.,12.]], dtype=float32)>
5007
5008  We can also specify padding explicitly. The following example adds width-1
5009  padding on all sides (top, bottom, left, right):
5010
5011  >>> x = tf.constant([[1., 2., 3., 4.],
5012  ...                  [5., 6., 7., 8.],
5013  ...                  [9., 10., 11., 12.]])
5014  >>> x = x[tf.newaxis, :, :, tf.newaxis]
5015  >>> result = tf.nn.max_pool2d(x, ksize=(2, 2), strides=(2, 2),
5016  ...                           padding=[[0, 0], [1, 1], [1, 1], [0, 0]])
5017  >>> result[0, :, :, 0]
5018  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
5019  array([[ 1., 3., 4.],
5020         [ 9., 11., 12.]], dtype=float32)>
5021
5022  For more examples and detail, see `tf.nn.max_pool`.
5023
5024  Args:
5025    input: A 4-D `Tensor` of the format specified by `data_format`.
5026    ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
5027      the window for each dimension of the input tensor. If only one integer is
5028      specified, then we apply the same window for all 4 dims. If two are
5029      provided then we use those for H, W dimensions and keep N, C dimension
5030      window size = 1.
5031    strides: An int or list of `ints` that has length `1`, `2` or `4`. The
5032      stride of the sliding window for each dimension of the input tensor. If
5033      only one integer is specified, we apply the same stride to all 4 dims. If
5034      two are provided we use those for the H, W dimensions and keep N, C of
5035      stride = 1.
5036    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
5037      padding algorithm to use, or a list indicating the explicit paddings at
5038      the start and end of each dimension. See
5039      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
5040        for more information. When explicit padding is used and data_format is
5041        `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
5042        [pad_left, pad_right], [0, 0]]`. When explicit padding used and
5043        data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
5044        [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit
5045        padding, the size of the paddings cannot be greater than the sliding
5046        window size.
5047    data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
5048    name: Optional name for the operation.
5049
5050  Returns:
5051    A `Tensor` of format specified by `data_format`.
5052    The max pooled output tensor.
5053  """
5054  with ops.name_scope(name, "MaxPool2d", [input]) as name:
5055    if data_format is None:
5056      data_format = "NHWC"
5057    channel_index = 1 if data_format.startswith("NC") else 3
5058
5059    ksize = _get_sequence(ksize, 2, channel_index, "ksize")
5060    strides = _get_sequence(strides, 2, channel_index, "strides")
5061    if isinstance(padding, (list, tuple)) and data_format == "NCHW_VECT_C":
5062      raise ValueError("`data_format='NCHW_VECT_C'` is not supported with "
5063                       f"explicit padding. Received: padding={padding}")
5064    padding, explicit_paddings = convert_padding(padding)
5065
5066    return gen_nn_ops.max_pool(
5067        input,
5068        ksize=ksize,
5069        strides=strides,
5070        padding=padding,
5071        explicit_paddings=explicit_paddings,
5072        data_format=data_format,
5073        name=name)
5074# pylint: enable=redefined-builtin
5075
5076
5077# pylint: disable=redefined-builtin
5078@tf_export("nn.max_pool3d")
5079@dispatch.add_dispatch_support
5080def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
5081  """Performs the max pooling on the input.
5082
5083  Args:
5084    input: A 5-D `Tensor` of the format specified by `data_format`.
5085    ksize: An int or list of `ints` that has length `1`, `3` or `5`. The size of
5086      the window for each dimension of the input tensor.
5087    strides: An int or list of `ints` that has length `1`, `3` or `5`. The
5088      stride of the sliding window for each dimension of the input tensor.
5089    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
5090      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
5091      for more information.
5092    data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC".
5093      The data format of the input and output data. With the default format
5094      "NDHWC", the data is stored in the order of: [batch, in_depth, in_height,
5095        in_width, in_channels]. Alternatively, the format could be "NCDHW", the
5096      data storage order is: [batch, in_channels, in_depth, in_height,
5097        in_width].
5098    name: A name for the operation (optional).
5099
5100  Returns:
5101    A `Tensor` of format specified by `data_format`.
5102    The max pooled output tensor.
5103  """
5104  with ops.name_scope(name, "MaxPool3D", [input]) as name:
5105    if data_format is None:
5106      data_format = "NDHWC"
5107    channel_index = 1 if data_format.startswith("NC") else 4
5108
5109    ksize = _get_sequence(ksize, 3, channel_index, "ksize")
5110    strides = _get_sequence(strides, 3, channel_index, "strides")
5111
5112    return gen_nn_ops.max_pool3d(
5113        input,
5114        ksize=ksize,
5115        strides=strides,
5116        padding=padding,
5117        data_format=data_format,
5118        name=name)
5119# pylint: enable=redefined-builtin
5120
5121
5122@tf_export("nn.max_pool_with_argmax", v1=[])
5123@dispatch.add_dispatch_support
5124def max_pool_with_argmax_v2(
5125    input,  # pylint: disable=redefined-builtin
5126    ksize,
5127    strides,
5128    padding,
5129    data_format="NHWC",
5130    output_dtype=dtypes.int64,
5131    include_batch_in_index=False,
5132    name=None):
5133  """Performs max pooling on the input and outputs both max values and indices.
5134
5135  The indices in `argmax` are flattened, so that a maximum value at position
5136  `[b, y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
5137  `include_batch_in_index` is False;
5138  `((b * height + y) * width + x) * channels + c`
5139  if `include_batch_in_index` is True.
5140
5141  The indices returned are always in `[0, height) x [0, width)` before
5142  flattening, even if padding is involved and the mathematically correct answer
5143  is outside (either negative or too large).  This is a bug, but fixing it is
5144  difficult to do in a safe backwards compatible way, especially due to
5145  flattening.
5146
5147  Args:
5148    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
5149      `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
5150      `uint32`, `uint64`.
5151      4-D with shape `[batch, height, width, channels]`.  Input to pool over.
5152    ksize: An int or list of `ints` that has length `1`, `2` or `4`.
5153      The size of the window for each dimension of the input tensor.
5154    strides: An int or list of `ints` that has length `1`, `2` or `4`.
5155      The stride of the sliding window for each dimension of the
5156      input tensor.
5157    padding: A `string` from: `"SAME", "VALID"`.
5158      The type of padding algorithm to use. See
5159      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
5160      for more information.
5161    data_format: An optional `string`, must be set to `"NHWC"`. Defaults to
5162      `"NHWC"`.
5163      Specify the data format of the input and output data.
5164    output_dtype: An optional `tf.DType` from: `tf.int32, tf.int64`.
5165      Defaults to `tf.int64`.
5166      The dtype of the returned argmax tensor.
5167    include_batch_in_index: An optional `boolean`. Defaults to `False`.
5168      Whether to include batch dimension in flattened index of `argmax`.
5169    name: A name for the operation (optional).
5170
5171  Returns:
5172    A tuple of `Tensor` objects (output, argmax).
5173
5174    output: A `Tensor`. Has the same type as `input`.
5175    argmax: A `Tensor` of type `output_dtype`.
5176  """
5177
5178  if data_format != "NHWC":
5179    raise ValueError("`data_format` values other  than 'NHWC' are not "
5180                     f"supported. Received: data_format={data_format}")
5181
5182  ksize = _get_sequence(ksize, 2, 3, "ksize")
5183  strides = _get_sequence(strides, 2, 3, "strides")
5184
5185  return gen_nn_ops.max_pool_with_argmax(
5186      input=input,
5187      ksize=ksize,
5188      strides=strides,
5189      padding=padding,
5190      Targmax=output_dtype,
5191      include_batch_in_index=include_batch_in_index,
5192      name=name)
5193
5194
5195@tf_export(v1=["nn.max_pool_with_argmax"])
5196@dispatch.add_dispatch_support
5197def max_pool_with_argmax_v1(  # pylint: disable=missing-docstring,invalid-name
5198    input,  # pylint: disable=redefined-builtin
5199    ksize,
5200    strides,
5201    padding,
5202    data_format="NHWC",
5203    Targmax=None,
5204    name=None,
5205    output_dtype=None,
5206    include_batch_in_index=False):
5207  if data_format != "NHWC":
5208    raise ValueError("`data_format` values other  than 'NHWC' are not "
5209                     f"supported. Received: data_format={data_format}")
5210
5211  Targmax = deprecated_argument_lookup(
5212      "output_dtype", output_dtype, "Targmax", Targmax)
5213  if Targmax is None:
5214    Targmax = dtypes.int64
5215  return gen_nn_ops.max_pool_with_argmax(
5216      input=input,
5217      ksize=ksize,
5218      strides=strides,
5219      padding=padding,
5220      Targmax=Targmax,
5221      include_batch_in_index=include_batch_in_index,
5222      name=name)
5223
5224
5225max_pool_with_argmax_v1.__doc__ = gen_nn_ops.max_pool_with_argmax.__doc__
5226
5227
5228@ops.RegisterStatistics("Conv3D", "flops")
5229def _calc_conv3d_flops(graph, node):
5230  """Calculates the compute resources needed for Conv3D."""
5231  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5232  input_shape.assert_is_fully_defined()
5233  filter_shape = graph_util.tensor_shape_from_node_def_name(
5234      graph, node.input[1])
5235  filter_shape.assert_is_fully_defined()
5236  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5237  output_shape.assert_is_fully_defined()
5238  filter_time = int(filter_shape[0])
5239  filter_height = int(filter_shape[1])
5240  filter_width = int(filter_shape[2])
5241  filter_in_depth = int(filter_shape[3])
5242  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5243  return ops.OpStats("flops", (output_count * filter_in_depth * filter_time *
5244                               filter_height * filter_width * 2))
5245
5246
5247@ops.RegisterStatistics("Conv2D", "flops")
5248def _calc_conv_flops(graph, node):
5249  """Calculates the compute resources needed for Conv2D."""
5250  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5251  input_shape.assert_is_fully_defined()
5252  filter_shape = graph_util.tensor_shape_from_node_def_name(
5253      graph, node.input[1])
5254  filter_shape.assert_is_fully_defined()
5255  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5256  output_shape.assert_is_fully_defined()
5257  filter_height = int(filter_shape[0])
5258  filter_width = int(filter_shape[1])
5259  filter_in_depth = int(filter_shape[2])
5260  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5261  return ops.OpStats(
5262      "flops",
5263      (output_count * filter_in_depth * filter_height * filter_width * 2))
5264
5265
5266@ops.RegisterStatistics("DepthwiseConv2dNative", "flops")
5267def _calc_depthwise_conv_flops(graph, node):
5268  """Calculates the compute resources needed for DepthwiseConv2dNative."""
5269  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5270  input_shape.assert_is_fully_defined()
5271  filter_shape = graph_util.tensor_shape_from_node_def_name(
5272      graph, node.input[1])
5273  filter_shape.assert_is_fully_defined()
5274  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
5275  output_shape.assert_is_fully_defined()
5276  filter_height = int(filter_shape[0])
5277  filter_width = int(filter_shape[1])
5278  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
5279  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
5280
5281
5282@ops.RegisterStatistics("BiasAdd", "flops")
5283def _calc_bias_add_flops(graph, node):
5284  """Calculates the computing needed for BiasAdd."""
5285  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
5286  input_shape.assert_is_fully_defined()
5287  input_count = np.prod(input_shape.as_list())
5288  return ops.OpStats("flops", input_count)
5289
5290
5291@tf_export(v1=["nn.xw_plus_b"])
5292@dispatch.add_dispatch_support
5293def xw_plus_b(x, weights, biases, name=None):  # pylint: disable=invalid-name
5294  """Computes matmul(x, weights) + biases.
5295
5296  Args:
5297    x: a 2D tensor.  Dimensions typically: batch, in_units
5298    weights: a 2D tensor.  Dimensions typically: in_units, out_units
5299    biases: a 1D tensor.  Dimensions: out_units
5300    name: A name for the operation (optional).  If not specified
5301      "xw_plus_b" is used.
5302
5303  Returns:
5304    A 2-D Tensor computing matmul(x, weights) + biases.
5305    Dimensions typically: batch, out_units.
5306  """
5307  with ops.name_scope(name, "xw_plus_b", [x, weights, biases]) as name:
5308    x = ops.convert_to_tensor(x, name="x")
5309    weights = ops.convert_to_tensor(weights, name="weights")
5310    biases = ops.convert_to_tensor(biases, name="biases")
5311    mm = math_ops.matmul(x, weights)
5312    return bias_add(mm, biases, name=name)
5313
5314
5315def xw_plus_b_v1(x, weights, biases, name=None):
5316  """Computes matmul(x, weights) + biases.
5317
5318  This is a deprecated version of that will soon be removed.
5319
5320  Args:
5321    x: a 2D tensor.  Dimensions typically: batch, in_units
5322    weights: a 2D tensor.  Dimensions typically: in_units, out_units
5323    biases: a 1D tensor.  Dimensions: out_units
5324    name: A name for the operation (optional).  If not specified
5325      "xw_plus_b_v1" is used.
5326
5327  Returns:
5328    A 2-D Tensor computing matmul(x, weights) + biases.
5329    Dimensions typically: batch, out_units.
5330  """
5331  with ops.name_scope(name, "xw_plus_b_v1", [x, weights, biases]) as name:
5332    x = ops.convert_to_tensor(x, name="x")
5333    weights = ops.convert_to_tensor(weights, name="weights")
5334    biases = ops.convert_to_tensor(biases, name="biases")
5335    mm = math_ops.matmul(x, weights)
5336    return bias_add_v1(mm, biases, name=name)
5337
5338
5339def _get_noise_shape(x, noise_shape):
5340  # If noise_shape is none return immediately.
5341  if noise_shape is None:
5342    return array_ops.shape(x)
5343
5344  try:
5345    # Best effort to figure out the intended shape.
5346    # If not possible, let the op to handle it.
5347    # In eager mode exception will show up.
5348    noise_shape_ = tensor_shape.as_shape(noise_shape)
5349  except (TypeError, ValueError):
5350    return noise_shape
5351
5352  if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
5353    new_dims = []
5354    for i, dim in enumerate(x.shape.dims):
5355      if noise_shape_.dims[i].value is None and dim.value is not None:
5356        new_dims.append(dim.value)
5357      else:
5358        new_dims.append(noise_shape_.dims[i].value)
5359    return tensor_shape.TensorShape(new_dims)
5360
5361  return noise_shape
5362
5363
5364@tf_export(v1=["nn.dropout"])
5365@dispatch.add_dispatch_support
5366@deprecation.deprecated_args(None, "Please use `rate` instead of `keep_prob`. "
5367                             "Rate should be set to `rate = 1 - keep_prob`.",
5368                             "keep_prob")
5369def dropout(x, keep_prob=None, noise_shape=None, seed=None, name=None,
5370            rate=None):
5371  """Computes dropout.
5372
5373  For each element of `x`, with probability `rate`, outputs `0`, and otherwise
5374  scales up the input by `1 / (1-rate)`. The scaling is such that the expected
5375  sum is unchanged.
5376
5377  By default, each element is kept or dropped independently.  If `noise_shape`
5378  is specified, it must be
5379  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5380  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5381  will make independent decisions.  For example, if `shape(x) = [k, l, m, n]`
5382  and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
5383  kept independently and each row and column will be kept or not kept together.
5384
5385  Args:
5386    x: A floating point tensor.
5387    keep_prob: (deprecated) A deprecated alias for `(1-rate)`.
5388    noise_shape: A 1-D integer `Tensor`, representing the
5389      shape for randomly generated keep/drop flags.
5390    seed: A Python integer. Used to create random seeds. See
5391      `tf.random.set_seed` for behavior.
5392    name: A name for this operation (optional).
5393    rate: A scalar `Tensor` with the same type as `x`. The probability that each
5394      element of `x` is discarded.
5395
5396  Returns:
5397    A Tensor of the same shape of `x`.
5398
5399  Raises:
5400    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating
5401      point tensor.
5402  """
5403  try:
5404    rate_from_keep_prob = 1. - keep_prob if keep_prob is not None else None
5405  except TypeError:
5406    raise ValueError("`keep_prob` must be a floating point number or Tensor. "
5407                     f"Received: keep_prob={keep_prob}")
5408
5409  rate = deprecation.deprecated_argument_lookup(
5410      "rate", rate,
5411      "keep_prob", rate_from_keep_prob)
5412
5413  if rate is None:
5414    raise ValueError(f"`rate` must be provided. Received: rate={rate}")
5415
5416  return dropout_v2(x, rate, noise_shape=noise_shape, seed=seed, name=name)
5417
5418
5419@tf_export("nn.dropout", v1=[])
5420@dispatch.add_dispatch_support
5421def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
5422  """Computes dropout: randomly sets elements to zero to prevent overfitting.
5423
5424  Warning: You should consider using
5425  `tf.nn.experimental.stateless_dropout` instead of this function. The
5426  difference between `tf.nn.experimental.stateless_dropout` and this
5427  function is analogous to the difference between
5428  `tf.random.stateless_uniform` and `tf.random.uniform`. Please see
5429  [Random number
5430  generation](https://www.tensorflow.org/guide/random_numbers) guide
5431  for a detailed description of the various RNG systems in TF. As the
5432  guide states, legacy stateful RNG ops like `tf.random.uniform` and
5433  `tf.nn.dropout` are not deprecated yet but highly discouraged,
5434  because their states are hard to control.
5435
5436  Note: The behavior of dropout has changed between TensorFlow 1.x and 2.x.
5437  When converting 1.x code, please use named arguments to ensure behavior stays
5438  consistent.
5439
5440  See also: `tf.keras.layers.Dropout` for a dropout layer.
5441
5442  [Dropout](https://arxiv.org/abs/1207.0580) is useful for regularizing DNN
5443  models. Inputs elements are randomly set to zero (and the other elements are
5444  rescaled). This encourages each node to be independently useful, as it cannot
5445  rely on the output of other nodes.
5446
5447  More precisely: With probability `rate` elements of `x` are set to `0`.
5448  The remaining elements are scaled up by `1.0 / (1 - rate)`, so that the
5449  expected value is preserved.
5450
5451  >>> tf.random.set_seed(0)
5452  >>> x = tf.ones([3,5])
5453  >>> tf.nn.dropout(x, rate = 0.5, seed = 1).numpy()
5454  array([[2., 0., 0., 2., 2.],
5455       [2., 2., 2., 2., 2.],
5456       [2., 0., 2., 0., 2.]], dtype=float32)
5457
5458  >>> tf.random.set_seed(0)
5459  >>> x = tf.ones([3,5])
5460  >>> tf.nn.dropout(x, rate = 0.8, seed = 1).numpy()
5461  array([[0., 0., 0., 5., 5.],
5462       [0., 5., 0., 5., 0.],
5463       [5., 0., 5., 0., 5.]], dtype=float32)
5464
5465  >>> tf.nn.dropout(x, rate = 0.0) == x
5466  <tf.Tensor: shape=(3, 5), dtype=bool, numpy=
5467    array([[ True,  True,  True,  True,  True],
5468           [ True,  True,  True,  True,  True],
5469           [ True,  True,  True,  True,  True]])>
5470
5471
5472  By default, each element is kept or dropped independently.  If `noise_shape`
5473  is specified, it must be
5474  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5475  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5476  will make independent decisions. This is useful for dropping whole
5477  channels from an image or sequence. For example:
5478
5479  >>> tf.random.set_seed(0)
5480  >>> x = tf.ones([3,10])
5481  >>> tf.nn.dropout(x, rate = 2/3, noise_shape=[1,10], seed=1).numpy()
5482  array([[0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5483       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.],
5484       [0., 0., 0., 3., 3., 0., 3., 3., 3., 0.]], dtype=float32)
5485
5486  Args:
5487    x: A floating point tensor.
5488    rate: A scalar `Tensor` with the same type as x. The probability
5489      that each element is dropped. For example, setting rate=0.1 would drop
5490      10% of input elements.
5491    noise_shape: A 1-D integer `Tensor`, representing the
5492      shape for randomly generated keep/drop flags.
5493    seed: A Python integer. Used to create random seeds. See
5494      `tf.random.set_seed` for behavior.
5495    name: A name for this operation (optional).
5496
5497  Returns:
5498    A Tensor of the same shape of `x`.
5499
5500  Raises:
5501    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
5502      tensor. `rate=1` is disallowed, because the output would be all zeros,
5503      which is likely not what was intended.
5504  """
5505  uniform_sampler = functools.partial(random_ops.random_uniform, seed=seed)
5506  def dummy_rng_step():
5507    random_seed.get_seed(seed)
5508  return _dropout(x=x, rate=rate, noise_shape=noise_shape,
5509                  uniform_sampler=uniform_sampler,
5510                  dummy_rng_step=dummy_rng_step, name=name,
5511                  default_name="dropout")
5512
5513
5514@tf_export("nn.experimental.stateless_dropout")
5515@dispatch.add_dispatch_support
5516def stateless_dropout(x, rate, seed, rng_alg=None, noise_shape=None, name=None):
5517  """Computes dropout: randomly sets elements to zero to prevent overfitting.
5518
5519  [Dropout](https://arxiv.org/abs/1207.0580) is useful for regularizing DNN
5520  models. Inputs elements are randomly set to zero (and the other elements are
5521  rescaled). This encourages each node to be independently useful, as it cannot
5522  rely on the output of other nodes.
5523
5524  More precisely: With probability `rate` elements of `x` are set to `0`.
5525  The remaining elements are scaled up by `1.0 / (1 - rate)`, so that the
5526  expected value is preserved.
5527
5528  >>> x = tf.ones([3,5])
5529  >>> tf.nn.experimental.stateless_dropout(x, rate=0.5, seed=[1, 0])
5530  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5531  array([[2., 0., 2., 0., 0.],
5532         [0., 0., 2., 0., 2.],
5533         [0., 0., 0., 0., 2.]], dtype=float32)>
5534
5535  >>> x = tf.ones([3,5])
5536  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5537  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5538  array([[5., 0., 0., 0., 0.],
5539         [0., 0., 0., 0., 5.],
5540         [0., 0., 0., 0., 5.]], dtype=float32)>
5541
5542  >>> tf.nn.experimental.stateless_dropout(x, rate=0.0, seed=[1, 0]) == x
5543  <tf.Tensor: shape=(3, 5), dtype=bool, numpy=
5544  array([[ True,  True,  True,  True,  True],
5545         [ True,  True,  True,  True,  True],
5546         [ True,  True,  True,  True,  True]])>
5547
5548
5549  This function is a stateless version of `tf.nn.dropout`, in the
5550  sense that no matter how many times you call this function, the same
5551  `seed` will lead to the same results, and different `seed` will lead
5552  to different results.
5553
5554  >>> x = tf.ones([3,5])
5555  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5556  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5557  array([[5., 0., 0., 0., 0.],
5558         [0., 0., 0., 0., 5.],
5559         [0., 0., 0., 0., 5.]], dtype=float32)>
5560  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[1, 0])
5561  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5562  array([[5., 0., 0., 0., 0.],
5563         [0., 0., 0., 0., 5.],
5564         [0., 0., 0., 0., 5.]], dtype=float32)>
5565  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[2, 0])
5566  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5567  array([[5., 0., 0., 0., 0.],
5568         [0., 0., 0., 5., 0.],
5569         [0., 0., 0., 0., 0.]], dtype=float32)>
5570  >>> tf.nn.experimental.stateless_dropout(x, rate=0.8, seed=[2, 0])
5571  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5572  array([[5., 0., 0., 0., 0.],
5573         [0., 0., 0., 5., 0.],
5574         [0., 0., 0., 0., 0.]], dtype=float32)>
5575
5576  Compare the above results to those of `tf.nn.dropout` below. The
5577  second time `tf.nn.dropout` is called with the same seed, it will
5578  give a different output.
5579
5580  >>> tf.random.set_seed(0)
5581  >>> x = tf.ones([3,5])
5582  >>> tf.nn.dropout(x, rate=0.8, seed=1)
5583  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5584  array([[0., 0., 0., 5., 5.],
5585         [0., 5., 0., 5., 0.],
5586         [5., 0., 5., 0., 5.]], dtype=float32)>
5587  >>> tf.nn.dropout(x, rate=0.8, seed=1)
5588  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5589  array([[0., 0., 0., 0., 0.],
5590         [0., 0., 0., 5., 0.],
5591         [0., 0., 0., 0., 0.]], dtype=float32)>
5592  >>> tf.nn.dropout(x, rate=0.8, seed=2)
5593  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5594  array([[0., 0., 0., 0., 0.],
5595         [0., 5., 0., 5., 0.],
5596         [0., 0., 0., 0., 0.]], dtype=float32)>
5597  >>> tf.nn.dropout(x, rate=0.8, seed=2)
5598  <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
5599  array([[0., 0., 0., 0., 0.],
5600         [5., 0., 5., 0., 5.],
5601         [0., 5., 0., 0., 5.]], dtype=float32)>
5602
5603  The difference between this function and `tf.nn.dropout` is
5604  analogous to the difference between `tf.random.stateless_uniform`
5605  and `tf.random.uniform`. Please see [Random number
5606  generation](https://www.tensorflow.org/guide/random_numbers) guide
5607  for a detailed description of the various RNG systems in TF. As the
5608  guide states, legacy stateful RNG ops like `tf.random.uniform` and
5609  `tf.nn.dropout` are not deprecated yet but highly discouraged,
5610  because their states are hard to control.
5611
5612  By default, each element is kept or dropped independently.  If `noise_shape`
5613  is specified, it must be
5614  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
5615  to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
5616  will make independent decisions. This is useful for dropping whole
5617  channels from an image or sequence. For example:
5618
5619  >>> x = tf.ones([3,10])
5620  >>> tf.nn.experimental.stateless_dropout(x, rate=2/3, noise_shape=[1,10],
5621  ...                                      seed=[1, 0])
5622  <tf.Tensor: shape=(3, 10), dtype=float32, numpy=
5623  array([[3., 0., 0., 0., 0., 0., 0., 3., 0., 3.],
5624         [3., 0., 0., 0., 0., 0., 0., 3., 0., 3.],
5625         [3., 0., 0., 0., 0., 0., 0., 3., 0., 3.]], dtype=float32)>
5626
5627  Args:
5628    x: A floating point tensor.
5629    rate: A scalar `Tensor` with the same type as x. The probability
5630      that each element is dropped. For example, setting rate=0.1 would drop
5631      10% of input elements.
5632    seed: An integer tensor of shape `[2]`. The seed of the random numbers.
5633    rng_alg: The algorithm used to generate the random numbers
5634      (default to `"auto_select"`). See the `alg` argument of
5635      `tf.random.stateless_uniform` for the supported values.
5636    noise_shape: A 1-D integer `Tensor`, representing the
5637      shape for randomly generated keep/drop flags.
5638    name: A name for this operation.
5639
5640  Returns:
5641    A Tensor of the same shape and dtype of `x`.
5642
5643  Raises:
5644    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
5645      tensor. `rate=1` is disallowed, because the output would be all zeros,
5646      which is likely not what was intended.
5647  """
5648  uniform_sampler = functools.partial(
5649      stateless_random_ops.stateless_random_uniform, seed=seed, alg=rng_alg)
5650  def dummy_rng_step():
5651    pass
5652  return _dropout(x=x, rate=rate, noise_shape=noise_shape,
5653                  uniform_sampler=uniform_sampler,
5654                  dummy_rng_step=dummy_rng_step, name=name,
5655                  default_name="stateless_dropout")
5656
5657
5658def _dropout(x, rate, noise_shape, uniform_sampler, dummy_rng_step, name,
5659             default_name):
5660  """Shared implementation of the various dropout functions.
5661
5662  Args:
5663    x: same as the namesake in `dropout_v2`.
5664    rate: same as the namesake in `dropout_v2`.
5665    noise_shape: same as the namesake in `dropout_v2`.
5666    uniform_sampler: a callable of signature `(shape, dtype) ->
5667      Tensor`, used to generate a tensor of uniformly-distributed
5668      random numbers, of the given shape and dtype.
5669    dummy_rng_step: a callable of signature `() -> None`, to make a
5670      dummy RNG call in the fast path. In the fast path where rate is
5671      0, we don't need to generate random numbers, but some samplers
5672      still require you to make an RNG call, to make sure that RNG
5673      states won't depend on whether the fast path is taken.
5674    name: same as the namesake in `dropout_v2`.
5675    default_name: a default name in case `name` is `None`.
5676
5677  Returns:
5678    A Tensor of the same shape and dtype of `x`.
5679  """
5680  with ops.name_scope(name, default_name, [x]) as name:
5681    is_rate_number = isinstance(rate, numbers.Real)
5682    if is_rate_number and (rate < 0 or rate >= 1):
5683      raise ValueError("`rate` must be a scalar tensor or a float in the "
5684                       f"range [0, 1). Received: rate={rate}")
5685    x = ops.convert_to_tensor(x, name="x")
5686    x_dtype = x.dtype
5687    if not x_dtype.is_floating:
5688      raise ValueError(
5689          "`x.dtype` must be a floating point tensor as `x` will be "
5690          f"scaled. Received: x_dtype={x_dtype}")
5691    if is_rate_number and rate == 0:
5692      # Fast-path: Return the input immediately if rate is non-tensor & is `0`.
5693      # We trigger this after all error checking
5694      # and after `x` has been converted to a tensor, to prevent inconsistent
5695      # tensor conversions/error raising if rate is changed to/from 0.
5696      #
5697      # We also explicitly call `dummy_rng_step` to make sure
5698      # we don't change the random number generation behavior of
5699      # stateful random ops by entering a fastpath,
5700      # despite not generating a random tensor in the fastpath
5701      dummy_rng_step()
5702      return x
5703
5704    is_executing_eagerly = context.executing_eagerly()
5705    if not tensor_util.is_tf_type(rate):
5706      if is_rate_number:
5707        keep_prob = 1 - rate
5708        scale = 1 / keep_prob
5709        scale = ops.convert_to_tensor(scale, dtype=x_dtype)
5710        ret = gen_math_ops.mul(x, scale)
5711      else:
5712        raise ValueError(
5713            f"`rate` must be a scalar or scalar tensor. Received: rate={rate}")
5714    else:
5715      rate.get_shape().assert_has_rank(0)
5716      rate_dtype = rate.dtype
5717      if rate_dtype != x_dtype:
5718        if not rate_dtype.is_compatible_with(x_dtype):
5719          raise ValueError(
5720              "`x.dtype` must be compatible with `rate.dtype`. "
5721              f"Received: x.dtype={x_dtype} and rate.dtype={rate_dtype}")
5722        rate = gen_math_ops.cast(rate, x_dtype, name="rate")
5723      one_tensor = constant_op.constant(1, dtype=x_dtype)
5724      ret = gen_math_ops.real_div(x, gen_math_ops.sub(one_tensor, rate))
5725
5726    noise_shape = _get_noise_shape(x, noise_shape)
5727    # Sample a uniform distribution on [0.0, 1.0) and select values larger
5728    # than or equal to `rate`.
5729    random_tensor = uniform_sampler(shape=noise_shape, dtype=x_dtype)
5730    keep_mask = random_tensor >= rate
5731    ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype))
5732    if not is_executing_eagerly:
5733      ret.set_shape(x.get_shape())
5734    return ret
5735
5736
5737@tf_export("math.top_k", "nn.top_k")
5738@dispatch.add_dispatch_support
5739def top_k(input, k=1, sorted=True, name=None):  # pylint: disable=redefined-builtin
5740  """Finds values and indices of the `k` largest entries for the last dimension.
5741
5742  If the input is a vector (rank=1), finds the `k` largest entries in the vector
5743  and outputs their values and indices as vectors.  Thus `values[j]` is the
5744  `j`-th largest entry in `input`, and its index is `indices[j]`.
5745
5746  >>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
5747  ...                         k=3)
5748  >>> result.values.numpy()
5749  array([99, 98, 96], dtype=int32)
5750  >>> result.indices.numpy()
5751  array([5, 2, 9], dtype=int32)
5752
5753  For matrices (resp. higher rank input), computes the top `k` entries in each
5754  row (resp. vector along the last dimension).  Thus,
5755
5756  >>> input = tf.random.normal(shape=(3,4,5,6))
5757  >>> k = 2
5758  >>> values, indices  = tf.math.top_k(input, k=k)
5759  >>> values.shape.as_list()
5760  [3, 4, 5, 2]
5761  >>>
5762  >>> values.shape == indices.shape == input.shape[:-1] + [k]
5763  True
5764
5765  The indices can be used to `gather` from a tensor who's shape matches `input`.
5766
5767  >>> gathered_values = tf.gather(input, indices, batch_dims=-1)
5768  >>> assert tf.reduce_all(gathered_values == values)
5769
5770  If two elements are equal, the lower-index element appears first.
5771
5772  >>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
5773  ...                        k=3)
5774  >>> result.indices.numpy()
5775  array([0, 1, 3], dtype=int32)
5776
5777  Args:
5778    input: 1-D or higher `Tensor` with last dimension at least `k`.
5779    k: 0-D `int32` `Tensor`.  Number of top elements to look for along the last
5780      dimension (along each row for matrices).
5781    sorted: If true the resulting `k` elements will be sorted by the values in
5782      descending order.
5783    name: Optional name for the operation.
5784
5785  Returns:
5786    A tuple with two named fields:
5787    values: The `k` largest elements along each last dimensional slice.
5788    indices: The indices of `values` within the last dimension of `input`.
5789  """
5790  return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name)
5791
5792
5793@tf_export("math.approx_max_k", "nn.approx_max_k")
5794@dispatch.add_dispatch_support
5795def approx_max_k(operand,
5796                 k,
5797                 reduction_dimension=-1,
5798                 recall_target=0.95,
5799                 reduction_input_size_override=-1,
5800                 aggregate_to_topk=True,
5801                 name=None):
5802  """Returns max `k` values and their indices of the input `operand` in an approximate manner.
5803
5804  See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is
5805  only optimized on TPU currently.
5806
5807  Args:
5808    operand : Array to search for max-k. Must be a floating number type.
5809    k : Specifies the number of max-k.
5810    reduction_dimension : Integer dimension along which to search. Default: -1.
5811    recall_target : Recall target for the approximation.
5812    reduction_input_size_override : When set to a positive value, it overrides
5813      the size determined by `operand[reduction_dim]` for evaluating the recall.
5814      This option is useful when the given `operand` is only a subset of the
5815      overall computation in SPMD or distributed pipelines, where the true input
5816      size cannot be deferred by the `operand` shape.
5817    aggregate_to_topk : When true, aggregates approximate results to top-k. When
5818      false, returns the approximate results. The number of the approximate
5819      results is implementation defined and is greater equals to the specified
5820      `k`.
5821    name: Optional name for the operation.
5822
5823  Returns:
5824    Tuple of two arrays. The arrays are the max `k` values and the
5825    corresponding indices along the `reduction_dimension` of the input
5826    `operand`. The arrays' dimensions are the same as the input `operand`
5827    except for the `reduction_dimension`: when `aggregate_to_topk` is true,
5828    the reduction dimension is `k`; otherwise, it is greater equals to `k`
5829    where the size is implementation-defined.
5830
5831  We encourage users to wrap `approx_max_k` with jit. See the following
5832  example for maximal inner production search (MIPS):
5833
5834  >>> import tensorflow as tf
5835  >>> @tf.function(jit_compile=True)
5836  ... def mips(qy, db, k=10, recall_target=0.95):
5837  ...   dists = tf.einsum('ik,jk->ij', qy, db)
5838  ...   # returns (f32[qy_size, k], i32[qy_size, k])
5839  ...   return tf.nn.approx_max_k(dists, k=k, recall_target=recall_target)
5840  >>>
5841  >>> qy = tf.random.uniform((256,128))
5842  >>> db = tf.random.uniform((2048,128))
5843  >>> dot_products, neighbors = mips(qy, db, k=20)
5844  """
5845  return gen_nn_ops.approx_top_k(
5846      operand,
5847      k=k,
5848      reduction_dimension=reduction_dimension,
5849      recall_target=recall_target,
5850      is_max_k=True,
5851      reduction_input_size_override=reduction_input_size_override,
5852      aggregate_to_topk=aggregate_to_topk,
5853      name=name)
5854
5855
5856@tf_export("math.approx_min_k", "nn.approx_min_k")
5857@dispatch.add_dispatch_support
5858def approx_min_k(operand,
5859                 k,
5860                 reduction_dimension=-1,
5861                 recall_target=0.95,
5862                 reduction_input_size_override=-1,
5863                 aggregate_to_topk=True,
5864                 name=None):
5865  """Returns min `k` values and their indices of the input `operand` in an approximate manner.
5866
5867  See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is
5868  only optimized on TPU currently.
5869
5870  Args:
5871    operand : Array to search for min-k. Must be a floating number type.
5872    k : Specifies the number of min-k.
5873    reduction_dimension: Integer dimension along which to search. Default: -1.
5874    recall_target: Recall target for the approximation.
5875    reduction_input_size_override : When set to a positive value, it overrides
5876      the size determined by `operand[reduction_dim]` for evaluating the recall.
5877      This option is useful when the given `operand` is only a subset of the
5878      overall computation in SPMD or distributed pipelines, where the true input
5879      size cannot be deferred by the `operand` shape.
5880    aggregate_to_topk: When true, aggregates approximate results to top-k. When
5881      false, returns the approximate results. The number of the approximate
5882      results is implementation defined and is greater equals to the specified
5883      `k`.
5884    name: Optional name for the operation.
5885
5886  Returns:
5887    Tuple of two arrays. The arrays are the least `k` values and the
5888    corresponding indices along the `reduction_dimension` of the input
5889    `operand`.  The arrays' dimensions are the same as the input `operand`
5890    except for the `reduction_dimension`: when `aggregate_to_topk` is true,
5891    the reduction dimension is `k`; otherwise, it is greater equals to `k`
5892    where the size is implementation-defined.
5893
5894  We encourage users to wrap `approx_min_k` with jit. See the following example
5895  for nearest neighbor search over the squared l2 distance:
5896
5897  >>> import tensorflow as tf
5898  >>> @tf.function(jit_compile=True)
5899  ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
5900  ...   dists = half_db_norms - tf.einsum('ik,jk->ij', qy, db)
5901  ...   return tf.nn.approx_min_k(dists, k=k, recall_target=recall_target)
5902  >>>
5903  >>> qy = tf.random.uniform((256,128))
5904  >>> db = tf.random.uniform((2048,128))
5905  >>> half_db_norms = tf.norm(db, axis=1) / 2
5906  >>> dists, neighbors = l2_ann(qy, db, half_db_norms)
5907
5908  In the example above, we compute `db_norms/2 - dot(qy, db^T)` instead of
5909  `qy^2 - 2 dot(qy, db^T) + db^2` for performance reason. The former uses less
5910  arithmetics and produces the same set of neighbors.
5911  """
5912  return gen_nn_ops.approx_top_k(
5913      operand,
5914      k=k,
5915      reduction_dimension=reduction_dimension,
5916      recall_target=recall_target,
5917      is_max_k=False,
5918      reduction_input_size_override=reduction_input_size_override,
5919      aggregate_to_topk=aggregate_to_topk,
5920      name=name)
5921
5922
5923def nth_element(input, n, reverse=False, name=None):  # pylint: disable=redefined-builtin
5924  r"""Finds values of the `n`-th smallest value for the last dimension.
5925
5926  Note that n is zero-indexed.
5927
5928  If the input is a vector (rank-1), finds the entries which is the nth-smallest
5929  value in the vector and outputs their values as scalar tensor.
5930
5931  For matrices (resp. higher rank input), computes the entries which is the
5932  nth-smallest value in each row (resp. vector along the last dimension). Thus,
5933
5934      values.shape = input.shape[:-1]
5935
5936  Args:
5937    input: 1-D or higher `Tensor` with last dimension at least `n+1`.
5938    n: A `Tensor` of type `int32`.
5939      0-D. Position of sorted vector to select along the last dimension (along
5940      each row for matrices). Valid range of n is `[0, input.shape[:-1])`
5941    reverse: An optional `bool`. Defaults to `False`.
5942      When set to True, find the nth-largest value in the vector and vice
5943      versa.
5944    name: A name for the operation (optional).
5945
5946  Returns:
5947    A `Tensor`. Has the same type as `input`.
5948    The `n`-th order statistic along each last dimensional slice.
5949  """
5950  return gen_nn_ops.nth_element(input, n, reverse=reverse, name=name)
5951
5952
5953@tf_export(v1=["nn.fractional_max_pool"])
5954@dispatch.add_dispatch_support
5955@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
5956                        "args are deprecated.  Use fractional_max_pool_v2.")
5957def fractional_max_pool(value,
5958                        pooling_ratio,
5959                        pseudo_random=False,
5960                        overlapping=False,
5961                        deterministic=False,
5962                        seed=0,
5963                        seed2=0,
5964                        name=None):   # pylint: disable=redefined-builtin
5965  r"""Performs fractional max pooling on the input.
5966
5967  This is a deprecated version of `fractional_max_pool`.
5968
5969  Fractional max pooling is slightly different than regular max pooling.  In
5970  regular max pooling, you downsize an input set by taking the maximum value of
5971  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
5972  a factor of N, where N is an integer.  Fractional max pooling, as you might
5973  expect from the word "fractional", means that the overall reduction ratio N
5974  does not have to be an integer.
5975
5976  The sizes of the pooling regions are generated randomly but are fairly
5977  uniform.  For example, let's look at the height dimension, and the constraints
5978  on the list of rows that will be pool boundaries.
5979
5980  First we define the following:
5981
5982  1.  input_row_length : the number of rows from the input set
5983  2.  output_row_length : which will be smaller than the input
5984  3.  alpha = input_row_length / output_row_length : our reduction ratio
5985  4.  K = floor(alpha)
5986  5.  row_pooling_sequence : this is the result list of pool boundary rows
5987
5988  Then, row_pooling_sequence should satisfy:
5989
5990  1.  a[0] = 0 : the first value of the sequence is 0
5991  2.  a[end] = input_row_length : the last value of the sequence is the size
5992  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
5993  4.  length(row_pooling_sequence) = output_row_length+1
5994
5995  Args:
5996    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
5997    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
5998      each dimension of `value`, currently only supports row and col dimension
5999      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
6000      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
6001      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
6002      ratio on height and width dimensions respectively.
6003    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
6004      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
6005      random fashion. Check (Graham, 2015) for difference between
6006      pseudorandom and random.
6007    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
6008      it means when pooling, the values at the boundary of adjacent pooling
6009      cells are used by both cells. For example:
6010      `index  0  1  2  3  4`
6011      `value  20 5  16 3  7`
6012      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
6013      twice.  The result would be [20, 16] for fractional max pooling.
6014    deterministic: An optional `bool`.  Deprecated; use `fractional_max_pool_v2`
6015      instead.
6016    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
6017      random number generator is seeded by the given seed.  Otherwise it is
6018      seeded by a random seed.
6019    seed2: An optional `int`.  Deprecated; use `fractional_max_pool_v2` instead.
6020    name: A name for the operation (optional).
6021
6022  Returns:
6023  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
6024  `col_pooling_sequence`).
6025    output: Output `Tensor` after fractional max pooling.  Has the same type as
6026      `value`.
6027    row_pooling_sequence: A `Tensor` of type `int64`.
6028    col_pooling_sequence: A `Tensor` of type `int64`.
6029
6030  Raises:
6031    ValueError: If op determinism is enabled and either the seeds are not set or
6032      the "deterministic" argument is False.
6033
6034  References:
6035    Fractional Max-Pooling:
6036      [Graham, 2015](https://arxiv.org/abs/1412.6071)
6037      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
6038  """
6039  if config.is_op_determinism_enabled() and (not seed or not seed2 or
6040                                             not deterministic):
6041    raise ValueError(
6042        f'tf.compat.v1.nn.fractional_max_pool requires "seed" and '
6043        f'"seed2" to be non-zero and "deterministic" to be true when op '
6044        f"determinism is enabled. Please pass in such values, e.g. by passing"
6045        f'"seed=1, seed2=1, deterministic=True". Got: seed={seed}, '
6046        f'seed2={seed2}, deterministic={deterministic}')
6047  return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
6048                                        overlapping, deterministic, seed, seed2,
6049                                        name)
6050
6051
6052@tf_export("nn.fractional_max_pool", v1=[])
6053@dispatch.add_dispatch_support
6054def fractional_max_pool_v2(value,
6055                           pooling_ratio,
6056                           pseudo_random=False,
6057                           overlapping=False,
6058                           seed=0,
6059                           name=None):  # pylint: disable=redefined-builtin
6060  r"""Performs fractional max pooling on the input.
6061
6062  Fractional max pooling is slightly different than regular max pooling.  In
6063  regular max pooling, you downsize an input set by taking the maximum value of
6064  smaller N x N subsections of the set (often 2x2), and try to reduce the set by
6065  a factor of N, where N is an integer.  Fractional max pooling, as you might
6066  expect from the word "fractional", means that the overall reduction ratio N
6067  does not have to be an integer.
6068
6069  The sizes of the pooling regions are generated randomly but are fairly
6070  uniform.  For example, let's look at the height dimension, and the constraints
6071  on the list of rows that will be pool boundaries.
6072
6073  First we define the following:
6074
6075  1.  input_row_length : the number of rows from the input set
6076  2.  output_row_length : which will be smaller than the input
6077  3.  alpha = input_row_length / output_row_length : our reduction ratio
6078  4.  K = floor(alpha)
6079  5.  row_pooling_sequence : this is the result list of pool boundary rows
6080
6081  Then, row_pooling_sequence should satisfy:
6082
6083  1.  a[0] = 0 : the first value of the sequence is 0
6084  2.  a[end] = input_row_length : the last value of the sequence is the size
6085  3.  K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size
6086  4.  length(row_pooling_sequence) = output_row_length+1
6087
6088  Args:
6089    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
6090    pooling_ratio: An int or list of `ints` that has length `1`, `2` or `4`.
6091      Pooling ratio for each dimension of `value`, currently only supports row
6092      and col dimension and should be >= 1.0. For example, a valid pooling ratio
6093      looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements must be 1.0
6094      because we don't allow pooling on batch and channels dimensions.  1.44 and
6095      1.73 are pooling ratio on height and width dimensions respectively.
6096    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
6097      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
6098      random fashion. Check paper (Graham, 2015) for difference between
6099      pseudorandom and random.
6100    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
6101      it means when pooling, the values at the boundary of adjacent pooling
6102      cells are used by both cells. For example:
6103      `index  0  1  2  3  4`
6104      `value  20 5  16 3  7`
6105      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
6106      twice.  The result would be [20, 16] for fractional max pooling.
6107    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
6108      random number generator is seeded by the given seed.  Otherwise it is
6109      seeded by a random seed.
6110    name: A name for the operation (optional).
6111
6112  Returns:
6113  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
6114  `col_pooling_sequence`).
6115    output: Output `Tensor` after fractional max pooling.  Has the same type as
6116      `value`.
6117    row_pooling_sequence: A `Tensor` of type `int64`.
6118    col_pooling_sequence: A `Tensor` of type `int64`.
6119
6120  Raises:
6121    ValueError: If no seed is specified and op determinism is enabled.
6122
6123  References:
6124    Fractional Max-Pooling:
6125      [Graham, 2015](https://arxiv.org/abs/1412.6071)
6126      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
6127  """
6128  if (isinstance(pooling_ratio, (list, tuple))):
6129    if (pooling_ratio[0] != 1.0 or pooling_ratio[-1] != 1.0):
6130      raise ValueError(
6131          "`pooling_ratio` should have first and last elements with value 1.0. "
6132          f"Received: pooling_ratio={pooling_ratio}")
6133    for element in pooling_ratio:
6134      if element < 1.0:
6135        raise ValueError(
6136            f"`pooling_ratio` elements should be >= 1.0. "
6137            f"Received: pooling_ratio={pooling_ratio}")
6138  elif (isinstance(pooling_ratio, (int, float))):
6139    if pooling_ratio < 1.0:
6140      raise ValueError(
6141          "`pooling_ratio` should be >= 1.0. "
6142          f"Received: pooling_ratio={pooling_ratio}")
6143  else:
6144    raise ValueError(
6145        "`pooling_ratio` should be an int or a list of ints. "
6146        f"Received: pooling_ratio={pooling_ratio}")
6147
6148  pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio")
6149
6150  if seed == 0:
6151    if config.is_op_determinism_enabled():
6152      raise ValueError(
6153          f"tf.nn.fractional_max_pool requires a non-zero seed to be passed in "
6154          f"when determinism is enabled, but got seed={seed}. Please pass in a "
6155          f'non-zero seed, e.g. by passing "seed=1".')
6156    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
6157                                          overlapping, deterministic=False,
6158                                          seed=0, seed2=0, name=name)
6159  else:
6160    seed1, seed2 = random_seed.get_seed(seed)
6161    return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
6162                                          overlapping, deterministic=True,
6163                                          seed=seed1, seed2=seed2, name=name)
6164
6165
6166@tf_export(v1=["nn.fractional_avg_pool"])
6167@dispatch.add_dispatch_support
6168@deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` "
6169                        "args are deprecated.  Use fractional_avg_pool_v2.")
6170def fractional_avg_pool(value,
6171                        pooling_ratio,
6172                        pseudo_random=False,
6173                        overlapping=False,
6174                        deterministic=False,
6175                        seed=0,
6176                        seed2=0,
6177                        name=None):  # pylint: disable=redefined-builtin
6178  r"""Performs fractional average pooling on the input.
6179
6180  This is a deprecated version of `fractional_avg_pool`.
6181
6182  Fractional average pooling is similar to Fractional max pooling in the pooling
6183  region generation step. The only difference is that after pooling regions are
6184  generated, a mean operation is performed instead of a max operation in each
6185  pooling region.
6186
6187  Args:
6188    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
6189    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
6190      each dimension of `value`, currently only supports row and col dimension
6191      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
6192      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
6193      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
6194      ratio on height and width dimensions respectively.
6195    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
6196      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
6197      random fashion. Check paper (Graham, 2015) for difference between
6198      pseudorandom and random.
6199    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
6200      it means when pooling, the values at the boundary of adjacent pooling
6201      cells are used by both cells. For example:
6202      `index  0  1  2  3  4`
6203      `value  20 5  16 3  7`
6204      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
6205      twice.  The result would be [20, 16] for fractional avg pooling.
6206    deterministic: An optional `bool`.  Deprecated; use `fractional_avg_pool_v2`
6207      instead.
6208    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
6209      random number generator is seeded by the given seed.  Otherwise it is
6210      seeded by a random seed.
6211    seed2: An optional `int`.  Deprecated; use `fractional_avg_pool_v2` instead.
6212    name: A name for the operation (optional).
6213
6214  Returns:
6215  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
6216  `col_pooling_sequence`).
6217    output: Output `Tensor` after fractional avg pooling.  Has the same type as
6218      `value`.
6219    row_pooling_sequence: A `Tensor` of type `int64`.
6220    col_pooling_sequence: A `Tensor` of type `int64`.
6221
6222  References:
6223    Fractional Max-Pooling:
6224      [Graham, 2015](https://arxiv.org/abs/1412.6071)
6225      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
6226  """
6227  return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
6228                                        overlapping, deterministic, seed, seed2,
6229                                        name=name)
6230
6231
6232@tf_export("nn.fractional_avg_pool", v1=[])
6233@dispatch.add_dispatch_support
6234def fractional_avg_pool_v2(value,
6235                           pooling_ratio,
6236                           pseudo_random=False,
6237                           overlapping=False,
6238                           seed=0,
6239                           name=None):  # pylint: disable=redefined-builtin
6240  r"""Performs fractional average pooling on the input.
6241
6242  Fractional average pooling is similar to Fractional max pooling in the pooling
6243  region generation step. The only difference is that after pooling regions are
6244  generated, a mean operation is performed instead of a max operation in each
6245  pooling region.
6246
6247  Args:
6248    value: A `Tensor`. 4-D with shape `[batch, height, width, channels]`.
6249    pooling_ratio: A list of `floats` that has length >= 4.  Pooling ratio for
6250      each dimension of `value`, currently only supports row and col dimension
6251      and should be >= 1.0. For example, a valid pooling ratio looks like [1.0,
6252      1.44, 1.73, 1.0]. The first and last elements must be 1.0 because we don't
6253      allow pooling on batch and channels dimensions.  1.44 and 1.73 are pooling
6254      ratio on height and width dimensions respectively.
6255    pseudo_random: An optional `bool`.  Defaults to `False`. When set to `True`,
6256      generates the pooling sequence in a pseudorandom fashion, otherwise, in a
6257      random fashion. Check paper (Graham, 2015) for difference between
6258      pseudorandom and random.
6259    overlapping: An optional `bool`.  Defaults to `False`.  When set to `True`,
6260      it means when pooling, the values at the boundary of adjacent pooling
6261      cells are used by both cells. For example:
6262      `index  0  1  2  3  4`
6263      `value  20 5  16 3  7`
6264      If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used
6265      twice.  The result would be [20, 16] for fractional avg pooling.
6266    seed: An optional `int`.  Defaults to `0`.  If set to be non-zero, the
6267      random number generator is seeded by the given seed.  Otherwise it is
6268      seeded by a random seed.
6269    name: A name for the operation (optional).
6270
6271  Returns:
6272  A tuple of `Tensor` objects (`output`, `row_pooling_sequence`,
6273  `col_pooling_sequence`).
6274    output: Output `Tensor` after fractional avg pooling.  Has the same type as
6275      `value`.
6276    row_pooling_sequence: A `Tensor` of type `int64`.
6277    col_pooling_sequence: A `Tensor` of type `int64`.
6278
6279  References:
6280    Fractional Max-Pooling:
6281      [Graham, 2015](https://arxiv.org/abs/1412.6071)
6282      ([pdf](https://arxiv.org/pdf/1412.6071.pdf))
6283  """
6284  if seed == 0:
6285    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
6286                                          overlapping, deterministic=False,
6287                                          seed=0, seed2=0, name=name)
6288  else:
6289    seed1, seed2 = random_seed.get_seed(seed)
6290    return gen_nn_ops.fractional_avg_pool(value, pooling_ratio, pseudo_random,
6291                                          overlapping, deterministic=True,
6292                                          seed=seed1, seed2=seed2, name=name)
6293
6294
6295@ops.RegisterStatistics("Dilation2D", "flops")
6296def _calc_dilation2d_flops(graph, node):
6297  """Calculates the compute resources needed for Dilation2D."""
6298  input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
6299  input_shape.assert_is_fully_defined()
6300  filter_shape = graph_util.tensor_shape_from_node_def_name(
6301      graph, node.input[1])
6302  filter_shape.assert_is_fully_defined()
6303  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
6304  output_shape.assert_is_fully_defined()
6305  filter_height = int(filter_shape[0])
6306  filter_width = int(filter_shape[1])
6307  output_count = np.prod(output_shape.as_list(), dtype=np.int64)
6308  return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
6309
6310
6311@tf_export(v1=["nn.erosion2d"])
6312@dispatch.add_dispatch_support
6313def erosion2d(value, kernel, strides, rates, padding, name=None):
6314  """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
6315
6316  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
6317  `kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e.,
6318  each input channel is processed independently of the others with its own
6319  structuring function. The `output` tensor has shape
6320  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
6321  output tensor depend on the `padding` algorithm. We currently only support the
6322  default "NHWC" `data_format`.
6323
6324  In detail, the grayscale morphological 2-D erosion is given by:
6325
6326      output[b, y, x, c] =
6327         min_{dy, dx} value[b,
6328                            strides[1] * y - rates[1] * dy,
6329                            strides[2] * x - rates[2] * dx,
6330                            c] -
6331                      kernel[dy, dx, c]
6332
6333  Duality: The erosion of `value` by the `kernel` is equal to the negation of
6334  the dilation of `-value` by the reflected `kernel`.
6335
6336  Args:
6337    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
6338    kernel: A `Tensor`. Must have the same type as `value`.
6339      3-D with shape `[kernel_height, kernel_width, depth]`.
6340    strides: A list of `ints` that has length `>= 4`.
6341      1-D of length 4. The stride of the sliding window for each dimension of
6342      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
6343    rates: A list of `ints` that has length `>= 4`.
6344      1-D of length 4. The input stride for atrous morphological dilation.
6345      Must be: `[1, rate_height, rate_width, 1]`.
6346    padding: A `string` from: `"SAME", "VALID"`.
6347      The type of padding algorithm to use.
6348    name: A name for the operation (optional). If not specified "erosion2d"
6349      is used.
6350
6351  Returns:
6352    A `Tensor`. Has the same type as `value`.
6353    4-D with shape `[batch, out_height, out_width, depth]`.
6354  Raises:
6355    ValueError: If the `value` depth does not match `kernel`' shape, or if
6356      padding is other than `'VALID'` or `'SAME'`.
6357  """
6358  with ops.name_scope(name, "erosion2d", [value, kernel]) as name:
6359    # Reduce erosion to dilation by duality.
6360    return math_ops.negative(
6361        gen_nn_ops.dilation2d(
6362            input=math_ops.negative(value),
6363            filter=array_ops.reverse_v2(kernel, [0, 1]),
6364            strides=strides,
6365            rates=rates,
6366            padding=padding,
6367            name=name))
6368
6369
6370@tf_export("nn.erosion2d", v1=[])
6371@dispatch.add_dispatch_support
6372def erosion2d_v2(value,
6373                 filters,
6374                 strides,
6375                 padding,
6376                 data_format,
6377                 dilations,
6378                 name=None):
6379  """Computes the grayscale erosion of 4-D `value` and 3-D `filters` tensors.
6380
6381  The `value` tensor has shape `[batch, in_height, in_width, depth]` and the
6382  `filters` tensor has shape `[filters_height, filters_width, depth]`, i.e.,
6383  each input channel is processed independently of the others with its own
6384  structuring function. The `output` tensor has shape
6385  `[batch, out_height, out_width, depth]`. The spatial dimensions of the
6386  output tensor depend on the `padding` algorithm. We currently only support the
6387  default "NHWC" `data_format`.
6388
6389  In detail, the grayscale morphological 2-D erosion is given by:
6390
6391      output[b, y, x, c] =
6392         min_{dy, dx} value[b,
6393                            strides[1] * y - dilations[1] * dy,
6394                            strides[2] * x - dilations[2] * dx,
6395                            c] -
6396                      filters[dy, dx, c]
6397
6398  Duality: The erosion of `value` by the `filters` is equal to the negation of
6399  the dilation of `-value` by the reflected `filters`.
6400
6401  Args:
6402    value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`.
6403    filters: A `Tensor`. Must have the same type as `value`.
6404      3-D with shape `[filters_height, filters_width, depth]`.
6405    strides: A list of `ints` that has length `>= 4`.
6406      1-D of length 4. The stride of the sliding window for each dimension of
6407      the input tensor. Must be: `[1, stride_height, stride_width, 1]`.
6408    padding: A `string` from: `"SAME", "VALID"`.
6409      The type of padding algorithm to use. See
6410      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
6411      for more information.
6412    data_format: A `string`, only `"NHWC"` is currently supported.
6413    dilations: A list of `ints` that has length `>= 4`.
6414      1-D of length 4. The input stride for atrous morphological dilation.
6415      Must be: `[1, rate_height, rate_width, 1]`.
6416    name: A name for the operation (optional). If not specified "erosion2d"
6417      is used.
6418
6419  Returns:
6420    A `Tensor`. Has the same type as `value`.
6421    4-D with shape `[batch, out_height, out_width, depth]`.
6422
6423  Raises:
6424    ValueError: If the `value` depth does not match `filters`' shape, or if
6425      padding is other than `'VALID'` or `'SAME'`.
6426  """
6427  if data_format != "NHWC":
6428    raise ValueError("`data_format` values other  than 'NHWC' are not "
6429                     f"supported. Received: data_format={data_format}")
6430
6431  with ops.name_scope(name, "erosion2d", [value, filters]) as name:
6432    # Reduce erosion to dilation by duality.
6433    return math_ops.negative(
6434        gen_nn_ops.dilation2d(
6435            input=math_ops.negative(value),
6436            filter=array_ops.reverse_v2(filters, [0, 1]),
6437            strides=strides,
6438            rates=dilations,
6439            padding=padding,
6440            name=name))
6441
6442
6443@tf_export(v1=["math.in_top_k", "nn.in_top_k"])
6444@dispatch.add_dispatch_support
6445def in_top_k(predictions, targets, k, name=None):
6446  r"""Says whether the targets are in the top `K` predictions.
6447
6448  This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
6449  prediction for the target class is finite (not inf, -inf, or nan) and among
6450  the top `k` predictions among all predictions for example `i`. Note that the
6451  behavior of `InTopK` differs from the `TopK` op in its handling of ties; if
6452  multiple classes have the same prediction value and straddle the top-`k`
6453  boundary, all of those classes are considered to be in the top `k`.
6454
6455  More formally, let
6456
6457    \\(predictions_i\\) be the predictions for all classes for example `i`,
6458    \\(targets_i\\) be the target class for example `i`,
6459    \\(out_i\\) be the output for example `i`,
6460
6461  $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
6462
6463  Args:
6464    predictions: A `Tensor` of type `float32`.
6465      A `batch_size` x `classes` tensor.
6466    targets: A `Tensor`. Must be one of the following types: `int32`, `int64`.
6467      A `batch_size` vector of class ids.
6468    k: An `int`. Number of top elements to look at for computing precision.
6469    name: A name for the operation (optional).
6470
6471  Returns:
6472    A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`.
6473  """
6474  with ops.name_scope(name, "in_top_k"):
6475    return gen_nn_ops.in_top_kv2(predictions, targets, k, name=name)
6476
6477
6478@tf_export("math.in_top_k", "nn.in_top_k", v1=[])
6479@dispatch.add_dispatch_support
6480def in_top_k_v2(targets, predictions, k, name=None):
6481  return in_top_k(predictions, targets, k, name)
6482
6483
6484in_top_k_v2.__doc__ = in_top_k.__doc__
6485
6486
6487tf_export(v1=["nn.quantized_avg_pool"])(
6488    dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool))
6489tf_export(v1=["nn.quantized_conv2d"])(
6490    dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d))
6491tf_export(v1=["nn.quantized_relu_x"])(
6492    dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
6493tf_export(v1=["nn.quantized_max_pool"])(
6494    dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))
6495
6496
6497@tf_export("nn.isotonic_regression", v1=[])
6498@dispatch.add_dispatch_support
6499def isotonic_regression(inputs, decreasing=True, axis=-1):
6500  r"""Solves isotonic regression problems along the given axis.
6501
6502  For each vector x, the problem solved is
6503
6504  $$\argmin_{y_1 >= y_2 >= ... >= y_n} \sum_i (x_i - y_i)^2.$$
6505
6506  As the solution is component-wise constant, a second tensor is returned that
6507  encodes the segments. The problems are solved over the given axis.
6508
6509  Consider the following example, where we solve a batch of two problems. The
6510  first input is [3, 1, 2], while the second [1, 3, 4] (as the axis is 1).
6511  >>> x = tf.constant([[3, 1, 2], [1, 3, 4]], dtype=tf.float32)
6512  >>> y, segments = tf.nn.isotonic_regression(x, axis=1)
6513  >>> y  # The solution.
6514  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
6515  array([[3.       , 1.5      , 1.5      ],
6516         [2.6666667, 2.6666667, 2.6666667]], dtype=float32)>
6517
6518  Note that the first solution has two blocks [2] and [1.5, 1.5]. The second
6519  solution is constant, and thus has a single segment. These segments are
6520  exactly what the second returned tensor encodes:
6521
6522  >>> segments
6523  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
6524  array([[0, 1, 1],
6525         [0, 0, 0]], dtype=int32)>
6526
6527
6528  Args:
6529    inputs: A tensor holding the inputs.
6530    decreasing: If set to False, the inequalities in the optimizing constrained
6531      are flipped.
6532    axis: The axis along which the problems should be solved.
6533
6534  Returns:
6535    output: The solutions, same shape as type as the input.
6536    segments: An int32 tensor, same shape as the input indicating the segments
6537      that have the same value. Specifically, those positions that have the same
6538      value correspond to the same segment. These values start at zero, and are
6539      monotonously increasing for each solution.
6540  """
6541  type_promotions = {
6542      # Float types get mapped to themselves, int8/16 to float32, rest to double
6543      dtypes.float32:
6544          dtypes.float32,
6545      dtypes.half:
6546          dtypes.half,
6547      dtypes.bfloat16:
6548          dtypes.bfloat16,
6549      dtypes.int8:
6550          dtypes.float32,
6551      dtypes.int16:
6552          dtypes.float32,
6553  }
6554  inputs = ops.convert_to_tensor(inputs)
6555  try:
6556    output_dtype = type_promotions[inputs.dtype]
6557  except KeyError:
6558    output_dtype = dtypes.float64
6559
6560  def compute_on_matrix(matrix, name=None):
6561    iso_fn = functools.partial(
6562        gen_nn_ops.isotonic_regression, output_dtype=output_dtype, name=name)
6563    if decreasing:
6564      return iso_fn(matrix)
6565    else:
6566      output, segments = iso_fn(-matrix)
6567      return -output, segments
6568
6569  return _wrap_2d_function(inputs, compute_on_matrix, axis)
6570
6571
6572# Register elementwise ops that don't have Python wrappers.
6573# Unary elementwise ops.
6574dispatch.register_unary_elementwise_api(gen_nn_ops.elu)
6575dispatch.register_unary_elementwise_api(gen_nn_ops.relu)
6576dispatch.register_unary_elementwise_api(gen_nn_ops.selu)
6577dispatch.register_unary_elementwise_api(gen_nn_ops.softsign)
6578