• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from tensorflow.python.distribute import distribution_strategy_context as ds
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import candidate_sampling_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import custom_gradient
31from tensorflow.python.ops import embedding_ops
32from tensorflow.python.ops import gen_array_ops  # pylint: disable=unused-import
33from tensorflow.python.ops import gen_nn_ops
34from tensorflow.python.ops import gen_sparse_ops
35from tensorflow.python.ops import linalg_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.ops.losses import util as losses_util
40from tensorflow.python.platform import device_context
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.deprecation import deprecated_args
43from tensorflow.python.util.deprecation import deprecated_argument_lookup
44from tensorflow.python.util.tf_export import tf_export
45
46
47@tf_export("nn.log_poisson_loss")
48@dispatch.add_dispatch_support
49def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
50  """Computes log Poisson loss given `log_input`.
51
52  Gives the log-likelihood loss between the prediction and the target under the
53  assumption that the target has a Poisson distribution.
54  Caveat: By default, this is not the exact loss, but the loss minus a
55    constant term [log(z!)]. That has no effect for optimization, but
56    does not play well with relative loss comparisons. To compute an
57    approximation of the log factorial term, specify
58    compute_full_loss=True to enable Stirling's Approximation.
59
60  For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
61  loss is
62
63        -log(exp(-x) * (x^z) / z!)
64      = -log(exp(-x) * (x^z)) + log(z!)
65      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
66          [ Note the second term is the Stirling's Approximation for log(z!).
67            It is invariant to x and does not affect optimization, though
68            important for correct relative loss comparisons. It is only
69            computed when compute_full_loss == True. ]
70      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
71      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
72
73  Args:
74    targets: A `Tensor` of the same type and shape as `log_input`.
75    log_input: A `Tensor` of type `float32` or `float64`.
76    compute_full_loss: whether to compute the full loss. If false, a constant
77      term is dropped in favor of more efficient optimization.
78    name: A name for the operation (optional).
79
80  Returns:
81    A `Tensor` of the same shape as `log_input` with the componentwise
82    logistic losses.
83
84  Raises:
85    ValueError: If `log_input` and `targets` do not have the same shape.
86  """
87  with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
88    log_input = ops.convert_to_tensor(log_input, name="log_input")
89    targets = ops.convert_to_tensor(targets, name="targets")
90    try:
91      targets.get_shape().assert_is_compatible_with(log_input.get_shape())
92    except ValueError:
93      raise ValueError(
94          "log_input and targets must have the same shape (%s vs %s)" %
95          (log_input.get_shape(), targets.get_shape()))
96
97    result = math_ops.exp(log_input) - log_input * targets
98    if compute_full_loss:
99      # need to create constant tensors here so that their dtypes can be matched
100      # to that of the targets.
101      point_five = constant_op.constant(0.5, dtype=targets.dtype)
102      two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
103
104      stirling_approx = (targets * math_ops.log(targets)) - targets + (
105          point_five * math_ops.log(two_pi * targets))
106      zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
107      ones = array_ops.ones_like(targets, dtype=targets.dtype)
108      cond = math_ops.logical_and(targets >= zeros, targets <= ones)
109      result += array_ops.where(cond, zeros, stirling_approx)
110    return result
111
112
113@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
114@dispatch.add_dispatch_support
115def sigmoid_cross_entropy_with_logits(  # pylint: disable=invalid-name
116    _sentinel=None,
117    labels=None,
118    logits=None,
119    name=None):
120  """See sigmoid_cross_entropy_with_logits_v2."""
121  # pylint: disable=protected-access
122  nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel,
123                           labels, logits)
124  # pylint: enable=protected-access
125
126  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
127    logits = ops.convert_to_tensor(logits, name="logits")
128    labels = ops.convert_to_tensor(labels, name="labels")
129    try:
130      labels.get_shape().assert_is_compatible_with(logits.get_shape())
131    except ValueError:
132      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
133                       (logits.get_shape(), labels.get_shape()))
134
135    # The logistic loss formula from above is
136    #   x - x * z + log(1 + exp(-x))
137    # For x < 0, a more numerically stable formula is
138    #   -x * z + log(1 + exp(x))
139    # Note that these two expressions can be combined into the following:
140    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
141    # To allow computing gradients at zero, we define custom versions of max and
142    # abs functions.
143    zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
144    cond = (logits >= zeros)
145    relu_logits = array_ops.where(cond, logits, zeros)
146    neg_abs_logits = array_ops.where(cond, -logits, logits)  # pylint: disable=invalid-unary-operand-type
147    return math_ops.add(
148        relu_logits - logits * labels,
149        math_ops.log1p(math_ops.exp(neg_abs_logits)),
150        name=name)
151
152
153# Note: intentionally calling this v2 to not allow existing code with indirect
154# imports to ignore the sentinel behavior.
155@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
156@dispatch.add_dispatch_support
157def sigmoid_cross_entropy_with_logits_v2(  # pylint: disable=invalid-name
158    labels=None,
159    logits=None,
160    name=None):
161  r"""Computes sigmoid cross entropy given `logits`.
162
163  Measures the probability error in tasks with two outcomes in which each
164  outcome is independent and need not have a fully certain label. For instance,
165  one could perform a regression where the probability of an event happening is
166  known and used as a label. This loss may also be used for binary
167  classification, where labels are either zero or one.
168
169  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
170
171        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
172      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
173      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
174      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
175      = (1 - z) * x + log(1 + exp(-x))
176      = x - x * z + log(1 + exp(-x))
177
178  For x < 0, to avoid overflow in exp(-x), we reformulate the above
179
180        x - x * z + log(1 + exp(-x))
181      = log(exp(x)) - x * z + log(1 + exp(-x))
182      = - x * z + log(1 + exp(x))
183
184  Hence, to ensure stability and avoid overflow, the implementation uses this
185  equivalent formulation
186
187      max(x, 0) - x * z + log(1 + exp(-abs(x)))
188
189  `logits` and `labels` must have the same type and shape.
190
191  >>> logits = tf.constant([1., -1., 0., 1., -1., 0., 0.])
192  >>> labels = tf.constant([0., 0., 0., 1., 1., 1., 0.5])
193  >>> tf.nn.sigmoid_cross_entropy_with_logits(
194  ...     labels=labels, logits=logits).numpy()
195  array([1.3132617, 0.3132617, 0.6931472, 0.3132617, 1.3132617, 0.6931472,
196         0.6931472], dtype=float32)
197
198  Compared to the losses which handle multiple outcomes,
199  `tf.nn.softmax_cross_entropy_with_logits` for general multi-class
200  classification and `tf.nn.sparse_softmax_cross_entropy_with_logits` for more
201  efficient multi-class classification with hard labels,
202  `sigmoid_cross_entropy_with_logits` is a slight simplification for binary
203  classification:
204
205        sigmoid(x) = softmax([x, 0])[0]
206
207  $$\frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + e^0}$$
208
209  While `sigmoid_cross_entropy_with_logits` works for soft binary labels
210  (probabilities between 0 and 1), it can also be used for binary classification
211  where the labels are hard. There is an equivalence between all three symbols
212  in this case, with a probability 0 indicating the second class or 1 indicating
213  the first class:
214
215  >>> sigmoid_logits = tf.constant([1., -1., 0.])
216  >>> softmax_logits = tf.stack([sigmoid_logits, tf.zeros_like(sigmoid_logits)],
217  ...                           axis=-1)
218  >>> soft_binary_labels = tf.constant([1., 1., 0.])
219  >>> soft_multiclass_labels = tf.stack(
220  ...     [soft_binary_labels, 1. - soft_binary_labels], axis=-1)
221  >>> hard_labels = tf.constant([0, 0, 1])
222  >>> tf.nn.sparse_softmax_cross_entropy_with_logits(
223  ...     labels=hard_labels, logits=softmax_logits).numpy()
224  array([0.31326166, 1.3132616 , 0.6931472 ], dtype=float32)
225  >>> tf.nn.softmax_cross_entropy_with_logits(
226  ...     labels=soft_multiclass_labels, logits=softmax_logits).numpy()
227  array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
228  >>> tf.nn.sigmoid_cross_entropy_with_logits(
229  ...     labels=soft_binary_labels, logits=sigmoid_logits).numpy()
230  array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
231
232  Args:
233    labels: A `Tensor` of the same type and shape as `logits`. Between 0 and 1,
234      inclusive.
235    logits: A `Tensor` of type `float32` or `float64`. Any real number.
236    name: A name for the operation (optional).
237
238  Returns:
239    A `Tensor` of the same shape as `logits` with the componentwise
240    logistic losses.
241
242  Raises:
243    ValueError: If `logits` and `labels` do not have the same shape.
244  """
245  return sigmoid_cross_entropy_with_logits(
246      logits=logits, labels=labels, name=name)
247
248
249sigmoid_cross_entropy_with_logits.__doc__ = (
250    sigmoid_cross_entropy_with_logits_v2.__doc__)
251
252
253@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
254@dispatch.add_dispatch_support
255def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
256                                          name=None):
257  """Computes a weighted cross entropy.
258
259  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
260  allows one to trade off recall and precision by up- or down-weighting the
261  cost of a positive error relative to a negative error.
262
263  The usual cross-entropy cost is defined as:
264
265      labels * -log(sigmoid(logits)) +
266          (1 - labels) * -log(1 - sigmoid(logits))
267
268  A value `pos_weight > 1` decreases the false negative count, hence increasing
269  the recall.
270  Conversely setting `pos_weight < 1` decreases the false positive count and
271  increases the precision.
272  This can be seen from the fact that `pos_weight` is introduced as a
273  multiplicative coefficient for the positive labels term
274  in the loss expression:
275
276      labels * -log(sigmoid(logits)) * pos_weight +
277          (1 - labels) * -log(1 - sigmoid(logits))
278
279  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
280  The loss is:
281
282        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
283      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
284      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
285      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
286      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
287      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
288
289  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
290  the implementation uses
291
292      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
293
294  `logits` and `labels` must have the same type and shape.
295
296  >>> labels = tf.constant([1., 0.5, 0.])
297  >>> logits = tf.constant([1.5, -0.1, -10.])
298  >>> tf.nn.weighted_cross_entropy_with_logits(
299  ...     labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
300  array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
301  >>> tf.nn.weighted_cross_entropy_with_logits(
302  ...     labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
303  array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
304
305  Args:
306    labels: A `Tensor` of the same type and shape as `logits`, with values
307      between 0 and 1 inclusive.
308    logits: A `Tensor` of type `float32` or `float64`, any real numbers.
309    pos_weight: A coefficient to use on the positive examples, typically a
310      scalar but otherwise broadcastable to the shape of `logits`. Its value
311      should be non-negative.
312    name: A name for the operation (optional).
313
314  Returns:
315    A `Tensor` of the same shape as `logits` with the componentwise
316    weighted logistic losses.
317
318  Raises:
319    ValueError: If `logits` and `labels` do not have the same shape.
320  """
321  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
322    logits = ops.convert_to_tensor(logits, name="logits")
323    labels = ops.convert_to_tensor(labels, name="labels")
324    try:
325      labels.get_shape().assert_is_compatible_with(logits.get_shape())
326    except ValueError:
327      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
328                       (logits.get_shape(), labels.get_shape()))
329
330    # The logistic loss formula from above is
331    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
332    # For x < 0, a more numerically stable formula is
333    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
334    # To avoid branching, we use the combined version
335    #   (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
336    log_weight = 1 + (pos_weight - 1) * labels
337    return math_ops.add(
338        (1 - labels) * logits,
339        log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
340                      nn_ops.relu(-logits)),  # pylint: disable=invalid-unary-operand-type
341        name=name)
342
343
344@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
345@dispatch.add_dispatch_support
346@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
347def weighted_cross_entropy_with_logits(labels=None,
348                                       logits=None,
349                                       pos_weight=None,
350                                       name=None,
351                                       targets=None):
352  """Computes a weighted cross entropy.
353
354  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
355  allows one to trade off recall and precision by up- or down-weighting the
356  cost of a positive error relative to a negative error.
357
358  The usual cross-entropy cost is defined as:
359
360      labels * -log(sigmoid(logits)) +
361          (1 - labels) * -log(1 - sigmoid(logits))
362
363  A value `pos_weight > 1` decreases the false negative count, hence increasing
364  the recall.
365  Conversely setting `pos_weight < 1` decreases the false positive count and
366  increases the precision.
367  This can be seen from the fact that `pos_weight` is introduced as a
368  multiplicative coefficient for the positive labels term
369  in the loss expression:
370
371      labels * -log(sigmoid(logits)) * pos_weight +
372          (1 - labels) * -log(1 - sigmoid(logits))
373
374  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
375  The loss is:
376
377        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
378      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
379      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
380      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
381      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
382      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
383
384  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
385  the implementation uses
386
387      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
388
389  `logits` and `labels` must have the same type and shape.
390
391  Args:
392    labels: A `Tensor` of the same type and shape as `logits`.
393    logits: A `Tensor` of type `float32` or `float64`.
394    pos_weight: A coefficient to use on the positive examples.
395    name: A name for the operation (optional).
396    targets: Deprecated alias for labels.
397
398  Returns:
399    A `Tensor` of the same shape as `logits` with the componentwise
400    weighted logistic losses.
401
402  Raises:
403    ValueError: If `logits` and `labels` do not have the same shape.
404  """
405  labels = deprecated_argument_lookup("labels", labels, "targets", targets)
406  return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
407
408
409@tf_export("nn.compute_average_loss")
410@dispatch.add_dispatch_support
411def compute_average_loss(per_example_loss,
412                         sample_weight=None,
413                         global_batch_size=None):
414  """Scales per-example losses with sample_weights and computes their average.
415
416  Usage with distribution strategy and custom training loop:
417
418  ```python
419  with strategy.scope():
420    def compute_loss(labels, predictions, sample_weight=None):
421
422      # If you are using a `Loss` class instead, set reduction to `NONE` so that
423      # we can do the reduction afterwards and divide by global batch size.
424      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
425          labels, predictions)
426
427      # Compute loss that is scaled by sample_weight and by global batch size.
428      return tf.nn.compute_average_loss(
429          per_example_loss,
430          sample_weight=sample_weight,
431          global_batch_size=GLOBAL_BATCH_SIZE)
432  ```
433
434  Args:
435    per_example_loss: Per-example loss.
436    sample_weight: Optional weighting for each example.
437    global_batch_size: Optional global batch size value. Defaults to (size of
438      first dimension of `losses`) * (number of replicas).
439
440  Returns:
441    Scalar loss value.
442  """  # pylint: disable=g-doc-exception
443  per_example_loss = ops.convert_to_tensor(per_example_loss)
444  input_dtype = per_example_loss.dtype
445
446  with losses_util.check_per_example_loss_rank(per_example_loss):
447    if sample_weight is not None:
448      sample_weight = ops.convert_to_tensor(sample_weight)
449      per_example_loss = losses_util.scale_losses_by_sample_weight(
450          per_example_loss, sample_weight)
451    per_example_loss = math_ops.cast(per_example_loss, input_dtype)
452
453    if global_batch_size is None:
454      if ds.has_strategy() and ds.in_cross_replica_context():
455        raise RuntimeError(
456            "You are calling `compute_average_loss` in cross replica context, "
457            "while it was expected to be called in replica context.")
458
459      num_replicas = ds.get_strategy().num_replicas_in_sync
460      per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
461      global_batch_size = per_replica_batch_size * num_replicas
462
463    global_batch_size = math_ops.cast(global_batch_size, input_dtype)
464    return math_ops.reduce_sum(per_example_loss) / global_batch_size
465
466
467@tf_export("nn.scale_regularization_loss")
468@dispatch.add_dispatch_support
469def scale_regularization_loss(regularization_loss):
470  """Scales the sum of the given regularization losses by number of replicas.
471
472  Usage with distribution strategy and custom training loop:
473
474  ```python
475  with strategy.scope():
476    def compute_loss(self, label, predictions):
477      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
478          labels, predictions)
479
480      # Compute loss that is scaled by sample_weight and by global batch size.
481      loss = tf.nn.compute_average_loss(
482          per_example_loss,
483          sample_weight=sample_weight,
484          global_batch_size=GLOBAL_BATCH_SIZE)
485
486      # Add scaled regularization losses.
487      loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
488      return loss
489  ```
490
491  Args:
492    regularization_loss: Regularization loss.
493
494  Returns:
495    Scalar loss value.
496  """  # pylint: disable=g-doc-exception
497  if ds.has_strategy() and ds.in_cross_replica_context():
498    raise RuntimeError(
499        "You are calling `scale_regularization_loss` in cross replica context, "
500        "while it was expected to be called in replica context.")
501
502  num_replicas = ds.get_strategy().num_replicas_in_sync
503  return math_ops.reduce_sum(regularization_loss) / num_replicas
504
505
506@tf_export(v1=["nn.relu_layer"])
507@dispatch.add_dispatch_support
508def relu_layer(x, weights, biases, name=None):
509  """Computes Relu(x * weight + biases).
510
511  Args:
512    x: a 2D tensor.  Dimensions typically: batch, in_units
513    weights: a 2D tensor.  Dimensions typically: in_units, out_units
514    biases: a 1D tensor.  Dimensions: out_units
515    name: A name for the operation (optional).  If not specified
516      "nn_relu_layer" is used.
517
518  Returns:
519    A 2-D Tensor computing relu(matmul(x, weights) + biases).
520    Dimensions typically: batch, out_units.
521  """
522  with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
523    x = ops.convert_to_tensor(x, name="x")
524    weights = ops.convert_to_tensor(weights, name="weights")
525    biases = ops.convert_to_tensor(biases, name="biases")
526    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
527    return nn_ops.relu(xw_plus_b, name=name)
528
529
530@tf_export("nn.silu", "nn.swish")
531@dispatch.add_dispatch_support
532@custom_gradient.custom_gradient
533def swish(features):
534  # pylint: disable=g-doc-args
535  """Computes the SiLU or Swish activation function: `x * sigmoid(x)`.
536
537  The SiLU activation function was introduced in "Gaussian Error Linear Units
538  (GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and
539  "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in
540  Reinforcement Learning"
541  [Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently
542  discovered (and called swish) in "Searching for Activation Functions"
543  [Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941)
544
545  Args:
546    features: A `Tensor` representing preactivation values.
547
548  Returns:
549    The activation value.
550  """
551  # pylint: enable=g-doc-args
552  features = ops.convert_to_tensor(features, name="features")
553
554  def grad(dy):
555    """Gradient for the Swish activation function"""
556    # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
557    # around for backprop, effectively doubling the tensor's memory consumption.
558    # We use a control dependency here so that sigmoid(features) is re-computed
559    # during backprop (the control dep prevents it being de-duped with the
560    # forward pass) and we can free the sigmoid(features) expression immediately
561    # after use during the forward pass.
562    with ops.control_dependencies([dy]):
563      sigmoid_features = math_ops.sigmoid(features)
564    activation_grad = (
565        sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
566    return dy * activation_grad
567
568  return features * math_ops.sigmoid(features), grad
569
570
571# pylint: disable=redefined-builtin
572@tf_export("linalg.normalize")
573@dispatch.add_dispatch_support
574def normalize(tensor, ord="euclidean", axis=None, name=None):
575  """Normalizes `tensor` along dimension `axis` using specified norm.
576
577  This uses `tf.linalg.norm` to compute the norm along `axis`.
578
579  This function can compute several different vector norms (the 1-norm, the
580  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
581  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
582
583  Args:
584    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
585    ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
586      `2`, `np.inf` and any positive real number yielding the corresponding
587      p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
588      `tensor` is a matrix and equivalent to 2-norm for vectors.
589      Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
590        vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
591        '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
592        on how to compute norms for a batch of vectors or matrices stored in a
593        tensor.
594    axis: If `axis` is `None` (the default), the input is considered a vector
595      and a single vector norm is computed over the entire set of values in the
596      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
597      `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
598      input is considered a batch of vectors, and `axis` determines the axis in
599      `tensor` over which to compute vector norms. If `axis` is a 2-tuple of
600      Python integers it is considered a batch of matrices and `axis` determines
601      the axes in `tensor` over which to compute a matrix norm.
602      Negative indices are supported. Example: If you are passing a tensor that
603        can be either a matrix or a batch of matrices at runtime, pass
604        `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
605        computed.
606    name: The name of the op.
607
608  Returns:
609    normalized: A normalized `Tensor` with the same shape as `tensor`.
610    norm: The computed norms with the same shape and dtype `tensor` but the
611      final axis is 1 instead. Same as running
612      `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
613
614  Raises:
615    ValueError: If `ord` or `axis` is invalid.
616  """
617  with ops.name_scope(name, "normalize", [tensor]) as name:
618    tensor = ops.convert_to_tensor(tensor)
619    norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
620    norm = math_ops.cast(norm, tensor.dtype)
621    normalized = tensor / norm
622    return normalized, norm
623
624
625@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize",
626           v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
627@dispatch.add_dispatch_support
628@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
629def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
630  """Normalizes along dimension `axis` using an L2 norm.
631
632  For a 1-D tensor with `axis = 0`, computes
633
634      output = x / sqrt(max(sum(x**2), epsilon))
635
636  For `x` with more dimensions, independently normalizes each 1-D slice along
637  dimension `axis`.
638
639  1-D tensor example:
640  >>> x = tf.constant([3.0, 4.0])
641  >>> tf.math.l2_normalize(x).numpy()
642  array([0.6, 0.8], dtype=float32)
643
644  2-D tensor example:
645  >>> x = tf.constant([[3.0], [4.0]])
646  >>> tf.math.l2_normalize(x, 0).numpy()
647  array([[0.6],
648       [0.8]], dtype=float32)
649
650  >>> x = tf.constant([[3.0], [4.0]])
651  >>> tf.math.l2_normalize(x, 1).numpy()
652  array([[1.],
653       [1.]], dtype=float32)
654
655  Args:
656    x: A `Tensor`.
657    axis: Dimension along which to normalize.  A scalar or a vector of
658      integers.
659    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
660      divisor if `norm < sqrt(epsilon)`.
661    name: A name for this operation (optional).
662    dim: Deprecated, do not use.
663
664  Returns:
665    A `Tensor` with the same shape as `x`.
666  """
667  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
668  with ops.name_scope(name, "l2_normalize", [x]) as name:
669    x = ops.convert_to_tensor(x, name="x")
670    if x.dtype.is_complex:
671      square_real = math_ops.square(math_ops.real(x))
672      square_imag = math_ops.square(math_ops.imag(x))
673      square_sum = math_ops.real(
674          math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True))
675      x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
676      norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm)
677      norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm)
678      return math_ops.complex(norm_real, norm_imag, name=name)
679    square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
680    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
681    return math_ops.multiply(x, x_inv_norm, name=name)
682
683
684def _count_nonzero(input_tensor, dtype=dtypes.int64):
685  """Same as math_ops.count_nonzero.
686
687  The reduction is done in dtype, which can be faster for 32-bit dtypes.
688
689  Args:
690      input_tensor: numeric tensor
691      dtype: reduction dtype
692
693  Returns:
694      number of nonzero values with type dtype
695  """
696  with ops.name_scope("count_nonzero", values=[input_tensor]):
697    zero = array_ops.zeros([], dtype=input_tensor.dtype)
698    nonzero_count = math_ops.reduce_sum(
699        math_ops.cast(
700            math_ops.not_equal(input_tensor, zero),
701            dtype=dtype), name="nonzero_count")
702    return nonzero_count
703
704
705@tf_export("math.zero_fraction", "nn.zero_fraction")
706@dispatch.add_dispatch_support
707def zero_fraction(value, name=None):
708  """Returns the fraction of zeros in `value`.
709
710  If `value` is empty, the result is `nan`.
711
712  This is useful in summaries to measure and report sparsity.  For example,
713
714  ```python
715      z = tf.nn.relu(...)
716      summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
717  ```
718
719  Args:
720    value: A tensor of numeric type.
721    name: A name for the operation (optional).
722
723  Returns:
724    The fraction of zeros in `value`, with type `float32`.
725  """
726  with ops.name_scope(name, "zero_fraction", [value]):
727    value = ops.convert_to_tensor(value, name="value")
728    size = array_ops.size(value, out_type=dtypes.int64)
729    # If the count is small, we can save memory/CPU with an int32 reduction.
730    num_nonzero = control_flow_ops.cond(
731        size <= dtypes.int32.max,
732        # pylint: disable=g-long-lambda
733        true_fn=lambda: math_ops.cast(
734            _count_nonzero(value, dtype=dtypes.int32),
735            dtype=dtypes.int64),
736        false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
737
738    with ops.name_scope("counts_to_fraction"):
739      num_zero = size - num_nonzero
740      num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
741      size_float32 = math_ops.cast(size, dtype=dtypes.float32)
742      zero_fraction_float32 = num_zero_float32 / size_float32
743
744    return array_ops.identity(zero_fraction_float32, "fraction")
745
746
747# pylint: disable=redefined-builtin
748@tf_export(v1=["nn.depthwise_conv2d"])
749@dispatch.add_dispatch_support
750def depthwise_conv2d(input,
751                     filter,
752                     strides,
753                     padding,
754                     rate=None,
755                     name=None,
756                     data_format=None,
757                     dilations=None):
758  """Depthwise 2-D convolution.
759
760  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
761  and a filter tensor of shape
762  `[filter_height, filter_width, in_channels, channel_multiplier]`
763  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
764  applies a different filter to each input channel (expanding from 1 channel
765  to `channel_multiplier` channels for each), then concatenates the results
766  together.  The output has `in_channels * channel_multiplier` channels.
767
768  In detail, with the default NHWC format,
769
770      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
771           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
772                                           strides[2] * j + rate[1] * dj, k]
773
774  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
775  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
776  If any value in `rate` is greater than 1, we perform atrous depthwise
777  convolution, in which case all values in the `strides` tensor must be equal
778  to 1.
779
780  Usage Example:
781
782  >>> x = np.array([
783  ...     [1., 2.],
784  ...     [3., 4.],
785  ...     [5., 6.]
786  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
787  >>> kernel = np.array([
788  ...     [1., 2.],
789  ...     [3., 4]
790  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
791  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
792  ...                                  padding='VALID').numpy()
793    array([[[[10., 14.],
794             [14., 20.]],
795            [[18., 26.],
796             [22., 32.]]]], dtype=float32)
797
798  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
799  ...                                  padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
800  ...                                 ).numpy()
801    array([[[[ 0.,  0.],
802             [ 3.,  4.],
803             [ 6.,  8.]],
804            [[ 0.,  0.],
805             [10., 14.],
806             [14., 20.]],
807            [[ 0.,  0.],
808             [18., 26.],
809             [22., 32.]]]], dtype=float32)
810
811  Args:
812    input: 4-D with shape according to `data_format`.
813    filter: 4-D with shape
814      `[filter_height, filter_width, in_channels, channel_multiplier]`.
815    strides: 1-D of size 4.  The stride of the sliding window for each
816      dimension of `input`.
817    padding: Controls how to pad the image before applying the convolution. Can
818      be the string `"SAME"` or `"VALID"` indicating the type of padding
819      algorithm to use, or a list indicating the explicit paddings at the start
820      and end of each dimension. When explicit padding is used and data_format
821      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
822      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
823      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
824      [pad_top, pad_bottom], [pad_left, pad_right]]`.
825    rate: 1-D of size 2. The dilation rate in which we sample input values
826      across the `height` and `width` dimensions in atrous convolution. If it is
827      greater than 1, then all values of strides must be 1.
828    name: A name for this operation (optional).
829    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
830    dilations: Alias of rate.
831
832  Returns:
833    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
834    "NHWC" format, shape is
835    `[batch, out_height, out_width, in_channels * channel_multiplier].`
836  """
837  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
838  with ops.name_scope(name, "depthwise", [input, filter]) as name:
839    input = ops.convert_to_tensor(input, name="tensor_in")
840    filter = ops.convert_to_tensor(filter, name="filter_in")
841    if rate is None:
842      rate = [1, 1]
843
844    # Use depthwise_conv2d_native if executing on TPU.
845    if device_context.enclosing_tpu_context() is not None:
846      if data_format == "NCHW":
847        dilations = [1, 1, rate[0], rate[1]]
848      else:
849        dilations = [1, rate[0], rate[1], 1]
850      return nn_ops.depthwise_conv2d_native(
851          input=input,
852          filter=filter,
853          strides=strides,
854          padding=padding,
855          data_format=data_format,
856          dilations=dilations,
857          name=name)
858
859    def op(input_converted, _, padding):
860      return nn_ops.depthwise_conv2d_native(
861          input=input_converted,
862          filter=filter,
863          strides=strides,
864          padding=padding,
865          data_format=data_format,
866          name=name)
867
868    return nn_ops.with_space_to_batch(
869        input=input,
870        filter_shape=array_ops.shape(filter),
871        dilation_rate=rate,
872        padding=padding,
873        data_format=data_format,
874        op=op)
875
876
877@tf_export("nn.depthwise_conv2d", v1=[])
878@dispatch.add_dispatch_support
879def depthwise_conv2d_v2(input,
880                        filter,
881                        strides,
882                        padding,
883                        data_format=None,
884                        dilations=None,
885                        name=None):
886  """Depthwise 2-D convolution.
887
888  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
889  and a filter tensor of shape
890  `[filter_height, filter_width, in_channels, channel_multiplier]`
891  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
892  applies a different filter to each input channel (expanding from 1 channel
893  to `channel_multiplier` channels for each), then concatenates the results
894  together.  The output has `in_channels * channel_multiplier` channels.
895
896  In detail, with the default NHWC format,
897
898      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
899           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
900                                           strides[2] * j + rate[1] * dj, k]
901
902  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
903  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
904  If any value in `rate` is greater than 1, we perform atrous depthwise
905  convolution, in which case all values in the `strides` tensor must be equal
906  to 1.
907
908  Usage Example:
909
910  >>> x = np.array([
911  ...     [1., 2.],
912  ...     [3., 4.],
913  ...     [5., 6.]
914  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
915  >>> kernel = np.array([
916  ...     [1., 2.],
917  ...     [3., 4]
918  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
919  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
920  ...                        padding='VALID').numpy()
921    array([[[[10., 14.],
922             [14., 20.]],
923            [[18., 26.],
924             [22., 32.]]]], dtype=float32)
925
926  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
927  ...                        padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
928    array([[[[ 0.,  0.],
929             [ 3.,  4.],
930             [ 6.,  8.]],
931            [[ 0.,  0.],
932             [10., 14.],
933             [14., 20.]],
934            [[ 0.,  0.],
935             [18., 26.],
936             [22., 32.]]]], dtype=float32)
937
938  Args:
939    input: 4-D with shape according to `data_format`.
940    filter: 4-D with shape
941      `[filter_height, filter_width, in_channels, channel_multiplier]`.
942    strides: 1-D of size 4.  The stride of the sliding window for each
943      dimension of `input`.
944    padding: Controls how to pad the image before applying the convolution. Can
945      be the string `"SAME"` or `"VALID"` indicating the type of padding
946      algorithm to use, or a list indicating the explicit paddings at the start
947      and end of each dimension. When explicit padding is used and data_format
948      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
949      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
950      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
951      [pad_top, pad_bottom], [pad_left, pad_right]]`.
952    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
953    dilations: 1-D of size 2. The dilation rate in which we sample input values
954      across the `height` and `width` dimensions in atrous convolution. If it is
955      greater than 1, then all values of strides must be 1.
956    name: A name for this operation (optional).
957
958  Returns:
959    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
960    "NHWC" format, shape is
961    `[batch, out_height, out_width, in_channels * channel_multiplier].`
962  """
963  return depthwise_conv2d(input=input,
964                          filter=filter,
965                          strides=strides,
966                          padding=padding,
967                          rate=dilations,
968                          name=name,
969                          data_format=data_format)
970
971# pylint: enable=redefined-builtin
972
973
974# pylint: disable=redefined-builtin,line-too-long
975@tf_export(v1=["nn.separable_conv2d"])
976@dispatch.add_dispatch_support
977def separable_conv2d(input,
978                     depthwise_filter,
979                     pointwise_filter,
980                     strides,
981                     padding,
982                     rate=None,
983                     name=None,
984                     data_format=None,
985                     dilations=None):
986  """2-D convolution with separable filters.
987
988  Performs a depthwise convolution that acts separately on channels followed by
989  a pointwise convolution that mixes channels.  Note that this is separability
990  between dimensions `[1, 2]` and `3`, not spatial separability between
991  dimensions `1` and `2`.
992
993  In detail, with the default NHWC format,
994
995      output[b, i, j, k] = sum_{di, dj, q, r}
996          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
997          depthwise_filter[di, dj, q, r] *
998          pointwise_filter[0, 0, q * channel_multiplier + r, k]
999
1000  `strides` controls the strides for the depthwise convolution only, since
1001  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1002  `strides[0] = strides[3] = 1`.  For the most common case of the same
1003  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1004  If any value in `rate` is greater than 1, we perform atrous depthwise
1005  convolution, in which case all values in the `strides` tensor must be equal
1006  to 1.
1007
1008  Args:
1009    input: 4-D `Tensor` with shape according to `data_format`.
1010    depthwise_filter: 4-D `Tensor` with shape
1011      `[filter_height, filter_width, in_channels, channel_multiplier]`.
1012      Contains `in_channels` convolutional filters of depth 1.
1013    pointwise_filter: 4-D `Tensor` with shape
1014      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
1015      filter to mix channels after `depthwise_filter` has convolved spatially.
1016    strides: 1-D of size 4.  The strides for the depthwise convolution for
1017      each dimension of `input`.
1018    padding: Controls how to pad the image before applying the depthwise
1019      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1020      of padding algorithm to use, or a Python list indicating the explicit
1021      paddings at the start and end of each dimension. When explicit padding is
1022      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1023      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1024      padding used and data_format is `"NCHW"`, this should be in the form
1025      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1026    rate: 1-D of size 2. The dilation rate in which we sample input values
1027      across the `height` and `width` dimensions in atrous convolution. If it is
1028      greater than 1, then all values of strides must be 1.
1029    name: A name for this operation (optional).
1030    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1031    dilations: Alias of rate.
1032
1033  Returns:
1034    A 4-D `Tensor` with shape according to 'data_format'. For
1035      example, with data_format="NHWC", shape is [batch, out_height,
1036      out_width, out_channels].
1037  """
1038  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
1039  with ops.name_scope(name, "separable_conv2d",
1040                      [input, depthwise_filter, pointwise_filter]) as name:
1041    input = ops.convert_to_tensor(input, name="tensor_in")
1042    depthwise_filter = ops.convert_to_tensor(
1043        depthwise_filter, name="depthwise_filter")
1044    pointwise_filter = ops.convert_to_tensor(
1045        pointwise_filter, name="pointwise_filter")
1046
1047    pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
1048    pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
1049    pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
1050
1051    if rate is None:
1052      rate = [1, 1]
1053
1054    # The layout of the ops in the graph are expected to be as follows:
1055    # depthwise_conv2d  // Conv2D op corresponding to native depthwise conv.
1056    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
1057
1058    def op(input_converted, _, padding):
1059      return nn_ops.depthwise_conv2d_native(
1060          input=input_converted,
1061          filter=depthwise_filter,
1062          strides=strides,
1063          padding=padding,
1064          data_format=data_format,
1065          name="depthwise")
1066
1067    depthwise = nn_ops.with_space_to_batch(
1068        input=input,
1069        filter_shape=array_ops.shape(depthwise_filter),
1070        dilation_rate=rate,
1071        padding=padding,
1072        data_format=data_format,
1073        op=op)
1074
1075    return nn_ops.conv2d(
1076        depthwise,
1077        pointwise_filter, [1, 1, 1, 1],
1078        padding="VALID",
1079        data_format=data_format,
1080        name=name)
1081
1082
1083@tf_export("nn.separable_conv2d", v1=[])
1084@dispatch.add_dispatch_support
1085def separable_conv2d_v2(
1086    input,
1087    depthwise_filter,
1088    pointwise_filter,
1089    strides,
1090    padding,
1091    data_format=None,
1092    dilations=None,
1093    name=None,
1094):
1095  """2-D convolution with separable filters.
1096
1097  Performs a depthwise convolution that acts separately on channels followed by
1098  a pointwise convolution that mixes channels.  Note that this is separability
1099  between dimensions `[1, 2]` and `3`, not spatial separability between
1100  dimensions `1` and `2`.
1101
1102  In detail, with the default NHWC format,
1103
1104      output[b, i, j, k] = sum_{di, dj, q, r}
1105          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1106          depthwise_filter[di, dj, q, r] *
1107          pointwise_filter[0, 0, q * channel_multiplier + r, k]
1108
1109  `strides` controls the strides for the depthwise convolution only, since
1110  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1111  `strides[0] = strides[3] = 1`.  For the most common case of the same
1112  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1113  If any value in `rate` is greater than 1, we perform atrous depthwise
1114  convolution, in which case all values in the `strides` tensor must be equal
1115  to 1.
1116
1117  Args:
1118    input: 4-D `Tensor` with shape according to `data_format`.
1119    depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
1120      in_channels, channel_multiplier]`. Contains `in_channels` convolutional
1121      filters of depth 1.
1122    pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
1123      in_channels, out_channels]`.  Pointwise filter to mix channels after
1124      `depthwise_filter` has convolved spatially.
1125    strides: 1-D of size 4.  The strides for the depthwise convolution for each
1126      dimension of `input`.
1127    padding: Controls how to pad the image before applying the depthwise
1128      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1129      of padding algorithm to use, or a Python list indicating the explicit
1130      paddings at the start and end of each dimension. When explicit padding is
1131      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1132      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1133      padding used and data_format is `"NCHW"`, this should be in the form
1134      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1135    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1136    dilations: 1-D of size 2. The dilation rate in which we sample input values
1137      across the `height` and `width` dimensions in atrous convolution. If it is
1138      greater than 1, then all values of strides must be 1.
1139    name: A name for this operation (optional).
1140
1141  Returns:
1142    A 4-D `Tensor` with shape according to 'data_format'. For
1143      example, with data_format="NHWC", shape is [batch, out_height,
1144      out_width, out_channels].
1145  """
1146  return separable_conv2d(
1147      input,
1148      depthwise_filter,
1149      pointwise_filter,
1150      strides,
1151      padding,
1152      rate=dilations,
1153      name=name,
1154      data_format=data_format)
1155
1156# pylint: enable=redefined-builtin,line-too-long
1157
1158
1159@tf_export(v1=["nn.sufficient_statistics"])
1160@dispatch.add_dispatch_support
1161def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
1162                          keepdims=None):
1163  """Calculate the sufficient statistics for the mean and variance of `x`.
1164
1165  These sufficient statistics are computed using the one pass algorithm on
1166  an input that's optionally shifted. See:
1167  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1168
1169  For example:
1170  >>> t = [[1, 2, 3], [4, 5, 6]]
1171  >>> sufficient_statistics(t, [1])
1172  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1173  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1174  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1175  >>> sufficient_statistics(t, [-1])
1176  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1177  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1178  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1179
1180  Args:
1181    x: A `Tensor`.
1182    axes: Array of ints. Axes along which to compute mean and variance. As in
1183      Python, the axes can also be negative numbers. A negative axis is
1184      interpreted as counting from the end of the rank, i.e., axis +
1185      rank(values)-th dimension.
1186    shift: A `Tensor` containing the value by which to shift the data for
1187      numerical stability, or `None` if no shift is to be performed. A shift
1188      close to the true mean provides the most numerically stable results.
1189    keep_dims: produce statistics with the same dimensionality as the input.
1190    name: Name used to scope the operations that compute the sufficient stats.
1191    keepdims: Alias for keep_dims.
1192
1193  Returns:
1194    Four `Tensor` objects of the same type as `x`:
1195
1196    * the count (number of elements to average over).
1197    * the (possibly shifted) sum of the elements in the array.
1198    * the (possibly shifted) sum of squares of the elements in the array.
1199    * the shift by which the mean must be corrected or None if `shift` is None.
1200  """
1201  axes = list(set(axes))
1202  keep_dims = deprecated_argument_lookup(
1203      "keepdims", keepdims, "keep_dims", keep_dims)
1204  if keep_dims is None:
1205    keep_dims = False
1206  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
1207    x = ops.convert_to_tensor(x, name="x")
1208    x_shape = x.get_shape()
1209    if x_shape.rank is not None and all(
1210        x_shape.dims[d].value is not None for d in axes):
1211      counts = 1
1212      for d in axes:
1213        counts *= x_shape.dims[d].value
1214      counts = constant_op.constant(counts, dtype=x.dtype)
1215    else:  # shape needs to be inferred at runtime.
1216      # Normalize axes to be positive. Required for gather.
1217      rank = array_ops.rank(x)
1218      positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
1219      x_dims = array_ops.gather(
1220          math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
1221      counts = math_ops.reduce_prod(x_dims, name="count")
1222    if shift is not None:
1223      shift = ops.convert_to_tensor(shift, name="shift")
1224      m_ss = math_ops.subtract(x, shift)
1225      v_ss = math_ops.squared_difference(x, shift)
1226    else:  # no shift.
1227      m_ss = x
1228      v_ss = math_ops.square(x)
1229    m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
1230    v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
1231  return counts, m_ss, v_ss, shift
1232
1233
1234@tf_export("nn.sufficient_statistics", v1=[])
1235@dispatch.add_dispatch_support
1236def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
1237  """Calculate the sufficient statistics for the mean and variance of `x`.
1238
1239  These sufficient statistics are computed using the one pass algorithm on
1240  an input that's optionally shifted. See:
1241  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1242
1243  Args:
1244    x: A `Tensor`.
1245    axes: Array of ints. Axes along which to compute mean and variance.
1246    shift: A `Tensor` containing the value by which to shift the data for
1247      numerical stability, or `None` if no shift is to be performed. A shift
1248      close to the true mean provides the most numerically stable results.
1249    keepdims: produce statistics with the same dimensionality as the input.
1250    name: Name used to scope the operations that compute the sufficient stats.
1251
1252  Returns:
1253    Four `Tensor` objects of the same type as `x`:
1254
1255    * the count (number of elements to average over).
1256    * the (possibly shifted) sum of the elements in the array.
1257    * the (possibly shifted) sum of squares of the elements in the array.
1258    * the shift by which the mean must be corrected or None if `shift` is None.
1259  """
1260  return sufficient_statistics(
1261      x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
1262
1263
1264@tf_export("nn.normalize_moments")
1265@dispatch.add_dispatch_support
1266def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
1267  """Calculate the mean and variance of based on the sufficient statistics.
1268
1269  Args:
1270    counts: A `Tensor` containing the total count of the data (one value).
1271    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
1272      shifted) sum of the elements to average over.
1273    variance_ss: A `Tensor` containing the variance sufficient statistics: the
1274      (possibly shifted) squared sum of the data to compute the variance over.
1275    shift: A `Tensor` containing the value by which the data is shifted for
1276      numerical stability, or `None` if no shift was performed.
1277    name: Name used to scope the operations that compute the moments.
1278
1279  Returns:
1280    Two `Tensor` objects: `mean` and `variance`.
1281  """
1282  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
1283    divisor = math_ops.reciprocal(counts, name="divisor")
1284    if shift is not None:
1285      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
1286      mean = math_ops.add(shifted_mean, shift, name="mean")
1287    else:  # no shift.
1288      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
1289      mean = shifted_mean
1290    variance = math_ops.subtract(
1291        math_ops.multiply(variance_ss, divisor),
1292        math_ops.square(shifted_mean),
1293        name="variance")
1294  return (mean, variance)
1295
1296
1297@tf_export(v1=["nn.moments"])
1298@dispatch.add_dispatch_support
1299def moments(
1300    x,
1301    axes,
1302    shift=None,  # pylint: disable=unused-argument
1303    name=None,
1304    keep_dims=None,
1305    keepdims=None):
1306  """Calculate the mean and variance of `x`.
1307
1308  The mean and variance are calculated by aggregating the contents of `x`
1309  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1310  and variance of a vector.
1311
1312  Note: shift is currently not used; the true mean is computed and used.
1313
1314  When using these moments for batch normalization (see
1315  `tf.nn.batch_normalization`):
1316
1317   * for so-called "global normalization", used with convolutional filters with
1318     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1319   * for simple batch normalization pass `axes=[0]` (batch only).
1320
1321  Args:
1322    x: A `Tensor`.
1323    axes: Array of ints.  Axes along which to compute mean and
1324      variance.
1325    shift: Not used in the current implementation
1326    name: Name used to scope the operations that compute the moments.
1327    keep_dims: produce moments with the same dimensionality as the input.
1328    keepdims: Alias to keep_dims.
1329
1330  Returns:
1331    Two `Tensor` objects: `mean` and `variance`.
1332  """
1333  keep_dims = deprecated_argument_lookup(
1334      "keepdims", keepdims, "keep_dims", keep_dims)
1335  if keep_dims is None:
1336    keep_dims = False
1337  with ops.name_scope(name, "moments", [x, axes]):
1338    # The dynamic range of fp16 is too limited to support the collection of
1339    # sufficient statistics. As a workaround we simply perform the operations
1340    # on 32-bit floats before converting the mean and variance back to fp16
1341    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1342    # Compute true mean while keeping the dims for proper broadcasting.
1343    mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
1344    # sample variance, not unbiased variance
1345    # Note: stop_gradient does not change the gradient that gets
1346    #       backpropagated to the mean from the variance calculation,
1347    #       because that gradient is zero
1348    variance = math_ops.reduce_mean(
1349        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1350        axes,
1351        keepdims=True,
1352        name="variance")
1353    if not keep_dims:
1354      mean = array_ops.squeeze(mean, axes)
1355      variance = array_ops.squeeze(variance, axes)
1356    if x.dtype == dtypes.float16:
1357      return (math_ops.cast(mean, dtypes.float16),
1358              math_ops.cast(variance, dtypes.float16))
1359    else:
1360      return (mean, variance)
1361
1362
1363@tf_export("nn.moments", v1=[])
1364@dispatch.add_dispatch_support
1365def moments_v2(
1366    x,
1367    axes,
1368    shift=None,
1369    keepdims=False,
1370    name=None):
1371  """Calculates the mean and variance of `x`.
1372
1373  The mean and variance are calculated by aggregating the contents of `x`
1374  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1375  and variance of a vector.
1376
1377  Note: shift is currently not used; the true mean is computed and used.
1378
1379  When using these moments for batch normalization (see
1380  `tf.nn.batch_normalization`):
1381
1382   * for so-called "global normalization", used with convolutional filters with
1383     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1384   * for simple batch normalization pass `axes=[0]` (batch only).
1385
1386  Args:
1387    x: A `Tensor`.
1388    axes: Array of ints.  Axes along which to compute mean and
1389      variance.
1390    shift: Not used in the current implementation.
1391    keepdims: produce moments with the same dimensionality as the input.
1392    name: Name used to scope the operations that compute the moments.
1393
1394  Returns:
1395    Two `Tensor` objects: `mean` and `variance`.
1396  """
1397  return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1398
1399
1400@tf_export(v1=["nn.weighted_moments"])
1401@dispatch.add_dispatch_support
1402def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1403                     keepdims=None):
1404  """Returns the frequency-weighted mean and variance of `x`.
1405
1406  Args:
1407    x: A tensor.
1408    axes: 1-d tensor of int32 values; these are the axes along which
1409      to compute mean and variance.
1410    frequency_weights: A tensor of positive weights which can be
1411      broadcast with x.
1412    name: Name used to scope the operation.
1413    keep_dims: Produce moments with the same dimensionality as the input.
1414    keepdims: Alias of keep_dims.
1415
1416  Returns:
1417    Two tensors: `weighted_mean` and `weighted_variance`.
1418  """
1419  keep_dims = deprecated_argument_lookup(
1420      "keepdims", keepdims, "keep_dims", keep_dims)
1421  if keep_dims is None:
1422    keep_dims = False
1423  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1424    x = ops.convert_to_tensor(x, name="x")
1425    frequency_weights = ops.convert_to_tensor(
1426        frequency_weights, name="frequency_weights")
1427
1428    # Unlike moments(), this just uses a simpler two-pass method.
1429
1430    # See comment in moments() WRT precision; it applies here too.
1431    needs_cast = x.dtype == dtypes.float16
1432    if needs_cast:
1433      x = math_ops.cast(x, dtypes.float32)
1434
1435    if frequency_weights.dtype != x.dtype:
1436      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1437
1438    # Note that we use keep_dims=True for our reductions regardless of the arg;
1439    # this is so that the results remain broadcast-compatible with the inputs.
1440    weighted_input_sum = math_ops.reduce_sum(
1441        frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1442
1443    # The shape of the weights isn't necessarily the same as x's
1444    # shape, just broadcast-compatible with it -- so this expression
1445    # performs broadcasting to give a per-item weight, with the same
1446    # shape as (frequency_weights * x). This avoids having to reason
1447    # through all the broadcast logic to compute a correct
1448    # sum_of_weights.
1449    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1450
1451    sum_of_weights = math_ops.reduce_sum(
1452        broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1453
1454    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
1455
1456    weighted_mean = math_ops.multiply(weighted_input_sum, divisor)
1457
1458    # Have the weighted mean; now on to variance:
1459    weighted_distsq = math_ops.reduce_sum(
1460        frequency_weights * math_ops.squared_difference(x, weighted_mean),
1461        axes,
1462        name="weighted_distsq",
1463        keepdims=True)
1464
1465    weighted_variance = math_ops.multiply(weighted_distsq, divisor)
1466
1467    if not keep_dims:
1468      weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1469      weighted_variance = array_ops.squeeze(
1470          weighted_variance, axis=axes)
1471
1472    if needs_cast:
1473      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1474      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1475
1476    return weighted_mean, weighted_variance
1477
1478
1479@tf_export("nn.weighted_moments", v1=[])
1480@dispatch.add_dispatch_support
1481def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1482  """Returns the frequency-weighted mean and variance of `x`.
1483
1484  Args:
1485    x: A tensor.
1486    axes: 1-d tensor of int32 values; these are the axes along which
1487      to compute mean and variance.
1488    frequency_weights: A tensor of positive weights which can be
1489      broadcast with x.
1490    keepdims: Produce moments with the same dimensionality as the input.
1491    name: Name used to scope the operation.
1492
1493  Returns:
1494    Two tensors: `weighted_mean` and `weighted_variance`.
1495  """
1496  return weighted_moments(
1497      x=x,
1498      axes=axes,
1499      frequency_weights=frequency_weights,
1500      name=name,
1501      keep_dims=keepdims)
1502
1503
1504@tf_export("nn.batch_normalization")
1505@dispatch.add_dispatch_support
1506def batch_normalization(x,
1507                        mean,
1508                        variance,
1509                        offset,
1510                        scale,
1511                        variance_epsilon,
1512                        name=None):
1513  r"""Batch normalization.
1514
1515  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1516  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1517
1518  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1519
1520  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1521  shapes:
1522
1523    * In all generality, they can have the same number of dimensions as the
1524      input `x`, with identical sizes as `x` for the dimensions that are not
1525      normalized over (the 'depth' dimension(s)), and dimension 1 for the
1526      others which are being normalized over.
1527      `mean` and `variance` in this case would typically be the outputs of
1528      `tf.nn.moments(..., keepdims=True)` during training, or running averages
1529      thereof during inference.
1530    * In the common case where the 'depth' dimension is the last dimension in
1531      the input tensor `x`, they may be one dimensional tensors of the same
1532      size as the 'depth' dimension.
1533      This is the case for example for the common `[batch, depth]` layout of
1534      fully-connected layers, and `[batch, height, width, depth]` for
1535      convolutions.
1536      `mean` and `variance` in this case would typically be the outputs of
1537      `tf.nn.moments(..., keepdims=False)` during training, or running averages
1538      thereof during inference.
1539
1540  See equation 11 in Algorithm 2 of source:
1541  [Batch Normalization: Accelerating Deep Network Training by
1542  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1543  (http://arxiv.org/abs/1502.03167).
1544
1545  Args:
1546    x: Input `Tensor` of arbitrary dimensionality.
1547    mean: A mean `Tensor`.
1548    variance: A variance `Tensor`.
1549    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1550      None. If present, will be added to the normalized tensor.
1551    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1552      `None`. If present, the scale is applied to the normalized tensor.
1553    variance_epsilon: A small float number to avoid dividing by 0.
1554    name: A name for this operation (optional).
1555
1556  Returns:
1557    the normalized, scaled, offset tensor.
1558
1559  References:
1560    Batch Normalization - Accelerating Deep Network Training by Reducing
1561    Internal Covariate Shift:
1562      [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
1563      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1564  """
1565  with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1566    inv = math_ops.rsqrt(variance + variance_epsilon)
1567    if scale is not None:
1568      inv *= scale
1569    # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1570    # the precise order of ops that are generated by the expression below.
1571    return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1572        offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1573
1574
1575@tf_export(v1=["nn.fused_batch_norm"])
1576@dispatch.add_dispatch_support
1577def fused_batch_norm(
1578    x,
1579    scale,
1580    offset,  # pylint: disable=invalid-name
1581    mean=None,
1582    variance=None,
1583    epsilon=0.001,
1584    data_format="NHWC",
1585    is_training=True,
1586    name=None,
1587    exponential_avg_factor=1.0):
1588  r"""Batch normalization.
1589
1590
1591  See Source: [Batch Normalization: Accelerating Deep Network Training by
1592  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1593  (http://arxiv.org/abs/1502.03167).
1594
1595  Args:
1596    x: Input `Tensor` of 4 or 5 dimensions.
1597    scale: A `Tensor` of 1 dimension for scaling.
1598    offset: A `Tensor` of 1 dimension for bias.
1599    mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
1600          of this argument depends on the value of is_training and
1601          exponential_avg_factor as follows:
1602          is_training==False (inference):
1603            Mean must be a `Tensor` of the same shape as scale containing the
1604            estimated population mean computed during training.
1605          is_training==True and exponential_avg_factor == 1.0:
1606            Mean must be None.
1607          is_training==True and exponential_avg_factor != 1.0:
1608            Mean must be a `Tensor` of the same shape as scale containing the
1609            exponential running mean.
1610    variance: A `Tensor` of 1 dimension for population variance. The shape and
1611          meaning of this argument depends on the value of is_training and
1612          exponential_avg_factor as follows:
1613          is_training==False (inference):
1614            Variance must be a `Tensor` of the same shape as scale containing
1615            the estimated population variance computed during training.
1616          is_training==True and exponential_avg_factor == 1.0:
1617            Variance must be None.
1618          is_training==True and exponential_avg_factor != 1.0:
1619            Variance must be a `Tensor` of the same shape as scale containing
1620            the exponential running variance.
1621    epsilon: A small float number added to the variance of x.
1622    data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
1623                 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
1624    is_training: A bool value to specify if the operation is used for
1625                 training or inference.
1626    name: A name for this operation (optional).
1627    exponential_avg_factor: A float number (usually between 0 and 1) used
1628                            for controlling the decay of the running
1629                            population average of mean and variance.
1630                            If set to 1.0, the current batch average is
1631                            returned.
1632
1633  Returns:
1634    y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
1635    running_mean: A 1D Tensor for the exponential running mean of x.
1636                  The output value is (1 - exponential_avg_factor) * mean +
1637                  exponential_avg_factor * batch_mean), where batch_mean
1638                  is the mean of the current batch in x.
1639    running_var: A 1D Tensor for the exponential running variance
1640                 The output value is (1 - exponential_avg_factor) * variance +
1641                 exponential_avg_factor * batch_variance), where batch_variance
1642                 is the variance of the current batch in x.
1643
1644  References:
1645    Batch Normalization - Accelerating Deep Network Training by Reducing
1646    Internal Covariate Shift:
1647      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1648      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1649  """
1650  if (not is_training or exponential_avg_factor != 1.0) and (
1651      (mean is None) or (variance is None)):
1652    raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
1653                     "is_training is False or "
1654                     "exponential_avg_factor != 1.0.")
1655  x = ops.convert_to_tensor(x, name="input")
1656  scale = ops.convert_to_tensor(scale, name="scale")
1657  offset = ops.convert_to_tensor(offset, name="offset")
1658  if mean is None:
1659    mean = constant_op.constant([])
1660  if variance is None:
1661    variance = constant_op.constant([])
1662
1663  # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
1664  # prevent exception (see cudnn.h).
1665  min_epsilon = 1.001e-5
1666  epsilon = epsilon if epsilon > min_epsilon else min_epsilon
1667
1668  y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
1669      x,
1670      scale,
1671      offset,
1672      mean,
1673      variance,
1674      epsilon=epsilon,
1675      exponential_avg_factor=exponential_avg_factor,
1676      data_format=data_format,
1677      is_training=is_training,
1678      name=name)
1679  return y, running_mean, running_var
1680
1681
1682@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1683@dispatch.add_dispatch_support
1684def batch_norm_with_global_normalization(t=None,
1685                                         m=None,
1686                                         v=None,
1687                                         beta=None,
1688                                         gamma=None,
1689                                         variance_epsilon=None,
1690                                         scale_after_normalization=None,
1691                                         name=None,
1692                                         input=None,  # pylint: disable=redefined-builtin
1693                                         mean=None,
1694                                         variance=None):
1695  """Batch normalization.
1696
1697  This op is deprecated. See `tf.nn.batch_normalization`.
1698
1699  Args:
1700    t: A 4D input Tensor.
1701    m: A 1D mean Tensor with size matching the last dimension of t.
1702      This is the first output from tf.nn.moments,
1703      or a saved moving average thereof.
1704    v: A 1D variance Tensor with size matching the last dimension of t.
1705      This is the second output from tf.nn.moments,
1706      or a saved moving average thereof.
1707    beta: A 1D beta Tensor with size matching the last dimension of t.
1708      An offset to be added to the normalized tensor.
1709    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1710      If "scale_after_normalization" is true, this tensor will be multiplied
1711      with the normalized tensor.
1712    variance_epsilon: A small float number to avoid dividing by 0.
1713    scale_after_normalization: A bool indicating whether the resulted tensor
1714      needs to be multiplied with gamma.
1715    name: A name for this operation (optional).
1716    input: Alias for t.
1717    mean: Alias for m.
1718    variance: Alias for v.
1719
1720  Returns:
1721     A batch-normalized `t`.
1722
1723  References:
1724    Batch Normalization - Accelerating Deep Network Training by Reducing
1725    Internal Covariate Shift:
1726      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1727      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1728  """
1729  t = deprecated_argument_lookup("input", input, "t", t)
1730  m = deprecated_argument_lookup("mean", mean, "m", m)
1731  v = deprecated_argument_lookup("variance", variance, "v", v)
1732  return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1733                             else None, variance_epsilon, name)
1734
1735
1736# pylint: disable=redefined-builtin,line-too-long
1737@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1738@dispatch.add_dispatch_support
1739def batch_norm_with_global_normalization_v2(input,
1740                                            mean,
1741                                            variance,
1742                                            beta,
1743                                            gamma,
1744                                            variance_epsilon,
1745                                            scale_after_normalization,
1746                                            name=None):
1747  """Batch normalization.
1748
1749  This op is deprecated. See `tf.nn.batch_normalization`.
1750
1751  Args:
1752    input: A 4D input Tensor.
1753    mean: A 1D mean Tensor with size matching the last dimension of t.
1754      This is the first output from tf.nn.moments,
1755      or a saved moving average thereof.
1756    variance: A 1D variance Tensor with size matching the last dimension of t.
1757      This is the second output from tf.nn.moments,
1758      or a saved moving average thereof.
1759    beta: A 1D beta Tensor with size matching the last dimension of t.
1760      An offset to be added to the normalized tensor.
1761    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1762      If "scale_after_normalization" is true, this tensor will be multiplied
1763      with the normalized tensor.
1764    variance_epsilon: A small float number to avoid dividing by 0.
1765    scale_after_normalization: A bool indicating whether the resulted tensor
1766      needs to be multiplied with gamma.
1767    name: A name for this operation (optional).
1768
1769  Returns:
1770     A batch-normalized `t`.
1771
1772  References:
1773    Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
1774      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1775      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1776  """
1777  return batch_norm_with_global_normalization(t=input,
1778                                              m=mean,
1779                                              v=variance,
1780                                              beta=beta,
1781                                              gamma=gamma,
1782                                              variance_epsilon=variance_epsilon,
1783                                              scale_after_normalization=scale_after_normalization,
1784                                              name=name)
1785
1786# pylint: enable=redefined-builtin,line-too-long
1787
1788
1789def _sum_rows(x):
1790  """Returns a vector summing up each row of the matrix x."""
1791  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1792  # a matrix.  The gradient of _sum_rows(x) is more efficient than
1793  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1794  # we use _sum_rows(x) in the nce_loss() computation since the loss
1795  # is mostly used for training.
1796  cols = array_ops.shape(x)[1]
1797  ones_shape = array_ops.stack([cols, 1])
1798  ones = array_ops.ones(ones_shape, x.dtype)
1799  return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1800
1801
1802def _compute_sampled_logits(weights,
1803                            biases,
1804                            labels,
1805                            inputs,
1806                            num_sampled,
1807                            num_classes,
1808                            num_true=1,
1809                            sampled_values=None,
1810                            subtract_log_q=True,
1811                            remove_accidental_hits=False,
1812                            partition_strategy="mod",
1813                            name=None,
1814                            seed=None):
1815  """Helper function for nce_loss and sampled_softmax_loss functions.
1816
1817  Computes sampled output training logits and labels suitable for implementing
1818  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1819  sampled_softmax_loss).
1820
1821  Note: In the case where num_true > 1, we assign to each target class
1822  the target probability 1 / num_true so that the target probabilities
1823  sum to 1 per-example.
1824
1825  Args:
1826    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1827        objects whose concatenation along dimension 0 has shape
1828        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
1829    biases: A `Tensor` of shape `[num_classes]`.  The (possibly-partitioned)
1830        class biases.
1831    labels: A `Tensor` of type `int64` and shape `[batch_size,
1832        num_true]`. The target classes.  Note that this format differs from
1833        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
1834    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1835        activations of the input network.
1836    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1837    num_classes: An `int`. The number of possible classes.
1838    num_true: An `int`.  The number of target classes per training example.
1839    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1840        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1841        (if None, we default to `log_uniform_candidate_sampler`)
1842    subtract_log_q: A `bool`.  whether to subtract the log expected count of
1843        the labels in the sample to get the logits of the true labels.
1844        Default is True.  Turn off for Negative Sampling.
1845    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1846        where a sampled class equals one of the target classes.  Default is
1847        False.
1848    partition_strategy: A string specifying the partitioning strategy, relevant
1849        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1850        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1851    name: A name for the operation (optional).
1852    seed: random seed for candidate sampling. Default to None, which doesn't set
1853        the op-level random seed for candidate sampling.
1854  Returns:
1855    out_logits: `Tensor` object with shape
1856        `[batch_size, num_true + num_sampled]`, for passing to either
1857        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1858        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
1859    out_labels: A Tensor object with the same shape as `out_logits`.
1860  """
1861
1862  if isinstance(weights, variables.PartitionedVariable):
1863    weights = list(weights)
1864  if not isinstance(weights, list):
1865    weights = [weights]
1866
1867  with ops.name_scope(name, "compute_sampled_logits",
1868                      weights + [biases, inputs, labels]):
1869    if labels.dtype != dtypes.int64:
1870      labels = math_ops.cast(labels, dtypes.int64)
1871    labels_flat = array_ops.reshape(labels, [-1])
1872
1873    # Sample the negative labels.
1874    #   sampled shape: [num_sampled] tensor
1875    #   true_expected_count shape = [batch_size, 1] tensor
1876    #   sampled_expected_count shape = [num_sampled] tensor
1877    if sampled_values is None:
1878      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1879          true_classes=labels,
1880          num_true=num_true,
1881          num_sampled=num_sampled,
1882          unique=True,
1883          range_max=num_classes,
1884          seed=seed)
1885    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1886    # pylint: disable=unpacking-non-sequence
1887    sampled, true_expected_count, sampled_expected_count = (
1888        array_ops.stop_gradient(s) for s in sampled_values)
1889    # pylint: enable=unpacking-non-sequence
1890    sampled = math_ops.cast(sampled, dtypes.int64)
1891
1892    # labels_flat is a [batch_size * num_true] tensor
1893    # sampled is a [num_sampled] int tensor
1894    all_ids = array_ops.concat([labels_flat, sampled], 0)
1895
1896    # Retrieve the true weights and the logits of the sampled weights.
1897
1898    # weights shape is [num_classes, dim]
1899    all_w = embedding_ops.embedding_lookup(
1900        weights, all_ids, partition_strategy=partition_strategy)
1901    if all_w.dtype != inputs.dtype:
1902      all_w = math_ops.cast(all_w, inputs.dtype)
1903
1904    # true_w shape is [batch_size * num_true, dim]
1905    true_w = array_ops.slice(all_w, [0, 0],
1906                             array_ops.stack(
1907                                 [array_ops.shape(labels_flat)[0], -1]))
1908
1909    sampled_w = array_ops.slice(
1910        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1911    # inputs has shape [batch_size, dim]
1912    # sampled_w has shape [num_sampled, dim]
1913    # Apply X*W', which yields [batch_size, num_sampled]
1914    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1915
1916    # Retrieve the true and sampled biases, compute the true logits, and
1917    # add the biases to the true and sampled logits.
1918    all_b = embedding_ops.embedding_lookup(
1919        biases, all_ids, partition_strategy=partition_strategy)
1920    if all_b.dtype != inputs.dtype:
1921      all_b = math_ops.cast(all_b, inputs.dtype)
1922    # true_b is a [batch_size * num_true] tensor
1923    # sampled_b is a [num_sampled] float tensor
1924    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1925    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1926
1927    # inputs shape is [batch_size, dim]
1928    # true_w shape is [batch_size * num_true, dim]
1929    # row_wise_dots is [batch_size, num_true, dim]
1930    dim = array_ops.shape(true_w)[1:2]
1931    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1932    row_wise_dots = math_ops.multiply(
1933        array_ops.expand_dims(inputs, 1),
1934        array_ops.reshape(true_w, new_true_w_shape))
1935    # We want the row-wise dot plus biases which yields a
1936    # [batch_size, num_true] tensor of true_logits.
1937    dots_as_matrix = array_ops.reshape(row_wise_dots,
1938                                       array_ops.concat([[-1], dim], 0))
1939    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1940    true_b = array_ops.reshape(true_b, [-1, num_true])
1941    true_logits += true_b
1942    sampled_logits += sampled_b
1943
1944    if remove_accidental_hits:
1945      acc_hits = candidate_sampling_ops.compute_accidental_hits(
1946          labels, sampled, num_true=num_true)
1947      acc_indices, acc_ids, acc_weights = acc_hits
1948
1949      # This is how SparseToDense expects the indices.
1950      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1951      acc_ids_2d_int32 = array_ops.reshape(
1952          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1953      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1954                                        "sparse_indices")
1955      # Create sampled_logits_shape = [batch_size, num_sampled]
1956      sampled_logits_shape = array_ops.concat(
1957          [array_ops.shape(labels)[:1],
1958           array_ops.expand_dims(num_sampled, 0)], 0)
1959      if sampled_logits.dtype != acc_weights.dtype:
1960        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1961      sampled_logits += gen_sparse_ops.sparse_to_dense(
1962          sparse_indices,
1963          sampled_logits_shape,
1964          acc_weights,
1965          default_value=0.0,
1966          validate_indices=False)
1967
1968    if subtract_log_q:
1969      # Subtract log of Q(l), prior probability that l appears in sampled.
1970      true_logits -= math_ops.log(true_expected_count)
1971      sampled_logits -= math_ops.log(sampled_expected_count)
1972
1973    # Construct output logits and labels. The true labels/logits start at col 0.
1974    out_logits = array_ops.concat([true_logits, sampled_logits], 1)
1975
1976    # true_logits is a float tensor, ones_like(true_logits) is a float
1977    # tensor of ones. We then divide by num_true to ensure the per-example
1978    # labels sum to 1.0, i.e. form a proper probability distribution.
1979    out_labels = array_ops.concat([
1980        array_ops.ones_like(true_logits) / num_true,
1981        array_ops.zeros_like(sampled_logits)
1982    ], 1)
1983
1984    return out_logits, out_labels
1985
1986
1987@tf_export("nn.nce_loss", v1=[])
1988@dispatch.add_dispatch_support
1989def nce_loss_v2(weights,
1990                biases,
1991                labels,
1992                inputs,
1993                num_sampled,
1994                num_classes,
1995                num_true=1,
1996                sampled_values=None,
1997                remove_accidental_hits=False,
1998                name="nce_loss"):
1999  """Computes and returns the noise-contrastive estimation training loss.
2000
2001  See [Noise-contrastive estimation: A new estimation principle for
2002  unnormalized statistical
2003  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
2004  Also see our [Candidate Sampling Algorithms
2005  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
2006
2007  A common use case is to use this method for training, and calculate the full
2008  sigmoid loss for evaluation or inference as in the following example:
2009
2010  ```python
2011  if mode == "train":
2012    loss = tf.nn.nce_loss(
2013        weights=weights,
2014        biases=biases,
2015        labels=labels,
2016        inputs=inputs,
2017        ...)
2018  elif mode == "eval":
2019    logits = tf.matmul(inputs, tf.transpose(weights))
2020    logits = tf.nn.bias_add(logits, biases)
2021    labels_one_hot = tf.one_hot(labels, n_classes)
2022    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2023        labels=labels_one_hot,
2024        logits=logits)
2025    loss = tf.reduce_sum(loss, axis=1)
2026  ```
2027
2028  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2029  strategy will be used. Support for other partition strategy will be added
2030  later.
2031
2032  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2033  so your labels must be sorted in order of decreasing frequency to achieve
2034  good results.  For more details, see
2035  `tf.random.log_uniform_candidate_sampler`.
2036
2037  Note: In the case where `num_true` > 1, we assign to each target class
2038  the target probability 1 / `num_true` so that the target probabilities
2039  sum to 1 per-example.
2040
2041  Note: It would be useful to allow a variable number of target classes per
2042  example.  We hope to provide this functionality in a future release.
2043  For now, if you have a variable number of target classes, you can pad them
2044  out to a constant number by either repeating them or by padding
2045  with an otherwise unused class.
2046
2047  Args:
2048    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2049      objects whose concatenation along dimension 0 has shape [num_classes,
2050      dim].  The (possibly-partitioned) class embeddings.
2051    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2052    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2053      target classes.
2054    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2055      the input network.
2056    num_sampled: An `int`.  The number of negative classes to randomly sample
2057      per batch. This single sample of negative classes is evaluated for each
2058      element in the batch.
2059    num_classes: An `int`. The number of possible classes.
2060    num_true: An `int`.  The number of target classes per training example.
2061    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2062      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2063      (if None, we default to `log_uniform_candidate_sampler`)
2064    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2065      where a sampled class equals one of the target classes.  If set to `True`,
2066      this is a "Sampled Logistic" loss instead of NCE, and we are learning to
2067      generate log-odds instead of log probabilities.  See our [Candidate
2068      Sampling Algorithms Reference]
2069        (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
2070          False.
2071    name: A name for the operation (optional).
2072
2073  Returns:
2074    A `batch_size` 1-D tensor of per-example NCE losses.
2075  """
2076  # TODO(yuefengz): get partition_strategy from either variables or distribution
2077  # strategies.
2078  return nce_loss(
2079      weights,
2080      biases,
2081      labels,
2082      inputs,
2083      num_sampled,
2084      num_classes,
2085      num_true=num_true,
2086      sampled_values=sampled_values,
2087      remove_accidental_hits=remove_accidental_hits,
2088      partition_strategy="div",
2089      name=name)
2090
2091
2092@tf_export(v1=["nn.nce_loss"])
2093@dispatch.add_dispatch_support
2094def nce_loss(weights,
2095             biases,
2096             labels,
2097             inputs,
2098             num_sampled,
2099             num_classes,
2100             num_true=1,
2101             sampled_values=None,
2102             remove_accidental_hits=False,
2103             partition_strategy="mod",
2104             name="nce_loss"):
2105  """Computes and returns the noise-contrastive estimation training loss.
2106
2107  A common use case is to use this method for training, and calculate the full
2108  sigmoid loss for evaluation or inference. In this case, you must set
2109  `partition_strategy="div"` for the two losses to be consistent, as in the
2110  following example:
2111
2112  ```python
2113  if mode == "train":
2114    loss = tf.nn.nce_loss(
2115        weights=weights,
2116        biases=biases,
2117        labels=labels,
2118        inputs=inputs,
2119        ...,
2120        partition_strategy="div")
2121  elif mode == "eval":
2122    logits = tf.matmul(inputs, tf.transpose(weights))
2123    logits = tf.nn.bias_add(logits, biases)
2124    labels_one_hot = tf.one_hot(labels, n_classes)
2125    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2126        labels=labels_one_hot,
2127        logits=logits)
2128    loss = tf.reduce_sum(loss, axis=1)
2129  ```
2130
2131  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2132  so your labels must be sorted in order of decreasing frequency to achieve
2133  good results.  For more details, see
2134  `tf.random.log_uniform_candidate_sampler`.
2135
2136  Note: In the case where `num_true` > 1, we assign to each target class
2137  the target probability 1 / `num_true` so that the target probabilities
2138  sum to 1 per-example.
2139
2140  Note: It would be useful to allow a variable number of target classes per
2141  example.  We hope to provide this functionality in a future release.
2142  For now, if you have a variable number of target classes, you can pad them
2143  out to a constant number by either repeating them or by padding
2144  with an otherwise unused class.
2145
2146  Args:
2147    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2148        objects whose concatenation along dimension 0 has shape
2149        [num_classes, dim].  The (possibly-partitioned) class embeddings.
2150    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2151    labels: A `Tensor` of type `int64` and shape `[batch_size,
2152        num_true]`. The target classes.
2153    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2154        activations of the input network.
2155    num_sampled: An `int`.  The number of negative classes to randomly sample
2156        per batch. This single sample of negative classes is evaluated for each
2157        element in the batch.
2158    num_classes: An `int`. The number of possible classes.
2159    num_true: An `int`.  The number of target classes per training example.
2160    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2161        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2162        (if None, we default to `log_uniform_candidate_sampler`)
2163    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2164        where a sampled class equals one of the target classes.  If set to
2165        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
2166        learning to generate log-odds instead of log probabilities. See
2167        our Candidate Sampling Algorithms Reference
2168        ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2169        Default is False.
2170    partition_strategy: A string specifying the partitioning strategy, relevant
2171        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2172        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2173    name: A name for the operation (optional).
2174
2175  Returns:
2176    A `batch_size` 1-D tensor of per-example NCE losses.
2177
2178  References:
2179    Noise-contrastive estimation - A new estimation principle for unnormalized
2180    statistical models:
2181      [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
2182      ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
2183  """
2184  logits, labels = _compute_sampled_logits(
2185      weights=weights,
2186      biases=biases,
2187      labels=labels,
2188      inputs=inputs,
2189      num_sampled=num_sampled,
2190      num_classes=num_classes,
2191      num_true=num_true,
2192      sampled_values=sampled_values,
2193      subtract_log_q=True,
2194      remove_accidental_hits=remove_accidental_hits,
2195      partition_strategy=partition_strategy,
2196      name=name)
2197  sampled_losses = sigmoid_cross_entropy_with_logits(
2198      labels=labels, logits=logits, name="sampled_losses")
2199  # sampled_losses is batch_size x {true_loss, sampled_losses...}
2200  # We sum out true and sampled losses.
2201  return _sum_rows(sampled_losses)
2202
2203
2204@tf_export("nn.sampled_softmax_loss", v1=[])
2205@dispatch.add_dispatch_support
2206def sampled_softmax_loss_v2(weights,
2207                            biases,
2208                            labels,
2209                            inputs,
2210                            num_sampled,
2211                            num_classes,
2212                            num_true=1,
2213                            sampled_values=None,
2214                            remove_accidental_hits=True,
2215                            seed=None,
2216                            name="sampled_softmax_loss"):
2217  """Computes and returns the sampled softmax training loss.
2218
2219  This is a faster way to train a softmax classifier over a huge number of
2220  classes.
2221
2222  This operation is for training only.  It is generally an underestimate of
2223  the full softmax loss.
2224
2225  A common use case is to use this method for training, and calculate the full
2226  softmax loss for evaluation or inference as in the following example:
2227
2228  ```python
2229  if mode == "train":
2230    loss = tf.nn.sampled_softmax_loss(
2231        weights=weights,
2232        biases=biases,
2233        labels=labels,
2234        inputs=inputs,
2235        ...)
2236  elif mode == "eval":
2237    logits = tf.matmul(inputs, tf.transpose(weights))
2238    logits = tf.nn.bias_add(logits, biases)
2239    labels_one_hot = tf.one_hot(labels, n_classes)
2240    loss = tf.nn.softmax_cross_entropy_with_logits(
2241        labels=labels_one_hot,
2242        logits=logits)
2243  ```
2244
2245  See our [Candidate Sampling Algorithms Reference]
2246  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
2247
2248  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
2249  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
2250
2251  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2252  strategy will be used. Support for other partition strategy will be added
2253  later.
2254
2255  Args:
2256    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2257      objects whose concatenation along dimension 0 has shape [num_classes,
2258      dim].  The (possibly-sharded) class embeddings.
2259    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2260    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2261      target classes.  Note that this format differs from the `labels` argument
2262      of `nn.softmax_cross_entropy_with_logits`.
2263    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2264      the input network.
2265    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2266    num_classes: An `int`. The number of possible classes.
2267    num_true: An `int`.  The number of target classes per training example.
2268    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2269      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2270      (if None, we default to `log_uniform_candidate_sampler`)
2271    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2272      where a sampled class equals one of the target classes.  Default is True.
2273    seed: random seed for candidate sampling. Default to None, which doesn't set
2274      the op-level random seed for candidate sampling.
2275    name: A name for the operation (optional).
2276
2277  Returns:
2278    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2279
2280  """
2281  return sampled_softmax_loss(
2282      weights,
2283      biases,
2284      labels,
2285      inputs,
2286      num_sampled,
2287      num_classes,
2288      num_true=num_true,
2289      sampled_values=sampled_values,
2290      remove_accidental_hits=remove_accidental_hits,
2291      partition_strategy="div",
2292      name=name,
2293      seed=seed)
2294
2295
2296@tf_export(v1=["nn.sampled_softmax_loss"])
2297@dispatch.add_dispatch_support
2298def sampled_softmax_loss(weights,
2299                         biases,
2300                         labels,
2301                         inputs,
2302                         num_sampled,
2303                         num_classes,
2304                         num_true=1,
2305                         sampled_values=None,
2306                         remove_accidental_hits=True,
2307                         partition_strategy="mod",
2308                         name="sampled_softmax_loss",
2309                         seed=None):
2310  """Computes and returns the sampled softmax training loss.
2311
2312  This is a faster way to train a softmax classifier over a huge number of
2313  classes.
2314
2315  This operation is for training only.  It is generally an underestimate of
2316  the full softmax loss.
2317
2318  A common use case is to use this method for training, and calculate the full
2319  softmax loss for evaluation or inference. In this case, you must set
2320  `partition_strategy="div"` for the two losses to be consistent, as in the
2321  following example:
2322
2323  ```python
2324  if mode == "train":
2325    loss = tf.nn.sampled_softmax_loss(
2326        weights=weights,
2327        biases=biases,
2328        labels=labels,
2329        inputs=inputs,
2330        ...,
2331        partition_strategy="div")
2332  elif mode == "eval":
2333    logits = tf.matmul(inputs, tf.transpose(weights))
2334    logits = tf.nn.bias_add(logits, biases)
2335    labels_one_hot = tf.one_hot(labels, n_classes)
2336    loss = tf.nn.softmax_cross_entropy_with_logits(
2337        labels=labels_one_hot,
2338        logits=logits)
2339  ```
2340
2341  See our Candidate Sampling Algorithms Reference
2342  ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2343  Also see Section 3 of (Jean et al., 2014) for the math.
2344
2345  Args:
2346    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2347        objects whose concatenation along dimension 0 has shape
2348        [num_classes, dim].  The (possibly-sharded) class embeddings.
2349    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2350    labels: A `Tensor` of type `int64` and shape `[batch_size,
2351        num_true]`. The target classes.  Note that this format differs from
2352        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
2353    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2354        activations of the input network.
2355    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2356    num_classes: An `int`. The number of possible classes.
2357    num_true: An `int`.  The number of target classes per training example.
2358    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2359        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2360        (if None, we default to `log_uniform_candidate_sampler`)
2361    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2362        where a sampled class equals one of the target classes.  Default is
2363        True.
2364    partition_strategy: A string specifying the partitioning strategy, relevant
2365        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2366        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2367    name: A name for the operation (optional).
2368    seed: random seed for candidate sampling. Default to None, which doesn't set
2369        the op-level random seed for candidate sampling.
2370
2371  Returns:
2372    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2373
2374  References:
2375    On Using Very Large Target Vocabulary for Neural Machine Translation:
2376      [Jean et al., 2014]
2377      (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
2378      ([pdf](http://aclweb.org/anthology/P15-1001))
2379  """
2380  logits, labels = _compute_sampled_logits(
2381      weights=weights,
2382      biases=biases,
2383      labels=labels,
2384      inputs=inputs,
2385      num_sampled=num_sampled,
2386      num_classes=num_classes,
2387      num_true=num_true,
2388      sampled_values=sampled_values,
2389      subtract_log_q=True,
2390      remove_accidental_hits=remove_accidental_hits,
2391      partition_strategy=partition_strategy,
2392      name=name,
2393      seed=seed)
2394  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
2395  sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
2396      labels=labels, logits=logits)
2397  # sampled_losses is a [batch_size] tensor.
2398  return sampled_losses
2399