• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import candidate_sampling_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import embedding_ops
31from tensorflow.python.ops import gen_array_ops  # pylint: disable=unused-import
32from tensorflow.python.ops import gen_nn_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.ops import gen_sparse_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.util.deprecation import deprecated_args
38from tensorflow.python.util.deprecation import deprecated_argument_lookup
39from tensorflow.python.util.tf_export import tf_export
40
41
42@tf_export("nn.log_poisson_loss")
43def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
44  """Computes log Poisson loss given `log_input`.
45
46  Gives the log-likelihood loss between the prediction and the target under the
47  assumption that the target has a Poisson distribution.
48  Caveat: By default, this is not the exact loss, but the loss minus a
49    constant term [log(z!)]. That has no effect for optimization, but
50    does not play well with relative loss comparisons. To compute an
51    approximation of the log factorial term, specify
52    compute_full_loss=True to enable Stirling's Approximation.
53
54  For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
55  loss is
56
57        -log(exp(-x) * (x^z) / z!)
58      = -log(exp(-x) * (x^z)) + log(z!)
59      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
60          [ Note the second term is the Stirling's Approximation for log(z!).
61            It is invariant to x and does not affect optimization, though
62            important for correct relative loss comparisons. It is only
63            computed when compute_full_loss == True. ]
64      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
65      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
66
67  Args:
68    targets: A `Tensor` of the same type and shape as `log_input`.
69    log_input: A `Tensor` of type `float32` or `float64`.
70    compute_full_loss: whether to compute the full loss. If false, a constant
71      term is dropped in favor of more efficient optimization.
72    name: A name for the operation (optional).
73
74  Returns:
75    A `Tensor` of the same shape as `log_input` with the componentwise
76    logistic losses.
77
78  Raises:
79    ValueError: If `log_input` and `targets` do not have the same shape.
80  """
81  with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
82    log_input = ops.convert_to_tensor(log_input, name="log_input")
83    targets = ops.convert_to_tensor(targets, name="targets")
84    try:
85      targets.get_shape().merge_with(log_input.get_shape())
86    except ValueError:
87      raise ValueError(
88          "log_input and targets must have the same shape (%s vs %s)" %
89          (log_input.get_shape(), targets.get_shape()))
90
91    result = math_ops.exp(log_input) - log_input * targets
92    if compute_full_loss:
93      # need to create constant tensors here so that their dtypes can be matched
94      # to that of the targets.
95      point_five = constant_op.constant(0.5, dtype=targets.dtype)
96      two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
97
98      stirling_approx = (targets * math_ops.log(targets)) - targets + (
99          point_five * math_ops.log(two_pi * targets))
100      zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
101      ones = array_ops.ones_like(targets, dtype=targets.dtype)
102      cond = math_ops.logical_and(targets >= zeros, targets <= ones)
103      result += array_ops.where(cond, zeros, stirling_approx)
104    return result
105
106
107@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
108def sigmoid_cross_entropy_with_logits(  # pylint: disable=invalid-name
109    _sentinel=None,
110    labels=None,
111    logits=None,
112    name=None):
113  """Computes sigmoid cross entropy given `logits`.
114
115  Measures the probability error in discrete classification tasks in which each
116  class is independent and not mutually exclusive.  For instance, one could
117  perform multilabel classification where a picture can contain both an elephant
118  and a dog at the same time.
119
120  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
121
122        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
123      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
124      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
125      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
126      = (1 - z) * x + log(1 + exp(-x))
127      = x - x * z + log(1 + exp(-x))
128
129  For x < 0, to avoid overflow in exp(-x), we reformulate the above
130
131        x - x * z + log(1 + exp(-x))
132      = log(exp(x)) - x * z + log(1 + exp(-x))
133      = - x * z + log(1 + exp(x))
134
135  Hence, to ensure stability and avoid overflow, the implementation uses this
136  equivalent formulation
137
138      max(x, 0) - x * z + log(1 + exp(-abs(x)))
139
140  `logits` and `labels` must have the same type and shape.
141
142  Args:
143    _sentinel: Used to prevent positional parameters. Internal, do not use.
144    labels: A `Tensor` of the same type and shape as `logits`.
145    logits: A `Tensor` of type `float32` or `float64`.
146    name: A name for the operation (optional).
147
148  Returns:
149    A `Tensor` of the same shape as `logits` with the componentwise
150    logistic losses.
151
152  Raises:
153    ValueError: If `logits` and `labels` do not have the same shape.
154  """
155  # pylint: disable=protected-access
156  nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel,
157                           labels, logits)
158  # pylint: enable=protected-access
159
160  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
161    logits = ops.convert_to_tensor(logits, name="logits")
162    labels = ops.convert_to_tensor(labels, name="labels")
163    try:
164      labels.get_shape().merge_with(logits.get_shape())
165    except ValueError:
166      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
167                       (logits.get_shape(), labels.get_shape()))
168
169    # The logistic loss formula from above is
170    #   x - x * z + log(1 + exp(-x))
171    # For x < 0, a more numerically stable formula is
172    #   -x * z + log(1 + exp(x))
173    # Note that these two expressions can be combined into the following:
174    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
175    # To allow computing gradients at zero, we define custom versions of max and
176    # abs functions.
177    zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
178    cond = (logits >= zeros)
179    relu_logits = array_ops.where(cond, logits, zeros)
180    neg_abs_logits = array_ops.where(cond, -logits, logits)
181    return math_ops.add(
182        relu_logits - logits * labels,
183        math_ops.log1p(math_ops.exp(neg_abs_logits)),
184        name=name)
185
186
187# Note: intentionally calling this v2 to not allow existing code with indirect
188# imports to ignore the sentinel behavior.
189@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
190def sigmoid_cross_entropy_with_logits_v2(  # pylint: disable=invalid-name
191    labels=None,
192    logits=None,
193    name=None):
194  """Computes sigmoid cross entropy given `logits`.
195
196  Measures the probability error in discrete classification tasks in which each
197  class is independent and not mutually exclusive.  For instance, one could
198  perform multilabel classification where a picture can contain both an elephant
199  and a dog at the same time.
200
201  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
202
203        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
204      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
205      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
206      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
207      = (1 - z) * x + log(1 + exp(-x))
208      = x - x * z + log(1 + exp(-x))
209
210  For x < 0, to avoid overflow in exp(-x), we reformulate the above
211
212        x - x * z + log(1 + exp(-x))
213      = log(exp(x)) - x * z + log(1 + exp(-x))
214      = - x * z + log(1 + exp(x))
215
216  Hence, to ensure stability and avoid overflow, the implementation uses this
217  equivalent formulation
218
219      max(x, 0) - x * z + log(1 + exp(-abs(x)))
220
221  `logits` and `labels` must have the same type and shape.
222
223  Args:
224    labels: A `Tensor` of the same type and shape as `logits`.
225    logits: A `Tensor` of type `float32` or `float64`.
226    name: A name for the operation (optional).
227
228  Returns:
229    A `Tensor` of the same shape as `logits` with the componentwise
230    logistic losses.
231
232  Raises:
233    ValueError: If `logits` and `labels` do not have the same shape.
234  """
235  return sigmoid_cross_entropy_with_logits(
236      logits=logits, labels=labels, name=name)
237
238@tf_export("nn.weighted_cross_entropy_with_logits")
239def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
240  """Computes a weighted cross entropy.
241
242  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
243  allows one to trade off recall and precision by up- or down-weighting the
244  cost of a positive error relative to a negative error.
245
246  The usual cross-entropy cost is defined as:
247
248      targets * -log(sigmoid(logits)) +
249          (1 - targets) * -log(1 - sigmoid(logits))
250
251  A value `pos_weights > 1` decreases the false negative count, hence increasing
252  the recall.
253  Conversely setting `pos_weights < 1` decreases the false positive count and
254  increases the precision.
255  This can be seen from the fact that `pos_weight` is introduced as a
256  multiplicative coefficient for the positive targets term
257  in the loss expression:
258
259      targets * -log(sigmoid(logits)) * pos_weight +
260          (1 - targets) * -log(1 - sigmoid(logits))
261
262  For brevity, let `x = logits`, `z = targets`, `q = pos_weight`.
263  The loss is:
264
265        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
266      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
267      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
268      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
269      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
270      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
271
272  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
273  the implementation uses
274
275      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
276
277  `logits` and `targets` must have the same type and shape.
278
279  Args:
280    targets: A `Tensor` of the same type and shape as `logits`.
281    logits: A `Tensor` of type `float32` or `float64`.
282    pos_weight: A coefficient to use on the positive examples.
283    name: A name for the operation (optional).
284
285  Returns:
286    A `Tensor` of the same shape as `logits` with the componentwise
287    weighted logistic losses.
288
289  Raises:
290    ValueError: If `logits` and `targets` do not have the same shape.
291  """
292  with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
293    logits = ops.convert_to_tensor(logits, name="logits")
294    targets = ops.convert_to_tensor(targets, name="targets")
295    try:
296      targets.get_shape().merge_with(logits.get_shape())
297    except ValueError:
298      raise ValueError(
299          "logits and targets must have the same shape (%s vs %s)" %
300          (logits.get_shape(), targets.get_shape()))
301
302    # The logistic loss formula from above is
303    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
304    # For x < 0, a more numerically stable formula is
305    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
306    # To avoid branching, we use the combined version
307    #   (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
308    log_weight = 1 + (pos_weight - 1) * targets
309    return math_ops.add(
310        (1 - targets) * logits,
311        log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
312                      nn_ops.relu(-logits)),
313        name=name)
314
315
316@tf_export(v1=["nn.relu_layer"])
317def relu_layer(x, weights, biases, name=None):
318  """Computes Relu(x * weight + biases).
319
320  Args:
321    x: a 2D tensor.  Dimensions typically: batch, in_units
322    weights: a 2D tensor.  Dimensions typically: in_units, out_units
323    biases: a 1D tensor.  Dimensions: out_units
324    name: A name for the operation (optional).  If not specified
325      "nn_relu_layer" is used.
326
327  Returns:
328    A 2-D Tensor computing relu(matmul(x, weights) + biases).
329    Dimensions typically: batch, out_units.
330  """
331  with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
332    x = ops.convert_to_tensor(x, name="x")
333    weights = ops.convert_to_tensor(weights, name="weights")
334    biases = ops.convert_to_tensor(biases, name="biases")
335    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
336    return nn_ops.relu(xw_plus_b, name=name)
337
338
339def _swish_shape(op):
340  """Shape helper function for swish and _swish_grad function below."""
341  return [op.inputs[0].shape]
342
343
344@function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True)
345def _swish_grad(features, grad):
346  """Gradient of Swish function defined below."""
347  sigmoid_features = math_ops.sigmoid(features)
348  activation_grad = (
349      sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
350  return grad * activation_grad
351
352
353# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) around
354# for backprop, effectively doubling the tensor's memory consumption. We use a
355# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
356# during backprop, and we can free the sigmoid(features) expression immediately
357# after use during the forward pass.
358@tf_export("nn.swish")
359@function.Defun(
360    grad_func=_swish_grad,
361    shape_func=_swish_shape,
362    func_name="swish",
363    noinline=True)
364def swish(features):
365  # pylint: disable=g-doc-args
366  """Computes the Swish activation function: `x * sigmoid(x)`.
367
368  Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
369  https://arxiv.org/abs/1710.05941
370
371  Args:
372    features: A `Tensor` representing preactivation values.
373    name: A name for the operation (optional).
374
375  Returns:
376    The activation value.
377  """
378  # pylint: enable=g-doc-args
379  features = ops.convert_to_tensor(features, name="features")
380  return features * math_ops.sigmoid(features)
381
382
383@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
384@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
385def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
386  """Normalizes along dimension `axis` using an L2 norm.
387
388  For a 1-D tensor with `axis = 0`, computes
389
390      output = x / sqrt(max(sum(x**2), epsilon))
391
392  For `x` with more dimensions, independently normalizes each 1-D slice along
393  dimension `axis`.
394
395  Args:
396    x: A `Tensor`.
397    axis: Dimension along which to normalize.  A scalar or a vector of
398      integers.
399    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
400      divisor if `norm < sqrt(epsilon)`.
401    name: A name for this operation (optional).
402    dim: Deprecated alias for axis.
403
404  Returns:
405    A `Tensor` with the same shape as `x`.
406  """
407  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
408  return l2_normalize_v2(x, axis, epsilon, name)
409
410
411@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
412def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
413  """Normalizes along dimension `axis` using an L2 norm.
414
415  For a 1-D tensor with `axis = 0`, computes
416
417      output = x / sqrt(max(sum(x**2), epsilon))
418
419  For `x` with more dimensions, independently normalizes each 1-D slice along
420  dimension `axis`.
421
422  Args:
423    x: A `Tensor`.
424    axis: Dimension along which to normalize.  A scalar or a vector of
425      integers.
426    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
427      divisor if `norm < sqrt(epsilon)`.
428    name: A name for this operation (optional).
429
430  Returns:
431    A `Tensor` with the same shape as `x`.
432  """
433  with ops.name_scope(name, "l2_normalize", [x]) as name:
434    x = ops.convert_to_tensor(x, name="x")
435    square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
436    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
437    return math_ops.multiply(x, x_inv_norm, name=name)
438
439
440def _count_nonzero(input_tensor, dtype=dtypes.int64):
441  """Same as math_ops.count_nonzero.
442
443  The reduction is done in dtype, which can be faster for 32-bit dtypes.
444
445  Args:
446      input_tensor: numeric tensor
447      dtype: reduction dtype
448
449  Returns:
450      number of nonzero values with type dtype
451  """
452  with ops.name_scope("count_nonzero", values=[input_tensor]):
453    zero = array_ops.zeros([], dtype=input_tensor.dtype)
454    nonzero_count = math_ops.reduce_sum(
455        math_ops.cast(
456            math_ops.not_equal(input_tensor, zero),
457            dtype=dtype), name="nonzero_count")
458    return nonzero_count
459
460
461@tf_export("math.zero_fraction", "nn.zero_fraction")
462def zero_fraction(value, name=None):
463  """Returns the fraction of zeros in `value`.
464
465  If `value` is empty, the result is `nan`.
466
467  This is useful in summaries to measure and report sparsity.  For example,
468
469  ```python
470      z = tf.nn.relu(...)
471      summ = tf.summary.scalar('sparsity', tf.nn.zero_fraction(z))
472  ```
473
474  Args:
475    value: A tensor of numeric type.
476    name: A name for the operation (optional).
477
478  Returns:
479    The fraction of zeros in `value`, with type `float32`.
480  """
481  with ops.name_scope(name, "zero_fraction", [value]):
482    value = ops.convert_to_tensor(value, name="value")
483    size = array_ops.size(value, out_type=dtypes.int64)
484    # If the count is small, we can save memory/CPU with an int32 reduction.
485    num_nonzero = control_flow_ops.cond(
486        size <= dtypes.int32.max,
487        # pylint: disable=g-long-lambda
488        true_fn=lambda: math_ops.cast(
489            _count_nonzero(value, dtype=dtypes.int32),
490            dtype=dtypes.int64),
491        false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
492
493    with ops.name_scope("counts_to_fraction"):
494      num_zero = size - num_nonzero
495      num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
496      size_float32 = math_ops.cast(size, dtype=dtypes.float32)
497      zero_fraction_float32 = num_zero_float32 / size_float32
498
499    return array_ops.identity(zero_fraction_float32, "fraction")
500
501
502# pylint: disable=redefined-builtin
503@tf_export(v1=["nn.depthwise_conv2d"])
504def depthwise_conv2d(input,
505                     filter,
506                     strides,
507                     padding,
508                     rate=None,
509                     name=None,
510                     data_format=None,
511                     dilations=None):
512  """Depthwise 2-D convolution.
513
514  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
515  and a filter tensor of shape
516  `[filter_height, filter_width, in_channels, channel_multiplier]`
517  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
518  applies a different filter to each input channel (expanding from 1 channel
519  to `channel_multiplier` channels for each), then concatenates the results
520  together.  The output has `in_channels * channel_multiplier` channels.
521
522  In detail, with the default NHWC format,
523
524      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
525           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
526                                           strides[2] * j + rate[1] * dj, k]
527
528  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
529  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
530  If any value in `rate` is greater than 1, we perform atrous depthwise
531  convolution, in which case all values in the `strides` tensor must be equal
532  to 1.
533
534  Args:
535    input: 4-D with shape according to `data_format`.
536    filter: 4-D with shape
537      `[filter_height, filter_width, in_channels, channel_multiplier]`.
538    strides: 1-D of size 4.  The stride of the sliding window for each
539      dimension of `input`.
540    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
541      See the "returns" section of `tf.nn.convolution` for details.
542    rate: 1-D of size 2. The dilation rate in which we sample input values
543      across the `height` and `width` dimensions in atrous convolution. If it is
544      greater than 1, then all values of strides must be 1.
545    name: A name for this operation (optional).
546    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
547    dilations: Alias of rate.
548
549  Returns:
550    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
551    "NHWC" format, shape is
552    `[batch, out_height, out_width, in_channels * channel_multiplier].`
553  """
554  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
555  with ops.name_scope(name, "depthwise", [input, filter]) as name:
556    input = ops.convert_to_tensor(input, name="tensor_in")
557    filter = ops.convert_to_tensor(filter, name="filter_in")
558    if rate is None:
559      rate = [1, 1]
560
561    def op(input_converted, _, padding):
562      return nn_ops.depthwise_conv2d_native(
563          input=input_converted,
564          filter=filter,
565          strides=strides,
566          padding=padding,
567          data_format=data_format,
568          name=name)
569
570    return nn_ops.with_space_to_batch(
571        input=input,
572        filter_shape=array_ops.shape(filter),
573        dilation_rate=rate,
574        padding=padding,
575        data_format=data_format,
576        op=op)
577
578
579@tf_export("nn.depthwise_conv2d", v1=[])
580def depthwise_conv2d_v2(input,
581                        filter,
582                        strides,
583                        padding,
584                        data_format=None,
585                        dilations=None,
586                        name=None):
587  """Depthwise 2-D convolution.
588
589  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
590  and a filter tensor of shape
591  `[filter_height, filter_width, in_channels, channel_multiplier]`
592  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
593  applies a different filter to each input channel (expanding from 1 channel
594  to `channel_multiplier` channels for each), then concatenates the results
595  together.  The output has `in_channels * channel_multiplier` channels.
596
597  In detail, with the default NHWC format,
598
599      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
600           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
601                                           strides[2] * j + rate[1] * dj, k]
602
603  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
604  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
605  If any value in `rate` is greater than 1, we perform atrous depthwise
606  convolution, in which case all values in the `strides` tensor must be equal
607  to 1.
608
609  Args:
610    input: 4-D with shape according to `data_format`.
611    filter: 4-D with shape
612      `[filter_height, filter_width, in_channels, channel_multiplier]`.
613    strides: 1-D of size 4.  The stride of the sliding window for each
614      dimension of `input`.
615    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
616      See the "returns" section of `tf.nn.convolution` for details.
617    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
618    dilations: 1-D of size 2. The dilation rate in which we sample input values
619      across the `height` and `width` dimensions in atrous convolution. If it is
620      greater than 1, then all values of strides must be 1.
621    name: A name for this operation (optional).
622
623  Returns:
624    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
625    "NHWC" format, shape is
626    `[batch, out_height, out_width, in_channels * channel_multiplier].`
627  """
628  return depthwise_conv2d(input=input,
629                          filter=filter,
630                          strides=strides,
631                          padding=padding,
632                          rate=dilations,
633                          name=name,
634                          data_format=data_format)
635
636# pylint: enable=redefined-builtin
637
638
639# pylint: disable=redefined-builtin,line-too-long
640@tf_export(v1=["nn.separable_conv2d"])
641def separable_conv2d(input,
642                     depthwise_filter,
643                     pointwise_filter,
644                     strides,
645                     padding,
646                     rate=None,
647                     name=None,
648                     data_format=None,
649                     dilations=None):
650  """2-D convolution with separable filters.
651
652  Performs a depthwise convolution that acts separately on channels followed by
653  a pointwise convolution that mixes channels.  Note that this is separability
654  between dimensions `[1, 2]` and `3`, not spatial separability between
655  dimensions `1` and `2`.
656
657  In detail, with the default NHWC format,
658
659      output[b, i, j, k] = sum_{di, dj, q, r}
660          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
661          depthwise_filter[di, dj, q, r] *
662          pointwise_filter[0, 0, q * channel_multiplier + r, k]
663
664  `strides` controls the strides for the depthwise convolution only, since
665  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
666  `strides[0] = strides[3] = 1`.  For the most common case of the same
667  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
668  If any value in `rate` is greater than 1, we perform atrous depthwise
669  convolution, in which case all values in the `strides` tensor must be equal
670  to 1.
671
672  Args:
673    input: 4-D `Tensor` with shape according to `data_format`.
674    depthwise_filter: 4-D `Tensor` with shape
675      `[filter_height, filter_width, in_channels, channel_multiplier]`.
676      Contains `in_channels` convolutional filters of depth 1.
677    pointwise_filter: 4-D `Tensor` with shape
678      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
679      filter to mix channels after `depthwise_filter` has convolved spatially.
680    strides: 1-D of size 4.  The strides for the depthwise convolution for
681      each dimension of `input`.
682    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm.
683      See the "returns" section of `tf.nn.convolution` for details.
684    rate: 1-D of size 2. The dilation rate in which we sample input values
685      across the `height` and `width` dimensions in atrous convolution. If it is
686      greater than 1, then all values of strides must be 1.
687    name: A name for this operation (optional).
688    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
689    dilations: Alias of rate.
690
691  Returns:
692    A 4-D `Tensor` with shape according to 'data_format'. For
693      example, with data_format="NHWC", shape is [batch, out_height,
694      out_width, out_channels].
695  """
696  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
697  with ops.name_scope(name, "separable_conv2d",
698                      [input, depthwise_filter, pointwise_filter]) as name:
699    input = ops.convert_to_tensor(input, name="tensor_in")
700    depthwise_filter = ops.convert_to_tensor(
701        depthwise_filter, name="depthwise_filter")
702    pointwise_filter = ops.convert_to_tensor(
703        pointwise_filter, name="pointwise_filter")
704
705    pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
706    pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
707    pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
708
709    if rate is None:
710      rate = [1, 1]
711
712    # The layout of the ops in the graph are expected to be as follows:
713    # depthwise_conv2d  // Conv2D op corresponding to native deptwise conv.
714    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
715
716    def op(input_converted, _, padding):
717      return nn_ops.depthwise_conv2d_native(
718          input=input_converted,
719          filter=depthwise_filter,
720          strides=strides,
721          padding=padding,
722          data_format=data_format,
723          name="depthwise")
724
725    depthwise = nn_ops.with_space_to_batch(
726        input=input,
727        filter_shape=array_ops.shape(depthwise_filter),
728        dilation_rate=rate,
729        padding=padding,
730        data_format=data_format,
731        op=op)
732
733    return nn_ops.conv2d(
734        depthwise,
735        pointwise_filter, [1, 1, 1, 1],
736        padding="VALID",
737        data_format=data_format,
738        name=name)
739
740
741@tf_export("nn.separable_conv2d", v1=[])
742def separable_conv2d_v2(
743    input,
744    depthwise_filter,
745    pointwise_filter,
746    strides,
747    padding,
748    data_format=None,
749    dilations=None,
750    name=None,
751):
752  """2-D convolution with separable filters.
753
754  Performs a depthwise convolution that acts separately on channels followed by
755  a pointwise convolution that mixes channels.  Note that this is separability
756  between dimensions `[1, 2]` and `3`, not spatial separability between
757  dimensions `1` and `2`.
758
759  In detail, with the default NHWC format,
760
761      output[b, i, j, k] = sum_{di, dj, q, r}
762          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
763          depthwise_filter[di, dj, q, r] *
764          pointwise_filter[0, 0, q * channel_multiplier + r, k]
765
766  `strides` controls the strides for the depthwise convolution only, since
767  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
768  `strides[0] = strides[3] = 1`.  For the most common case of the same
769  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
770  If any value in `rate` is greater than 1, we perform atrous depthwise
771  convolution, in which case all values in the `strides` tensor must be equal
772  to 1.
773
774  Args:
775    input: 4-D `Tensor` with shape according to `data_format`.
776    depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
777      in_channels, channel_multiplier]`. Contains `in_channels` convolutional
778      filters of depth 1.
779    pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
780      in_channels, out_channels]`.  Pointwise filter to mix channels after
781      `depthwise_filter` has convolved spatially.
782    strides: 1-D of size 4.  The strides for the depthwise convolution for each
783      dimension of `input`.
784    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm. See
785      the "returns" section of `tf.nn.convolution` for details.
786    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
787    dilations: 1-D of size 2. The dilation rate in which we sample input values
788      across the `height` and `width` dimensions in atrous convolution. If it is
789      greater than 1, then all values of strides must be 1.
790    name: A name for this operation (optional).
791
792  Returns:
793    A 4-D `Tensor` with shape according to 'data_format'. For
794      example, with data_format="NHWC", shape is [batch, out_height,
795      out_width, out_channels].
796  """
797  return separable_conv2d(
798      input,
799      depthwise_filter,
800      pointwise_filter,
801      strides,
802      padding,
803      rate=dilations,
804      name=name,
805      data_format=data_format)
806
807# pylint: enable=redefined-builtin,line-too-long
808
809
810@tf_export(v1=["nn.sufficient_statistics"])
811def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
812                          keepdims=None):
813  """Calculate the sufficient statistics for the mean and variance of `x`.
814
815  These sufficient statistics are computed using the one pass algorithm on
816  an input that's optionally shifted. See:
817  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
818
819  Args:
820    x: A `Tensor`.
821    axes: Array of ints. Axes along which to compute mean and variance.
822    shift: A `Tensor` containing the value by which to shift the data for
823      numerical stability, or `None` if no shift is to be performed. A shift
824      close to the true mean provides the most numerically stable results.
825    keep_dims: produce statistics with the same dimensionality as the input.
826    name: Name used to scope the operations that compute the sufficient stats.
827    keepdims: Alias for keep_dims.
828
829  Returns:
830    Four `Tensor` objects of the same type as `x`:
831
832    * the count (number of elements to average over).
833    * the (possibly shifted) sum of the elements in the array.
834    * the (possibly shifted) sum of squares of the elements in the array.
835    * the shift by which the mean must be corrected or None if `shift` is None.
836  """
837  axes = list(set(axes))
838  keep_dims = deprecated_argument_lookup(
839      "keepdims", keepdims, "keep_dims", keep_dims)
840  if keep_dims is None:
841    keep_dims = False
842  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
843    x = ops.convert_to_tensor(x, name="x")
844    x_shape = x.get_shape()
845    if all(x_shape.dims[d].value is not None for d in axes):
846      counts = 1
847      for d in axes:
848        counts *= x_shape.dims[d].value
849      counts = constant_op.constant(counts, dtype=x.dtype)
850    else:  # shape needs to be inferred at runtime.
851      x_dims = array_ops.gather(
852          math_ops.cast(array_ops.shape(x), x.dtype), axes)
853      counts = math_ops.reduce_prod(x_dims, name="count")
854    if shift is not None:
855      shift = ops.convert_to_tensor(shift, name="shift")
856      m_ss = math_ops.subtract(x, shift)
857      v_ss = math_ops.squared_difference(x, shift)
858    else:  # no shift.
859      m_ss = x
860      v_ss = math_ops.square(x)
861    m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
862    v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
863  return counts, m_ss, v_ss, shift
864
865
866@tf_export("nn.sufficient_statistics", v1=[])
867def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
868  """Calculate the sufficient statistics for the mean and variance of `x`.
869
870  These sufficient statistics are computed using the one pass algorithm on
871  an input that's optionally shifted. See:
872  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
873
874  Args:
875    x: A `Tensor`.
876    axes: Array of ints. Axes along which to compute mean and variance.
877    shift: A `Tensor` containing the value by which to shift the data for
878      numerical stability, or `None` if no shift is to be performed. A shift
879      close to the true mean provides the most numerically stable results.
880    keepdims: produce statistics with the same dimensionality as the input.
881    name: Name used to scope the operations that compute the sufficient stats.
882
883  Returns:
884    Four `Tensor` objects of the same type as `x`:
885
886    * the count (number of elements to average over).
887    * the (possibly shifted) sum of the elements in the array.
888    * the (possibly shifted) sum of squares of the elements in the array.
889    * the shift by which the mean must be corrected or None if `shift` is None.
890  """
891  return sufficient_statistics(
892      x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
893
894
895@tf_export("nn.normalize_moments")
896def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
897  """Calculate the mean and variance of based on the sufficient statistics.
898
899  Args:
900    counts: A `Tensor` containing the total count of the data (one value).
901    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
902      shifted) sum of the elements to average over.
903    variance_ss: A `Tensor` containing the variance sufficient statistics: the
904      (possibly shifted) squared sum of the data to compute the variance over.
905    shift: A `Tensor` containing the value by which the data is shifted for
906      numerical stability, or `None` if no shift was performed.
907    name: Name used to scope the operations that compute the moments.
908
909  Returns:
910    Two `Tensor` objects: `mean` and `variance`.
911  """
912  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
913    divisor = math_ops.reciprocal(counts, name="divisor")
914    if shift is not None:
915      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
916      mean = math_ops.add(shifted_mean, shift, name="mean")
917    else:  # no shift.
918      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
919      mean = shifted_mean
920    variance = math_ops.subtract(
921        math_ops.multiply(variance_ss, divisor),
922        math_ops.square(shifted_mean),
923        name="variance")
924  return (mean, variance)
925
926
927@tf_export(v1=["nn.moments"])
928def moments(
929    x,
930    axes,
931    shift=None,  # pylint: disable=unused-argument
932    name=None,
933    keep_dims=None,
934    keepdims=None):
935  """Calculate the mean and variance of `x`.
936
937  The mean and variance are calculated by aggregating the contents of `x`
938  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
939  and variance of a vector.
940
941  Note: shift is currently not used; the true mean is computed and used.
942
943  When using these moments for batch normalization (see
944  `tf.nn.batch_normalization`):
945
946   * for so-called "global normalization", used with convolutional filters with
947     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
948   * for simple batch normalization pass `axes=[0]` (batch only).
949
950  Args:
951    x: A `Tensor`.
952    axes: Array of ints.  Axes along which to compute mean and
953      variance.
954    shift: Not used in the current implementation
955    name: Name used to scope the operations that compute the moments.
956    keep_dims: produce moments with the same dimensionality as the input.
957    keepdims: Alias to keep_dims.
958
959  Returns:
960    Two `Tensor` objects: `mean` and `variance`.
961  """
962  keep_dims = deprecated_argument_lookup(
963      "keepdims", keepdims, "keep_dims", keep_dims)
964  if keep_dims is None:
965    keep_dims = False
966  with ops.name_scope(name, "moments", [x, axes]):
967    # The dynamic range of fp16 is too limited to support the collection of
968    # sufficient statistics. As a workaround we simply perform the operations
969    # on 32-bit floats before converting the mean and variance back to fp16
970    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
971    # Compute true mean while keeping the dims for proper broadcasting.
972    mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
973    # sample variance, not unbiased variance
974    # Note: stop_gradient does not change the gradient that gets
975    #       backpropagated to the mean from the variance calculation,
976    #       because that gradient is zero
977    variance = math_ops.reduce_mean(
978        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
979        axes,
980        keepdims=True,
981        name="variance")
982    if not keep_dims:
983      mean = array_ops.squeeze(mean, axes)
984      variance = array_ops.squeeze(variance, axes)
985    if x.dtype == dtypes.float16:
986      return (math_ops.cast(mean, dtypes.float16),
987              math_ops.cast(variance, dtypes.float16))
988    else:
989      return (mean, variance)
990
991
992@tf_export("nn.moments", v1=[])
993def moments_v2(
994    x,
995    axes,
996    shift=None,
997    keepdims=False,
998    name=None):
999  """Calculates the mean and variance of `x`.
1000
1001  The mean and variance are calculated by aggregating the contents of `x`
1002  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1003  and variance of a vector.
1004
1005  Note: shift is currently not used; the true mean is computed and used.
1006
1007  When using these moments for batch normalization (see
1008  `tf.nn.batch_normalization`):
1009
1010   * for so-called "global normalization", used with convolutional filters with
1011     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1012   * for simple batch normalization pass `axes=[0]` (batch only).
1013
1014  Args:
1015    x: A `Tensor`.
1016    axes: Array of ints.  Axes along which to compute mean and
1017      variance.
1018    shift: Not used in the current implementation.
1019    keepdims: produce moments with the same dimensionality as the input.
1020    name: Name used to scope the operations that compute the moments.
1021
1022  Returns:
1023    Two `Tensor` objects: `mean` and `variance`.
1024  """
1025  return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1026
1027
1028@tf_export(v1=["nn.weighted_moments"])
1029def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1030                     keepdims=None):
1031  """Returns the frequency-weighted mean and variance of `x`.
1032
1033  Args:
1034    x: A tensor.
1035    axes: 1-d tensor of int32 values; these are the axes along which
1036      to compute mean and variance.
1037    frequency_weights: A tensor of positive weights which can be
1038      broadcast with x.
1039    name: Name used to scope the operation.
1040    keep_dims: Produce moments with the same dimensionality as the input.
1041    keepdims: Alias of keep_dims.
1042
1043  Returns:
1044    Two tensors: `weighted_mean` and `weighted_variance`.
1045  """
1046  keep_dims = deprecated_argument_lookup(
1047      "keepdims", keepdims, "keep_dims", keep_dims)
1048  if keep_dims is None:
1049    keep_dims = False
1050  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1051    x = ops.convert_to_tensor(x, name="x")
1052    frequency_weights = ops.convert_to_tensor(
1053        frequency_weights, name="frequency_weights")
1054
1055    # Unlike moments(), this just uses a simpler two-pass method.
1056
1057    # See comment in moments() WRT precision; it applies here too.
1058    needs_cast = x.dtype == dtypes.float16
1059    if needs_cast:
1060      x = math_ops.cast(x, dtypes.float32)
1061
1062    if frequency_weights.dtype != x.dtype:
1063      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1064
1065    # Note that we use keep_dims=True for our reductions regardless of the arg;
1066    # this is so that the results remain broadcast-compatible with the inputs.
1067    weighted_input_sum = math_ops.reduce_sum(
1068        frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1069
1070    # The shape of the weights isn't necessarily the same as x's
1071    # shape, just broadcast-compatible with it -- so this expression
1072    # performs broadcasting to give a per-item weight, with the same
1073    # shape as (freqency_weights * x). This avoids having to reason
1074    # through all the broadcast logic to compute a correct
1075    # sum_of_weights.
1076    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1077
1078    sum_of_weights = math_ops.reduce_sum(
1079        broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1080
1081    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
1082
1083    weighted_mean = math_ops.multiply(weighted_input_sum, divisor)
1084
1085    # Have the weighted mean; now on to variance:
1086    weighted_distsq = math_ops.reduce_sum(
1087        frequency_weights * math_ops.squared_difference(x, weighted_mean),
1088        axes,
1089        name="weighted_distsq",
1090        keepdims=True)
1091
1092    weighted_variance = math_ops.multiply(weighted_distsq, divisor)
1093
1094    if not keep_dims:
1095      weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1096      weighted_variance = array_ops.squeeze(
1097          weighted_variance, axis=axes)
1098
1099    if needs_cast:
1100      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1101      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1102
1103    return weighted_mean, weighted_variance
1104
1105
1106@tf_export("nn.weighted_moments", v1=[])
1107def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1108  """Returns the frequency-weighted mean and variance of `x`.
1109
1110  Args:
1111    x: A tensor.
1112    axes: 1-d tensor of int32 values; these are the axes along which
1113      to compute mean and variance.
1114    frequency_weights: A tensor of positive weights which can be
1115      broadcast with x.
1116    keepdims: Produce moments with the same dimensionality as the input.
1117    name: Name used to scope the operation.
1118
1119  Returns:
1120    Two tensors: `weighted_mean` and `weighted_variance`.
1121  """
1122  return weighted_moments(
1123      x=x,
1124      axes=axes,
1125      frequency_weights=frequency_weights,
1126      name=name,
1127      keep_dims=keepdims)
1128
1129
1130@tf_export("nn.batch_normalization")
1131def batch_normalization(x,
1132                        mean,
1133                        variance,
1134                        offset,
1135                        scale,
1136                        variance_epsilon,
1137                        name=None):
1138  r"""Batch normalization.
1139
1140  As described in http://arxiv.org/abs/1502.03167.
1141  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1142  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1143
1144  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1145
1146  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1147  shapes:
1148
1149    * In all generality, they can have the same number of dimensions as the
1150      input `x`, with identical sizes as `x` for the dimensions that are not
1151      normalized over (the 'depth' dimension(s)), and dimension 1 for the
1152      others which are being normalized over.
1153      `mean` and `variance` in this case would typically be the outputs of
1154      `tf.nn.moments(..., keep_dims=True)` during training, or running averages
1155      thereof during inference.
1156    * In the common case where the 'depth' dimension is the last dimension in
1157      the input tensor `x`, they may be one dimensional tensors of the same
1158      size as the 'depth' dimension.
1159      This is the case for example for the common `[batch, depth]` layout of
1160      fully-connected layers, and `[batch, height, width, depth]` for
1161      convolutions.
1162      `mean` and `variance` in this case would typically be the outputs of
1163      `tf.nn.moments(..., keep_dims=False)` during training, or running averages
1164      thereof during inference.
1165
1166  Args:
1167    x: Input `Tensor` of arbitrary dimensionality.
1168    mean: A mean `Tensor`.
1169    variance: A variance `Tensor`.
1170    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1171      None. If present, will be added to the normalized tensor.
1172    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1173      `None`. If present, the scale is applied to the normalized tensor.
1174    variance_epsilon: A small float number to avoid dividing by 0.
1175    name: A name for this operation (optional).
1176
1177  Returns:
1178    the normalized, scaled, offset tensor.
1179  """
1180  with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1181    inv = math_ops.rsqrt(variance + variance_epsilon)
1182    if scale is not None:
1183      inv *= scale
1184    # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1185    # the precise order of ops that are generated by the expression below.
1186    return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1187        offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1188
1189
1190@tf_export(v1=["nn.fused_batch_norm"])
1191def fused_batch_norm(
1192    x,
1193    scale,
1194    offset,  # pylint: disable=invalid-name
1195    mean=None,
1196    variance=None,
1197    epsilon=0.001,
1198    data_format="NHWC",
1199    is_training=True,
1200    name=None):
1201  r"""Batch normalization.
1202
1203  As described in http://arxiv.org/abs/1502.03167.
1204
1205  Args:
1206    x: Input `Tensor` of 4 dimensions.
1207    scale: A `Tensor` of 1 dimension for scaling.
1208    offset: A `Tensor` of 1 dimension for bias.
1209    mean: A `Tensor` of 1 dimension for population mean used for inference.
1210    variance: A `Tensor` of 1 dimension for population variance
1211              used for inference.
1212    epsilon: A small float number added to the variance of x.
1213    data_format: The data format for x. Either "NHWC" (default) or "NCHW".
1214    is_training: A bool value to specify if the operation is used for
1215                 training or inference.
1216    name: A name for this operation (optional).
1217
1218  Returns:
1219    y: A 4D Tensor for the normalized, scaled, offsetted x.
1220    batch_mean: A 1D Tensor for the mean of x.
1221    batch_var: A 1D Tensor for the variance of x.
1222
1223  Raises:
1224    ValueError: If mean or variance is not None when is_training is True.
1225  """
1226  x = ops.convert_to_tensor(x, name="input")
1227  scale = ops.convert_to_tensor(scale, name="scale")
1228  offset = ops.convert_to_tensor(offset, name="offset")
1229  if is_training:
1230    if (mean is not None) or (variance is not None):
1231      raise ValueError("Both 'mean' and 'variance' must be None "
1232                       "if is_training is True.")
1233  if mean is None:
1234    mean = constant_op.constant([])
1235  if variance is None:
1236    variance = constant_op.constant([])
1237  # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
1238  # prevent exception (see cudnn.h).
1239  min_epsilon = 1.001e-5
1240  epsilon = epsilon if epsilon > min_epsilon else min_epsilon
1241  # TODO(reedwm): In a few weeks, switch to using the V2 version exclusively. We
1242  # currently only use the V2 version for float16 inputs, which is not supported
1243  # by the V1 version.
1244  if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
1245    fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
1246  else:
1247    fused_batch_norm_func = gen_nn_ops._fused_batch_norm  # pylint: disable=protected-access
1248  y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
1249      x,
1250      scale,
1251      offset,
1252      mean,
1253      variance,
1254      epsilon=epsilon,
1255      data_format=data_format,
1256      is_training=is_training,
1257      name=name)
1258  return y, batch_mean, batch_var
1259
1260
1261@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1262def batch_norm_with_global_normalization(t=None,
1263                                         m=None,
1264                                         v=None,
1265                                         beta=None,
1266                                         gamma=None,
1267                                         variance_epsilon=None,
1268                                         scale_after_normalization=None,
1269                                         name=None,
1270                                         input=None,  # pylint: disable=redefined-builtin
1271                                         mean=None,
1272                                         variance=None):
1273  """Batch normalization.
1274
1275  This op is deprecated. See `tf.nn.batch_normalization`.
1276
1277  Args:
1278    t: A 4D input Tensor.
1279    m: A 1D mean Tensor with size matching the last dimension of t.
1280      This is the first output from tf.nn.moments,
1281      or a saved moving average thereof.
1282    v: A 1D variance Tensor with size matching the last dimension of t.
1283      This is the second output from tf.nn.moments,
1284      or a saved moving average thereof.
1285    beta: A 1D beta Tensor with size matching the last dimension of t.
1286      An offset to be added to the normalized tensor.
1287    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1288      If "scale_after_normalization" is true, this tensor will be multiplied
1289      with the normalized tensor.
1290    variance_epsilon: A small float number to avoid dividing by 0.
1291    scale_after_normalization: A bool indicating whether the resulted tensor
1292      needs to be multiplied with gamma.
1293    name: A name for this operation (optional).
1294    input: Alias for t.
1295    mean: Alias for m.
1296    variance: Alias for v.
1297
1298  Returns:
1299     A batch-normalized `t`.
1300  """
1301  t = deprecated_argument_lookup("input", input, "t", t)
1302  m = deprecated_argument_lookup("mean", mean, "m", m)
1303  v = deprecated_argument_lookup("variance", variance, "v", v)
1304  return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1305                             else None, variance_epsilon, name)
1306
1307
1308# pylint: disable=redefined-builtin,line-too-long
1309@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1310def batch_norm_with_global_normalization_v2(input,
1311                                            mean,
1312                                            variance,
1313                                            beta,
1314                                            gamma,
1315                                            variance_epsilon,
1316                                            scale_after_normalization,
1317                                            name=None):
1318  """Batch normalization.
1319
1320  This op is deprecated. See `tf.nn.batch_normalization`.
1321
1322  Args:
1323    input: A 4D input Tensor.
1324    mean: A 1D mean Tensor with size matching the last dimension of t.
1325      This is the first output from tf.nn.moments,
1326      or a saved moving average thereof.
1327    variance: A 1D variance Tensor with size matching the last dimension of t.
1328      This is the second output from tf.nn.moments,
1329      or a saved moving average thereof.
1330    beta: A 1D beta Tensor with size matching the last dimension of t.
1331      An offset to be added to the normalized tensor.
1332    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1333      If "scale_after_normalization" is true, this tensor will be multiplied
1334      with the normalized tensor.
1335    variance_epsilon: A small float number to avoid dividing by 0.
1336    scale_after_normalization: A bool indicating whether the resulted tensor
1337      needs to be multiplied with gamma.
1338    name: A name for this operation (optional).
1339
1340  Returns:
1341     A batch-normalized `t`.
1342  """
1343  return batch_norm_with_global_normalization(t=input,
1344                                              m=mean,
1345                                              v=variance,
1346                                              beta=beta,
1347                                              gamma=gamma,
1348                                              variance_epsilon=variance_epsilon,
1349                                              scale_after_normalization=scale_after_normalization,
1350                                              name=name)
1351
1352# pylint: enable=redefined-builtin,line-too-long
1353
1354
1355def _sum_rows(x):
1356  """Returns a vector summing up each row of the matrix x."""
1357  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1358  # a matrix.  The gradient of _sum_rows(x) is more efficient than
1359  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1360  # we use _sum_rows(x) in the nce_loss() computation since the loss
1361  # is mostly used for training.
1362  cols = array_ops.shape(x)[1]
1363  ones_shape = array_ops.stack([cols, 1])
1364  ones = array_ops.ones(ones_shape, x.dtype)
1365  return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1366
1367
1368def _compute_sampled_logits(weights,
1369                            biases,
1370                            labels,
1371                            inputs,
1372                            num_sampled,
1373                            num_classes,
1374                            num_true=1,
1375                            sampled_values=None,
1376                            subtract_log_q=True,
1377                            remove_accidental_hits=False,
1378                            partition_strategy="mod",
1379                            name=None,
1380                            seed=None):
1381  """Helper function for nce_loss and sampled_softmax_loss functions.
1382
1383  Computes sampled output training logits and labels suitable for implementing
1384  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1385  sampled_softmax_loss).
1386
1387  Note: In the case where num_true > 1, we assign to each target class
1388  the target probability 1 / num_true so that the target probabilities
1389  sum to 1 per-example.
1390
1391  Args:
1392    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1393        objects whose concatenation along dimension 0 has shape
1394        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
1395    biases: A `Tensor` of shape `[num_classes]`.  The (possibly-partitioned)
1396        class biases.
1397    labels: A `Tensor` of type `int64` and shape `[batch_size,
1398        num_true]`. The target classes.  Note that this format differs from
1399        the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
1400    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1401        activations of the input network.
1402    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1403    num_classes: An `int`. The number of possible classes.
1404    num_true: An `int`.  The number of target classes per training example.
1405    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1406        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1407        (if None, we default to `log_uniform_candidate_sampler`)
1408    subtract_log_q: A `bool`.  whether to subtract the log expected count of
1409        the labels in the sample to get the logits of the true labels.
1410        Default is True.  Turn off for Negative Sampling.
1411    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1412        where a sampled class equals one of the target classes.  Default is
1413        False.
1414    partition_strategy: A string specifying the partitioning strategy, relevant
1415        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1416        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1417    name: A name for the operation (optional).
1418    seed: random seed for candidate sampling. Default to None, which doesn't set
1419        the op-level random seed for candidate sampling.
1420  Returns:
1421    out_logits: `Tensor` object with shape
1422        `[batch_size, num_true + num_sampled]`, for passing to either
1423        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1424        `nn.softmax_cross_entropy_with_logits_v2` (sampled softmax).
1425    out_labels: A Tensor object with the same shape as `out_logits`.
1426  """
1427
1428  if isinstance(weights, variables.PartitionedVariable):
1429    weights = list(weights)
1430  if not isinstance(weights, list):
1431    weights = [weights]
1432
1433  with ops.name_scope(name, "compute_sampled_logits",
1434                      weights + [biases, inputs, labels]):
1435    if labels.dtype != dtypes.int64:
1436      labels = math_ops.cast(labels, dtypes.int64)
1437    labels_flat = array_ops.reshape(labels, [-1])
1438
1439    # Sample the negative labels.
1440    #   sampled shape: [num_sampled] tensor
1441    #   true_expected_count shape = [batch_size, 1] tensor
1442    #   sampled_expected_count shape = [num_sampled] tensor
1443    if sampled_values is None:
1444      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1445          true_classes=labels,
1446          num_true=num_true,
1447          num_sampled=num_sampled,
1448          unique=True,
1449          range_max=num_classes,
1450          seed=seed)
1451    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1452    # pylint: disable=unpacking-non-sequence
1453    sampled, true_expected_count, sampled_expected_count = (
1454        array_ops.stop_gradient(s) for s in sampled_values)
1455    # pylint: enable=unpacking-non-sequence
1456    sampled = math_ops.cast(sampled, dtypes.int64)
1457
1458    # labels_flat is a [batch_size * num_true] tensor
1459    # sampled is a [num_sampled] int tensor
1460    all_ids = array_ops.concat([labels_flat, sampled], 0)
1461
1462    # Retrieve the true weights and the logits of the sampled weights.
1463
1464    # weights shape is [num_classes, dim]
1465    all_w = embedding_ops.embedding_lookup(
1466        weights, all_ids, partition_strategy=partition_strategy)
1467    if all_w.dtype != inputs.dtype:
1468      all_w = math_ops.cast(all_w, inputs.dtype)
1469
1470    # true_w shape is [batch_size * num_true, dim]
1471    true_w = array_ops.slice(all_w, [0, 0],
1472                             array_ops.stack(
1473                                 [array_ops.shape(labels_flat)[0], -1]))
1474
1475    sampled_w = array_ops.slice(
1476        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1477    # inputs has shape [batch_size, dim]
1478    # sampled_w has shape [num_sampled, dim]
1479    # Apply X*W', which yields [batch_size, num_sampled]
1480    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1481
1482    # Retrieve the true and sampled biases, compute the true logits, and
1483    # add the biases to the true and sampled logits.
1484    all_b = embedding_ops.embedding_lookup(
1485        biases, all_ids, partition_strategy=partition_strategy)
1486    if all_b.dtype != inputs.dtype:
1487      all_b = math_ops.cast(all_b, inputs.dtype)
1488    # true_b is a [batch_size * num_true] tensor
1489    # sampled_b is a [num_sampled] float tensor
1490    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1491    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1492
1493    # inputs shape is [batch_size, dim]
1494    # true_w shape is [batch_size * num_true, dim]
1495    # row_wise_dots is [batch_size, num_true, dim]
1496    dim = array_ops.shape(true_w)[1:2]
1497    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1498    row_wise_dots = math_ops.multiply(
1499        array_ops.expand_dims(inputs, 1),
1500        array_ops.reshape(true_w, new_true_w_shape))
1501    # We want the row-wise dot plus biases which yields a
1502    # [batch_size, num_true] tensor of true_logits.
1503    dots_as_matrix = array_ops.reshape(row_wise_dots,
1504                                       array_ops.concat([[-1], dim], 0))
1505    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1506    true_b = array_ops.reshape(true_b, [-1, num_true])
1507    true_logits += true_b
1508    sampled_logits += sampled_b
1509
1510    if remove_accidental_hits:
1511      acc_hits = candidate_sampling_ops.compute_accidental_hits(
1512          labels, sampled, num_true=num_true)
1513      acc_indices, acc_ids, acc_weights = acc_hits
1514
1515      # This is how SparseToDense expects the indices.
1516      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1517      acc_ids_2d_int32 = array_ops.reshape(
1518          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1519      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1520                                        "sparse_indices")
1521      # Create sampled_logits_shape = [batch_size, num_sampled]
1522      sampled_logits_shape = array_ops.concat(
1523          [array_ops.shape(labels)[:1],
1524           array_ops.expand_dims(num_sampled, 0)], 0)
1525      if sampled_logits.dtype != acc_weights.dtype:
1526        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1527      sampled_logits += gen_sparse_ops.sparse_to_dense(
1528          sparse_indices,
1529          sampled_logits_shape,
1530          acc_weights,
1531          default_value=0.0,
1532          validate_indices=False)
1533
1534    if subtract_log_q:
1535      # Subtract log of Q(l), prior probability that l appears in sampled.
1536      true_logits -= math_ops.log(true_expected_count)
1537      sampled_logits -= math_ops.log(sampled_expected_count)
1538
1539    # Construct output logits and labels. The true labels/logits start at col 0.
1540    out_logits = array_ops.concat([true_logits, sampled_logits], 1)
1541
1542    # true_logits is a float tensor, ones_like(true_logits) is a float
1543    # tensor of ones. We then divide by num_true to ensure the per-example
1544    # labels sum to 1.0, i.e. form a proper probability distribution.
1545    out_labels = array_ops.concat([
1546        array_ops.ones_like(true_logits) / num_true,
1547        array_ops.zeros_like(sampled_logits)
1548    ], 1)
1549
1550    return out_logits, out_labels
1551
1552
1553@tf_export("nn.nce_loss", v1=[])
1554def nce_loss_v2(weights,
1555                biases,
1556                labels,
1557                inputs,
1558                num_sampled,
1559                num_classes,
1560                num_true=1,
1561                sampled_values=None,
1562                remove_accidental_hits=False,
1563                name="nce_loss"):
1564  """Computes and returns the noise-contrastive estimation training loss.
1565
1566  See [Noise-contrastive estimation: A new estimation principle for
1567  unnormalized statistical
1568  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
1569  Also see our [Candidate Sampling Algorithms
1570  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
1571
1572  A common use case is to use this method for training, and calculate the full
1573  sigmoid loss for evaluation or inference as in the following example:
1574
1575  ```python
1576  if mode == "train":
1577    loss = tf.nn.nce_loss(
1578        weights=weights,
1579        biases=biases,
1580        labels=labels,
1581        inputs=inputs,
1582        ...)
1583  elif mode == "eval":
1584    logits = tf.matmul(inputs, tf.transpose(weights))
1585    logits = tf.nn.bias_add(logits, biases)
1586    labels_one_hot = tf.one_hot(labels, n_classes)
1587    loss = tf.nn.sigmoid_cross_entropy_with_logits(
1588        labels=labels_one_hot,
1589        logits=logits)
1590    loss = tf.reduce_sum(loss, axis=1)
1591  ```
1592
1593  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
1594  strategy will be used. Support for other partition strategy will be added
1595  later.
1596
1597  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
1598  so your labels must be sorted in order of decreasing frequency to achieve
1599  good results.  For more details, see
1600  `tf.nn.log_uniform_candidate_sampler`.
1601
1602  Note: In the case where `num_true` > 1, we assign to each target class
1603  the target probability 1 / `num_true` so that the target probabilities
1604  sum to 1 per-example.
1605
1606  Note: It would be useful to allow a variable number of target classes per
1607  example.  We hope to provide this functionality in a future release.
1608  For now, if you have a variable number of target classes, you can pad them
1609  out to a constant number by either repeating them or by padding
1610  with an otherwise unused class.
1611
1612  Args:
1613    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1614      objects whose concatenation along dimension 0 has shape [num_classes,
1615      dim].  The (possibly-partitioned) class embeddings.
1616    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1617    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
1618      target classes.
1619    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
1620      the input network.
1621    num_sampled: An `int`.  The number of negative classes to randomly sample
1622      per batch. This single sample of negative classes is evaluated for each
1623      element in the batch.
1624    num_classes: An `int`. The number of possible classes.
1625    num_true: An `int`.  The number of target classes per training example.
1626    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1627      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1628      (if None, we default to `log_uniform_candidate_sampler`)
1629    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
1630      where a sampled class equals one of the target classes.  If set to `True`,
1631      this is a "Sampled Logistic" loss instead of NCE, and we are learning to
1632      generate log-odds instead of log probabilities.  See our [Candidate
1633      Sampling Algorithms Reference]
1634        (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
1635          False.
1636    name: A name for the operation (optional).
1637
1638  Returns:
1639    A `batch_size` 1-D tensor of per-example NCE losses.
1640  """
1641  # TODO(yuefengz): get partition_strategy from either variables or distribution
1642  # strategies.
1643  return nce_loss(
1644      weights,
1645      biases,
1646      labels,
1647      inputs,
1648      num_sampled,
1649      num_classes,
1650      num_true=num_true,
1651      sampled_values=sampled_values,
1652      remove_accidental_hits=remove_accidental_hits,
1653      partition_strategy="div",
1654      name=name)
1655
1656
1657@tf_export(v1=["nn.nce_loss"])
1658def nce_loss(weights,
1659             biases,
1660             labels,
1661             inputs,
1662             num_sampled,
1663             num_classes,
1664             num_true=1,
1665             sampled_values=None,
1666             remove_accidental_hits=False,
1667             partition_strategy="mod",
1668             name="nce_loss"):
1669  """Computes and returns the noise-contrastive estimation training loss.
1670
1671  See [Noise-contrastive estimation: A new estimation principle for
1672  unnormalized statistical
1673  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
1674  Also see our [Candidate Sampling Algorithms
1675  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
1676
1677  A common use case is to use this method for training, and calculate the full
1678  sigmoid loss for evaluation or inference. In this case, you must set
1679  `partition_strategy="div"` for the two losses to be consistent, as in the
1680  following example:
1681
1682  ```python
1683  if mode == "train":
1684    loss = tf.nn.nce_loss(
1685        weights=weights,
1686        biases=biases,
1687        labels=labels,
1688        inputs=inputs,
1689        ...,
1690        partition_strategy="div")
1691  elif mode == "eval":
1692    logits = tf.matmul(inputs, tf.transpose(weights))
1693    logits = tf.nn.bias_add(logits, biases)
1694    labels_one_hot = tf.one_hot(labels, n_classes)
1695    loss = tf.nn.sigmoid_cross_entropy_with_logits(
1696        labels=labels_one_hot,
1697        logits=logits)
1698    loss = tf.reduce_sum(loss, axis=1)
1699  ```
1700
1701  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
1702  so your labels must be sorted in order of decreasing frequency to achieve
1703  good results.  For more details, see
1704  `tf.nn.log_uniform_candidate_sampler`.
1705
1706  Note: In the case where `num_true` > 1, we assign to each target class
1707  the target probability 1 / `num_true` so that the target probabilities
1708  sum to 1 per-example.
1709
1710  Note: It would be useful to allow a variable number of target classes per
1711  example.  We hope to provide this functionality in a future release.
1712  For now, if you have a variable number of target classes, you can pad them
1713  out to a constant number by either repeating them or by padding
1714  with an otherwise unused class.
1715
1716  Args:
1717    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1718        objects whose concatenation along dimension 0 has shape
1719        [num_classes, dim].  The (possibly-partitioned) class embeddings.
1720    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1721    labels: A `Tensor` of type `int64` and shape `[batch_size,
1722        num_true]`. The target classes.
1723    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1724        activations of the input network.
1725    num_sampled: An `int`.  The number of negative classes to randomly sample
1726        per batch. This single sample of negative classes is evaluated for each
1727        element in the batch.
1728    num_classes: An `int`. The number of possible classes.
1729    num_true: An `int`.  The number of target classes per training example.
1730    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1731        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1732        (if None, we default to `log_uniform_candidate_sampler`)
1733    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
1734        where a sampled class equals one of the target classes.  If set to
1735        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
1736        learning to generate log-odds instead of log probabilities.  See
1737        our [Candidate Sampling Algorithms Reference]
1738        (https://www.tensorflow.org/extras/candidate_sampling.pdf).
1739        Default is False.
1740    partition_strategy: A string specifying the partitioning strategy, relevant
1741        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1742        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1743    name: A name for the operation (optional).
1744
1745  Returns:
1746    A `batch_size` 1-D tensor of per-example NCE losses.
1747  """
1748  logits, labels = _compute_sampled_logits(
1749      weights=weights,
1750      biases=biases,
1751      labels=labels,
1752      inputs=inputs,
1753      num_sampled=num_sampled,
1754      num_classes=num_classes,
1755      num_true=num_true,
1756      sampled_values=sampled_values,
1757      subtract_log_q=True,
1758      remove_accidental_hits=remove_accidental_hits,
1759      partition_strategy=partition_strategy,
1760      name=name)
1761  sampled_losses = sigmoid_cross_entropy_with_logits(
1762      labels=labels, logits=logits, name="sampled_losses")
1763  # sampled_losses is batch_size x {true_loss, sampled_losses...}
1764  # We sum out true and sampled losses.
1765  return _sum_rows(sampled_losses)
1766
1767
1768@tf_export("nn.sampled_softmax_loss", v1=[])
1769def sampled_softmax_loss_v2(weights,
1770                            biases,
1771                            labels,
1772                            inputs,
1773                            num_sampled,
1774                            num_classes,
1775                            num_true=1,
1776                            sampled_values=None,
1777                            remove_accidental_hits=True,
1778                            seed=None,
1779                            name="sampled_softmax_loss"):
1780  """Computes and returns the sampled softmax training loss.
1781
1782  This is a faster way to train a softmax classifier over a huge number of
1783  classes.
1784
1785  This operation is for training only.  It is generally an underestimate of
1786  the full softmax loss.
1787
1788  A common use case is to use this method for training, and calculate the full
1789  sigmoid loss for evaluation or inference as in the following example:
1790
1791  ```python
1792  if mode == "train":
1793    loss = tf.nn.sampled_softmax_loss(
1794        weights=weights,
1795        biases=biases,
1796        labels=labels,
1797        inputs=inputs,
1798        ...)
1799  elif mode == "eval":
1800    logits = tf.matmul(inputs, tf.transpose(weights))
1801    logits = tf.nn.bias_add(logits, biases)
1802    labels_one_hot = tf.one_hot(labels, n_classes)
1803    loss = tf.nn.softmax_cross_entropy_with_logits_v2(
1804        labels=labels_one_hot,
1805        logits=logits)
1806  ```
1807
1808  See our [Candidate Sampling Algorithms Reference]
1809  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
1810
1811  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
1812  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
1813
1814  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
1815  strategy will be used. Support for other partition strategy will be added
1816  later.
1817
1818  Args:
1819    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1820      objects whose concatenation along dimension 0 has shape [num_classes,
1821      dim].  The (possibly-sharded) class embeddings.
1822    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1823    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
1824      target classes.  Note that this format differs from the `labels` argument
1825      of `nn.softmax_cross_entropy_with_logits_v2`.
1826    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
1827      the input network.
1828    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1829    num_classes: An `int`. The number of possible classes.
1830    num_true: An `int`.  The number of target classes per training example.
1831    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1832      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1833      (if None, we default to `log_uniform_candidate_sampler`)
1834    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1835      where a sampled class equals one of the target classes.  Default is True.
1836    seed: random seed for candidate sampling. Default to None, which doesn't set
1837      the op-level random seed for candidate sampling.
1838    name: A name for the operation (optional).
1839
1840  Returns:
1841    A `batch_size` 1-D tensor of per-example sampled softmax losses.
1842
1843  """
1844  return sampled_softmax_loss(
1845      weights,
1846      biases,
1847      labels,
1848      inputs,
1849      num_sampled,
1850      num_classes,
1851      num_true=num_true,
1852      sampled_values=sampled_values,
1853      remove_accidental_hits=remove_accidental_hits,
1854      partition_strategy="div",
1855      name=name,
1856      seed=seed)
1857
1858
1859@tf_export(v1=["nn.sampled_softmax_loss"])
1860def sampled_softmax_loss(weights,
1861                         biases,
1862                         labels,
1863                         inputs,
1864                         num_sampled,
1865                         num_classes,
1866                         num_true=1,
1867                         sampled_values=None,
1868                         remove_accidental_hits=True,
1869                         partition_strategy="mod",
1870                         name="sampled_softmax_loss",
1871                         seed=None):
1872  """Computes and returns the sampled softmax training loss.
1873
1874  This is a faster way to train a softmax classifier over a huge number of
1875  classes.
1876
1877  This operation is for training only.  It is generally an underestimate of
1878  the full softmax loss.
1879
1880  A common use case is to use this method for training, and calculate the full
1881  softmax loss for evaluation or inference. In this case, you must set
1882  `partition_strategy="div"` for the two losses to be consistent, as in the
1883  following example:
1884
1885  ```python
1886  if mode == "train":
1887    loss = tf.nn.sampled_softmax_loss(
1888        weights=weights,
1889        biases=biases,
1890        labels=labels,
1891        inputs=inputs,
1892        ...,
1893        partition_strategy="div")
1894  elif mode == "eval":
1895    logits = tf.matmul(inputs, tf.transpose(weights))
1896    logits = tf.nn.bias_add(logits, biases)
1897    labels_one_hot = tf.one_hot(labels, n_classes)
1898    loss = tf.nn.softmax_cross_entropy_with_logits_v2(
1899        labels=labels_one_hot,
1900        logits=logits)
1901  ```
1902
1903  See our [Candidate Sampling Algorithms Reference]
1904  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
1905
1906  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
1907  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
1908
1909  Args:
1910    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1911        objects whose concatenation along dimension 0 has shape
1912        [num_classes, dim].  The (possibly-sharded) class embeddings.
1913    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1914    labels: A `Tensor` of type `int64` and shape `[batch_size,
1915        num_true]`. The target classes.  Note that this format differs from
1916        the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
1917    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1918        activations of the input network.
1919    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1920    num_classes: An `int`. The number of possible classes.
1921    num_true: An `int`.  The number of target classes per training example.
1922    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1923        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1924        (if None, we default to `log_uniform_candidate_sampler`)
1925    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1926        where a sampled class equals one of the target classes.  Default is
1927        True.
1928    partition_strategy: A string specifying the partitioning strategy, relevant
1929        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1930        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1931    name: A name for the operation (optional).
1932    seed: random seed for candidate sampling. Default to None, which doesn't set
1933        the op-level random seed for candidate sampling.
1934
1935  Returns:
1936    A `batch_size` 1-D tensor of per-example sampled softmax losses.
1937
1938  """
1939  logits, labels = _compute_sampled_logits(
1940      weights=weights,
1941      biases=biases,
1942      labels=labels,
1943      inputs=inputs,
1944      num_sampled=num_sampled,
1945      num_classes=num_classes,
1946      num_true=num_true,
1947      sampled_values=sampled_values,
1948      subtract_log_q=True,
1949      remove_accidental_hits=remove_accidental_hits,
1950      partition_strategy=partition_strategy,
1951      name=name,
1952      seed=seed)
1953  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
1954  sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
1955      labels=labels, logits=logits)
1956  # sampled_losses is a [batch_size] tensor.
1957  return sampled_losses
1958