• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23from tensorflow.python.distribute import distribution_strategy_context as ds
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import candidate_sampling_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import custom_gradient
31from tensorflow.python.ops import embedding_ops
32from tensorflow.python.ops import gen_array_ops  # pylint: disable=unused-import
33from tensorflow.python.ops import gen_nn_ops
34from tensorflow.python.ops import gen_sparse_ops
35from tensorflow.python.ops import linalg_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import nn_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.ops.losses import util as losses_util
40from tensorflow.python.platform import device_context
41from tensorflow.python.util.deprecation import deprecated_args
42from tensorflow.python.util.deprecation import deprecated_argument_lookup
43from tensorflow.python.util.tf_export import tf_export
44
45
46@tf_export("nn.log_poisson_loss")
47def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
48  """Computes log Poisson loss given `log_input`.
49
50  Gives the log-likelihood loss between the prediction and the target under the
51  assumption that the target has a Poisson distribution.
52  Caveat: By default, this is not the exact loss, but the loss minus a
53    constant term [log(z!)]. That has no effect for optimization, but
54    does not play well with relative loss comparisons. To compute an
55    approximation of the log factorial term, specify
56    compute_full_loss=True to enable Stirling's Approximation.
57
58  For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
59  loss is
60
61        -log(exp(-x) * (x^z) / z!)
62      = -log(exp(-x) * (x^z)) + log(z!)
63      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
64          [ Note the second term is the Stirling's Approximation for log(z!).
65            It is invariant to x and does not affect optimization, though
66            important for correct relative loss comparisons. It is only
67            computed when compute_full_loss == True. ]
68      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
69      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
70
71  Args:
72    targets: A `Tensor` of the same type and shape as `log_input`.
73    log_input: A `Tensor` of type `float32` or `float64`.
74    compute_full_loss: whether to compute the full loss. If false, a constant
75      term is dropped in favor of more efficient optimization.
76    name: A name for the operation (optional).
77
78  Returns:
79    A `Tensor` of the same shape as `log_input` with the componentwise
80    logistic losses.
81
82  Raises:
83    ValueError: If `log_input` and `targets` do not have the same shape.
84  """
85  with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
86    log_input = ops.convert_to_tensor(log_input, name="log_input")
87    targets = ops.convert_to_tensor(targets, name="targets")
88    try:
89      targets.get_shape().merge_with(log_input.get_shape())
90    except ValueError:
91      raise ValueError(
92          "log_input and targets must have the same shape (%s vs %s)" %
93          (log_input.get_shape(), targets.get_shape()))
94
95    result = math_ops.exp(log_input) - log_input * targets
96    if compute_full_loss:
97      # need to create constant tensors here so that their dtypes can be matched
98      # to that of the targets.
99      point_five = constant_op.constant(0.5, dtype=targets.dtype)
100      two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
101
102      stirling_approx = (targets * math_ops.log(targets)) - targets + (
103          point_five * math_ops.log(two_pi * targets))
104      zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
105      ones = array_ops.ones_like(targets, dtype=targets.dtype)
106      cond = math_ops.logical_and(targets >= zeros, targets <= ones)
107      result += array_ops.where(cond, zeros, stirling_approx)
108    return result
109
110
111@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
112def sigmoid_cross_entropy_with_logits(  # pylint: disable=invalid-name
113    _sentinel=None,
114    labels=None,
115    logits=None,
116    name=None):
117  """Computes sigmoid cross entropy given `logits`.
118
119  Measures the probability error in discrete classification tasks in which each
120  class is independent and not mutually exclusive.  For instance, one could
121  perform multilabel classification where a picture can contain both an elephant
122  and a dog at the same time.
123
124  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
125
126        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
127      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
128      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
129      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
130      = (1 - z) * x + log(1 + exp(-x))
131      = x - x * z + log(1 + exp(-x))
132
133  For x < 0, to avoid overflow in exp(-x), we reformulate the above
134
135        x - x * z + log(1 + exp(-x))
136      = log(exp(x)) - x * z + log(1 + exp(-x))
137      = - x * z + log(1 + exp(x))
138
139  Hence, to ensure stability and avoid overflow, the implementation uses this
140  equivalent formulation
141
142      max(x, 0) - x * z + log(1 + exp(-abs(x)))
143
144  `logits` and `labels` must have the same type and shape.
145
146  Args:
147    _sentinel: Used to prevent positional parameters. Internal, do not use.
148    labels: A `Tensor` of the same type and shape as `logits`.
149    logits: A `Tensor` of type `float32` or `float64`.
150    name: A name for the operation (optional).
151
152  Returns:
153    A `Tensor` of the same shape as `logits` with the componentwise
154    logistic losses.
155
156  Raises:
157    ValueError: If `logits` and `labels` do not have the same shape.
158  """
159  # pylint: disable=protected-access
160  nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel,
161                           labels, logits)
162  # pylint: enable=protected-access
163
164  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
165    logits = ops.convert_to_tensor(logits, name="logits")
166    labels = ops.convert_to_tensor(labels, name="labels")
167    try:
168      labels.get_shape().merge_with(logits.get_shape())
169    except ValueError:
170      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
171                       (logits.get_shape(), labels.get_shape()))
172
173    # The logistic loss formula from above is
174    #   x - x * z + log(1 + exp(-x))
175    # For x < 0, a more numerically stable formula is
176    #   -x * z + log(1 + exp(x))
177    # Note that these two expressions can be combined into the following:
178    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
179    # To allow computing gradients at zero, we define custom versions of max and
180    # abs functions.
181    zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
182    cond = (logits >= zeros)
183    relu_logits = array_ops.where(cond, logits, zeros)
184    neg_abs_logits = array_ops.where(cond, -logits, logits)
185    return math_ops.add(
186        relu_logits - logits * labels,
187        math_ops.log1p(math_ops.exp(neg_abs_logits)),
188        name=name)
189
190
191# Note: intentionally calling this v2 to not allow existing code with indirect
192# imports to ignore the sentinel behavior.
193@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
194def sigmoid_cross_entropy_with_logits_v2(  # pylint: disable=invalid-name
195    labels=None,
196    logits=None,
197    name=None):
198  """Computes sigmoid cross entropy given `logits`.
199
200  Measures the probability error in discrete classification tasks in which each
201  class is independent and not mutually exclusive.  For instance, one could
202  perform multilabel classification where a picture can contain both an elephant
203  and a dog at the same time.
204
205  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
206
207        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
208      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
209      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
210      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
211      = (1 - z) * x + log(1 + exp(-x))
212      = x - x * z + log(1 + exp(-x))
213
214  For x < 0, to avoid overflow in exp(-x), we reformulate the above
215
216        x - x * z + log(1 + exp(-x))
217      = log(exp(x)) - x * z + log(1 + exp(-x))
218      = - x * z + log(1 + exp(x))
219
220  Hence, to ensure stability and avoid overflow, the implementation uses this
221  equivalent formulation
222
223      max(x, 0) - x * z + log(1 + exp(-abs(x)))
224
225  `logits` and `labels` must have the same type and shape.
226
227  Args:
228    labels: A `Tensor` of the same type and shape as `logits`.
229    logits: A `Tensor` of type `float32` or `float64`.
230    name: A name for the operation (optional).
231
232  Returns:
233    A `Tensor` of the same shape as `logits` with the componentwise
234    logistic losses.
235
236  Raises:
237    ValueError: If `logits` and `labels` do not have the same shape.
238  """
239  return sigmoid_cross_entropy_with_logits(
240      logits=logits, labels=labels, name=name)
241
242
243@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
244def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
245                                          name=None):
246  """Computes a weighted cross entropy.
247
248  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
249  allows one to trade off recall and precision by up- or down-weighting the
250  cost of a positive error relative to a negative error.
251
252  The usual cross-entropy cost is defined as:
253
254      labels * -log(sigmoid(logits)) +
255          (1 - labels) * -log(1 - sigmoid(logits))
256
257  A value `pos_weight > 1` decreases the false negative count, hence increasing
258  the recall.
259  Conversely setting `pos_weight < 1` decreases the false positive count and
260  increases the precision.
261  This can be seen from the fact that `pos_weight` is introduced as a
262  multiplicative coefficient for the positive labels term
263  in the loss expression:
264
265      labels * -log(sigmoid(logits)) * pos_weight +
266          (1 - labels) * -log(1 - sigmoid(logits))
267
268  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
269  The loss is:
270
271        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
272      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
273      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
274      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
275      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
276      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
277
278  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
279  the implementation uses
280
281      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
282
283  `logits` and `labels` must have the same type and shape.
284
285  Args:
286    labels: A `Tensor` of the same type and shape as `logits`.
287    logits: A `Tensor` of type `float32` or `float64`.
288    pos_weight: A coefficient to use on the positive examples.
289    name: A name for the operation (optional).
290
291  Returns:
292    A `Tensor` of the same shape as `logits` with the componentwise
293    weighted logistic losses.
294
295  Raises:
296    ValueError: If `logits` and `labels` do not have the same shape.
297  """
298  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
299    logits = ops.convert_to_tensor(logits, name="logits")
300    labels = ops.convert_to_tensor(labels, name="labels")
301    try:
302      labels.get_shape().merge_with(logits.get_shape())
303    except ValueError:
304      raise ValueError("logits and labels must have the same shape (%s vs %s)" %
305                       (logits.get_shape(), labels.get_shape()))
306
307    # The logistic loss formula from above is
308    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
309    # For x < 0, a more numerically stable formula is
310    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
311    # To avoid branching, we use the combined version
312    #   (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
313    log_weight = 1 + (pos_weight - 1) * labels
314    return math_ops.add(
315        (1 - labels) * logits,
316        log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
317                      nn_ops.relu(-logits)),
318        name=name)
319
320
321@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
322@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
323def weighted_cross_entropy_with_logits(labels=None,
324                                       logits=None,
325                                       pos_weight=None,
326                                       name=None,
327                                       targets=None):
328  """Computes a weighted cross entropy.
329
330  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
331  allows one to trade off recall and precision by up- or down-weighting the
332  cost of a positive error relative to a negative error.
333
334  The usual cross-entropy cost is defined as:
335
336      labels * -log(sigmoid(logits)) +
337          (1 - labels) * -log(1 - sigmoid(logits))
338
339  A value `pos_weight > 1` decreases the false negative count, hence increasing
340  the recall.
341  Conversely setting `pos_weight < 1` decreases the false positive count and
342  increases the precision.
343  This can be seen from the fact that `pos_weight` is introduced as a
344  multiplicative coefficient for the positive labels term
345  in the loss expression:
346
347      labels * -log(sigmoid(logits)) * pos_weight +
348          (1 - labels) * -log(1 - sigmoid(logits))
349
350  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
351  The loss is:
352
353        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
354      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
355      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
356      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
357      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
358      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
359
360  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
361  the implementation uses
362
363      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
364
365  `logits` and `labels` must have the same type and shape.
366
367  Args:
368    labels: A `Tensor` of the same type and shape as `logits`.
369    logits: A `Tensor` of type `float32` or `float64`.
370    pos_weight: A coefficient to use on the positive examples.
371    name: A name for the operation (optional).
372    targets: Deprecated alias for labels.
373
374  Returns:
375    A `Tensor` of the same shape as `logits` with the componentwise
376    weighted logistic losses.
377
378  Raises:
379    ValueError: If `logits` and `labels` do not have the same shape.
380  """
381  labels = deprecated_argument_lookup("labels", labels, "targets", targets)
382  return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
383
384
385@tf_export("nn.compute_average_loss")
386def compute_average_loss(per_example_loss,
387                         sample_weight=None,
388                         global_batch_size=None):
389  """Scales per-example losses with sample_weights and computes their average.
390
391  Usage with distribution strategy and custom training loop:
392
393  ```python
394  with strategy.scope():
395    def compute_loss(labels, predictions, sample_weight=None):
396
397      # If you are using a `Loss` class instead, set reduction to `NONE` so that
398      # we can do the reduction afterwards and divide by global batch size.
399      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
400          labels, predictions)
401
402      # Compute loss that is scaled by sample_weight and by global batch size.
403      return tf.compute_average_loss(
404          per_example_loss,
405          sample_weight=sample_weight,
406          global_batch_size=GLOBAL_BATCH_SIZE)
407  ```
408
409  Args:
410    per_example_loss: Per-example loss.
411    sample_weight: Optional weighting for each example.
412    global_batch_size: Optional global batch size value. Defaults to (size of
413      first dimension of `losses`) * (number of replicas).
414
415  Returns:
416    Scalar loss value.
417  """  # pylint: disable=g-doc-exception
418  per_example_loss = ops.convert_to_tensor(per_example_loss)
419  input_dtype = per_example_loss.dtype
420
421  with losses_util.check_per_example_loss_rank(per_example_loss):
422    if sample_weight is not None:
423      per_example_loss = losses_util.scale_losses_by_sample_weight(
424          per_example_loss, sample_weight)
425    per_example_loss = math_ops.cast(per_example_loss, input_dtype)
426
427    if global_batch_size is None:
428      if ds.has_strategy() and ds.in_cross_replica_context():
429        raise RuntimeError(
430            "You are calling `compute_average_loss` in cross replica context, "
431            "while it was expected to be called in replica context.")
432
433      num_replicas = ds.get_strategy().num_replicas_in_sync
434      per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
435      global_batch_size = per_replica_batch_size * num_replicas
436      global_batch_size = math_ops.cast(global_batch_size, input_dtype)
437
438    return math_ops.reduce_sum(per_example_loss) / global_batch_size
439
440
441@tf_export("nn.scale_regularization_loss")
442def scale_regularization_loss(regularization_loss):
443  """Scales the sum of the given regularization losses by number of replicas.
444
445  Usage with distribution strategy and custom training loop:
446
447  ```python
448  with strategy.scope():
449    def compute_loss(self, label, predictions):
450      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
451          labels, predictions)
452
453      # Compute loss that is scaled by sample_weight and by global batch size.
454      loss = tf.compute_average_loss(
455          per_example_loss,
456          sample_weight=sample_weight,
457          global_batch_size=GLOBAL_BATCH_SIZE)
458
459      # Add scaled regularization losses.
460      loss += tf.scale_regularization_loss(tf.nn.l2_loss(weights))
461      return loss
462  ```
463
464  Args:
465    regularization_loss: Regularization loss.
466
467  Returns:
468    Scalar loss value.
469  """  # pylint: disable=g-doc-exception
470  if ds.has_strategy() and ds.in_cross_replica_context():
471    raise RuntimeError(
472        "You are calling `scale_regularization_loss` in cross replica context, "
473        "while it was expected to be called in replica context.")
474
475  num_replicas = ds.get_strategy().num_replicas_in_sync
476  return math_ops.reduce_sum(regularization_loss) / num_replicas
477
478
479@tf_export(v1=["nn.relu_layer"])
480def relu_layer(x, weights, biases, name=None):
481  """Computes Relu(x * weight + biases).
482
483  Args:
484    x: a 2D tensor.  Dimensions typically: batch, in_units
485    weights: a 2D tensor.  Dimensions typically: in_units, out_units
486    biases: a 1D tensor.  Dimensions: out_units
487    name: A name for the operation (optional).  If not specified
488      "nn_relu_layer" is used.
489
490  Returns:
491    A 2-D Tensor computing relu(matmul(x, weights) + biases).
492    Dimensions typically: batch, out_units.
493  """
494  with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
495    x = ops.convert_to_tensor(x, name="x")
496    weights = ops.convert_to_tensor(weights, name="weights")
497    biases = ops.convert_to_tensor(biases, name="biases")
498    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
499    return nn_ops.relu(xw_plus_b, name=name)
500
501
502@tf_export("nn.swish")
503@custom_gradient.custom_gradient
504def swish(features):
505  # pylint: disable=g-doc-args
506  """Computes the Swish activation function: `x * sigmoid(x)`.
507
508  Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
509  https://arxiv.org/abs/1710.05941
510
511  Args:
512    features: A `Tensor` representing preactivation values.
513    name: A name for the operation (optional).
514
515  Returns:
516    The activation value.
517  """
518  # pylint: enable=g-doc-args
519  features = ops.convert_to_tensor(features, name="features")
520
521  def grad(dy):
522    """Gradient for the Swish activation function"""
523    # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
524    # around for backprop, effectively doubling the tensor's memory consumption.
525    # We use a control dependency here so that sigmoid(features) is re-computed
526    # during backprop (the control dep prevents it being de-duped with the
527    # forward pass) and we can free the sigmoid(features) expression immediately
528    # after use during the forward pass.
529    with ops.control_dependencies([dy]):
530      sigmoid_features = math_ops.sigmoid(features)
531    activation_grad = (
532        sigmoid_features * (1.0 + features * (1.0 - sigmoid_features)))
533    return dy * activation_grad
534
535  return features * math_ops.sigmoid(features), grad
536
537
538# pylint: disable=redefined-builtin
539@tf_export("linalg.normalize")
540def normalize(tensor, ord="euclidean", axis=None, name=None):
541  """Normalizes `tensor` along dimension `axis` using specified norm.
542
543  This uses `tf.linalg.norm` to compute the norm along `axis`.
544
545  This function can compute several different vector norms (the 1-norm, the
546  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
547  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
548
549  Args:
550    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
551    ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
552      `2`, `np.inf` and any positive real number yielding the corresponding
553      p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
554      `tensor` is a matrix and equivalent to 2-norm for vectors.
555      Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
556        vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
557        '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
558        on how to compute norms for a batch of vectors or matrices stored in a
559        tensor.
560    axis: If `axis` is `None` (the default), the input is considered a vector
561      and a single vector norm is computed over the entire set of values in the
562      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
563      `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
564      input is considered a batch of vectors, and `axis` determines the axis in
565      `tensor` over which to compute vector norms. If `axis` is a 2-tuple of
566      Python integers it is considered a batch of matrices and `axis` determines
567      the axes in `tensor` over which to compute a matrix norm.
568      Negative indices are supported. Example: If you are passing a tensor that
569        can be either a matrix or a batch of matrices at runtime, pass
570        `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
571        computed.
572    name: The name of the op.
573
574  Returns:
575    normalized: A normalized `Tensor` with the same shape as `tensor`.
576    norm: The computed norms with the same shape and dtype `tensor` but the
577      final axis is 1 instead. Same as running
578      `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
579
580  Raises:
581    ValueError: If `ord` or `axis` is invalid.
582  """
583  with ops.name_scope(name, "normalize", [tensor]) as name:
584    tensor = ops.convert_to_tensor(tensor)
585    norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
586    norm = math_ops.cast(norm, tensor.dtype)
587    normalized = tensor / norm
588    return normalized, norm
589
590
591@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
592@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
593def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
594  """Normalizes along dimension `axis` using an L2 norm.
595
596  For a 1-D tensor with `axis = 0`, computes
597
598      output = x / sqrt(max(sum(x**2), epsilon))
599
600  For `x` with more dimensions, independently normalizes each 1-D slice along
601  dimension `axis`.
602
603  Args:
604    x: A `Tensor`.
605    axis: Dimension along which to normalize.  A scalar or a vector of
606      integers.
607    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
608      divisor if `norm < sqrt(epsilon)`.
609    name: A name for this operation (optional).
610    dim: Deprecated alias for axis.
611
612  Returns:
613    A `Tensor` with the same shape as `x`.
614  """
615  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
616  return l2_normalize_v2(x, axis, epsilon, name)
617
618
619@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
620def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
621  """Normalizes along dimension `axis` using an L2 norm.
622
623  For a 1-D tensor with `axis = 0`, computes
624
625      output = x / sqrt(max(sum(x**2), epsilon))
626
627  For `x` with more dimensions, independently normalizes each 1-D slice along
628  dimension `axis`.
629
630  Args:
631    x: A `Tensor`.
632    axis: Dimension along which to normalize.  A scalar or a vector of
633      integers.
634    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
635      divisor if `norm < sqrt(epsilon)`.
636    name: A name for this operation (optional).
637
638  Returns:
639    A `Tensor` with the same shape as `x`.
640  """
641  with ops.name_scope(name, "l2_normalize", [x]) as name:
642    x = ops.convert_to_tensor(x, name="x")
643    square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
644    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
645    return math_ops.multiply(x, x_inv_norm, name=name)
646
647
648def _count_nonzero(input_tensor, dtype=dtypes.int64):
649  """Same as math_ops.count_nonzero.
650
651  The reduction is done in dtype, which can be faster for 32-bit dtypes.
652
653  Args:
654      input_tensor: numeric tensor
655      dtype: reduction dtype
656
657  Returns:
658      number of nonzero values with type dtype
659  """
660  with ops.name_scope("count_nonzero", values=[input_tensor]):
661    zero = array_ops.zeros([], dtype=input_tensor.dtype)
662    nonzero_count = math_ops.reduce_sum(
663        math_ops.cast(
664            math_ops.not_equal(input_tensor, zero),
665            dtype=dtype), name="nonzero_count")
666    return nonzero_count
667
668
669@tf_export("math.zero_fraction", "nn.zero_fraction")
670def zero_fraction(value, name=None):
671  """Returns the fraction of zeros in `value`.
672
673  If `value` is empty, the result is `nan`.
674
675  This is useful in summaries to measure and report sparsity.  For example,
676
677  ```python
678      z = tf.nn.relu(...)
679      summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
680  ```
681
682  Args:
683    value: A tensor of numeric type.
684    name: A name for the operation (optional).
685
686  Returns:
687    The fraction of zeros in `value`, with type `float32`.
688  """
689  with ops.name_scope(name, "zero_fraction", [value]):
690    value = ops.convert_to_tensor(value, name="value")
691    size = array_ops.size(value, out_type=dtypes.int64)
692    # If the count is small, we can save memory/CPU with an int32 reduction.
693    num_nonzero = control_flow_ops.cond(
694        size <= dtypes.int32.max,
695        # pylint: disable=g-long-lambda
696        true_fn=lambda: math_ops.cast(
697            _count_nonzero(value, dtype=dtypes.int32),
698            dtype=dtypes.int64),
699        false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
700
701    with ops.name_scope("counts_to_fraction"):
702      num_zero = size - num_nonzero
703      num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
704      size_float32 = math_ops.cast(size, dtype=dtypes.float32)
705      zero_fraction_float32 = num_zero_float32 / size_float32
706
707    return array_ops.identity(zero_fraction_float32, "fraction")
708
709
710# pylint: disable=redefined-builtin
711@tf_export(v1=["nn.depthwise_conv2d"])
712def depthwise_conv2d(input,
713                     filter,
714                     strides,
715                     padding,
716                     rate=None,
717                     name=None,
718                     data_format=None,
719                     dilations=None):
720  """Depthwise 2-D convolution.
721
722  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
723  and a filter tensor of shape
724  `[filter_height, filter_width, in_channels, channel_multiplier]`
725  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
726  applies a different filter to each input channel (expanding from 1 channel
727  to `channel_multiplier` channels for each), then concatenates the results
728  together.  The output has `in_channels * channel_multiplier` channels.
729
730  In detail, with the default NHWC format,
731
732      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
733           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
734                                           strides[2] * j + rate[1] * dj, k]
735
736  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
737  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
738  If any value in `rate` is greater than 1, we perform atrous depthwise
739  convolution, in which case all values in the `strides` tensor must be equal
740  to 1.
741
742  Args:
743    input: 4-D with shape according to `data_format`.
744    filter: 4-D with shape
745      `[filter_height, filter_width, in_channels, channel_multiplier]`.
746    strides: 1-D of size 4.  The stride of the sliding window for each
747      dimension of `input`.
748    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
749      See the "returns" section of `tf.nn.convolution` for details.
750    rate: 1-D of size 2. The dilation rate in which we sample input values
751      across the `height` and `width` dimensions in atrous convolution. If it is
752      greater than 1, then all values of strides must be 1.
753    name: A name for this operation (optional).
754    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
755    dilations: Alias of rate.
756
757  Returns:
758    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
759    "NHWC" format, shape is
760    `[batch, out_height, out_width, in_channels * channel_multiplier].`
761  """
762  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
763  with ops.name_scope(name, "depthwise", [input, filter]) as name:
764    input = ops.convert_to_tensor(input, name="tensor_in")
765    filter = ops.convert_to_tensor(filter, name="filter_in")
766    if rate is None:
767      rate = [1, 1]
768
769    # Use depthwise_conv2d_native if executing on TPU.
770    if device_context.enclosing_tpu_context() is not None:
771      if data_format == "NCHW":
772        dilations = [1, 1, rate[0], rate[1]]
773      else:
774        dilations = [1, rate[0], rate[1], 1]
775      return nn_ops.depthwise_conv2d_native(
776          input=input,
777          filter=filter,
778          strides=strides,
779          padding=padding,
780          data_format=data_format,
781          dilations=dilations,
782          name=name)
783
784    def op(input_converted, _, padding):
785      return nn_ops.depthwise_conv2d_native(
786          input=input_converted,
787          filter=filter,
788          strides=strides,
789          padding=padding,
790          data_format=data_format,
791          name=name)
792
793    return nn_ops.with_space_to_batch(
794        input=input,
795        filter_shape=array_ops.shape(filter),
796        dilation_rate=rate,
797        padding=padding,
798        data_format=data_format,
799        op=op)
800
801
802@tf_export("nn.depthwise_conv2d", v1=[])
803def depthwise_conv2d_v2(input,
804                        filter,
805                        strides,
806                        padding,
807                        data_format=None,
808                        dilations=None,
809                        name=None):
810  """Depthwise 2-D convolution.
811
812  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
813  and a filter tensor of shape
814  `[filter_height, filter_width, in_channels, channel_multiplier]`
815  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
816  applies a different filter to each input channel (expanding from 1 channel
817  to `channel_multiplier` channels for each), then concatenates the results
818  together.  The output has `in_channels * channel_multiplier` channels.
819
820  In detail, with the default NHWC format,
821
822      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
823           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
824                                           strides[2] * j + rate[1] * dj, k]
825
826  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
827  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
828  If any value in `rate` is greater than 1, we perform atrous depthwise
829  convolution, in which case all values in the `strides` tensor must be equal
830  to 1.
831
832  Args:
833    input: 4-D with shape according to `data_format`.
834    filter: 4-D with shape
835      `[filter_height, filter_width, in_channels, channel_multiplier]`.
836    strides: 1-D of size 4.  The stride of the sliding window for each
837      dimension of `input`.
838    padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
839      See the "returns" section of `tf.nn.convolution` for details.
840    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
841    dilations: 1-D of size 2. The dilation rate in which we sample input values
842      across the `height` and `width` dimensions in atrous convolution. If it is
843      greater than 1, then all values of strides must be 1.
844    name: A name for this operation (optional).
845
846  Returns:
847    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
848    "NHWC" format, shape is
849    `[batch, out_height, out_width, in_channels * channel_multiplier].`
850  """
851  return depthwise_conv2d(input=input,
852                          filter=filter,
853                          strides=strides,
854                          padding=padding,
855                          rate=dilations,
856                          name=name,
857                          data_format=data_format)
858
859# pylint: enable=redefined-builtin
860
861
862# pylint: disable=redefined-builtin,line-too-long
863@tf_export(v1=["nn.separable_conv2d"])
864def separable_conv2d(input,
865                     depthwise_filter,
866                     pointwise_filter,
867                     strides,
868                     padding,
869                     rate=None,
870                     name=None,
871                     data_format=None,
872                     dilations=None):
873  """2-D convolution with separable filters.
874
875  Performs a depthwise convolution that acts separately on channels followed by
876  a pointwise convolution that mixes channels.  Note that this is separability
877  between dimensions `[1, 2]` and `3`, not spatial separability between
878  dimensions `1` and `2`.
879
880  In detail, with the default NHWC format,
881
882      output[b, i, j, k] = sum_{di, dj, q, r}
883          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
884          depthwise_filter[di, dj, q, r] *
885          pointwise_filter[0, 0, q * channel_multiplier + r, k]
886
887  `strides` controls the strides for the depthwise convolution only, since
888  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
889  `strides[0] = strides[3] = 1`.  For the most common case of the same
890  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
891  If any value in `rate` is greater than 1, we perform atrous depthwise
892  convolution, in which case all values in the `strides` tensor must be equal
893  to 1.
894
895  Args:
896    input: 4-D `Tensor` with shape according to `data_format`.
897    depthwise_filter: 4-D `Tensor` with shape
898      `[filter_height, filter_width, in_channels, channel_multiplier]`.
899      Contains `in_channels` convolutional filters of depth 1.
900    pointwise_filter: 4-D `Tensor` with shape
901      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
902      filter to mix channels after `depthwise_filter` has convolved spatially.
903    strides: 1-D of size 4.  The strides for the depthwise convolution for
904      each dimension of `input`.
905    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm.
906      See the "returns" section of `tf.nn.convolution` for details.
907    rate: 1-D of size 2. The dilation rate in which we sample input values
908      across the `height` and `width` dimensions in atrous convolution. If it is
909      greater than 1, then all values of strides must be 1.
910    name: A name for this operation (optional).
911    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
912    dilations: Alias of rate.
913
914  Returns:
915    A 4-D `Tensor` with shape according to 'data_format'. For
916      example, with data_format="NHWC", shape is [batch, out_height,
917      out_width, out_channels].
918  """
919  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
920  with ops.name_scope(name, "separable_conv2d",
921                      [input, depthwise_filter, pointwise_filter]) as name:
922    input = ops.convert_to_tensor(input, name="tensor_in")
923    depthwise_filter = ops.convert_to_tensor(
924        depthwise_filter, name="depthwise_filter")
925    pointwise_filter = ops.convert_to_tensor(
926        pointwise_filter, name="pointwise_filter")
927
928    pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
929    pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
930    pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
931
932    if rate is None:
933      rate = [1, 1]
934
935    # The layout of the ops in the graph are expected to be as follows:
936    # depthwise_conv2d  // Conv2D op corresponding to native deptwise conv.
937    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
938
939    def op(input_converted, _, padding):
940      return nn_ops.depthwise_conv2d_native(
941          input=input_converted,
942          filter=depthwise_filter,
943          strides=strides,
944          padding=padding,
945          data_format=data_format,
946          name="depthwise")
947
948    depthwise = nn_ops.with_space_to_batch(
949        input=input,
950        filter_shape=array_ops.shape(depthwise_filter),
951        dilation_rate=rate,
952        padding=padding,
953        data_format=data_format,
954        op=op)
955
956    return nn_ops.conv2d(
957        depthwise,
958        pointwise_filter, [1, 1, 1, 1],
959        padding="VALID",
960        data_format=data_format,
961        name=name)
962
963
964@tf_export("nn.separable_conv2d", v1=[])
965def separable_conv2d_v2(
966    input,
967    depthwise_filter,
968    pointwise_filter,
969    strides,
970    padding,
971    data_format=None,
972    dilations=None,
973    name=None,
974):
975  """2-D convolution with separable filters.
976
977  Performs a depthwise convolution that acts separately on channels followed by
978  a pointwise convolution that mixes channels.  Note that this is separability
979  between dimensions `[1, 2]` and `3`, not spatial separability between
980  dimensions `1` and `2`.
981
982  In detail, with the default NHWC format,
983
984      output[b, i, j, k] = sum_{di, dj, q, r}
985          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
986          depthwise_filter[di, dj, q, r] *
987          pointwise_filter[0, 0, q * channel_multiplier + r, k]
988
989  `strides` controls the strides for the depthwise convolution only, since
990  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
991  `strides[0] = strides[3] = 1`.  For the most common case of the same
992  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
993  If any value in `rate` is greater than 1, we perform atrous depthwise
994  convolution, in which case all values in the `strides` tensor must be equal
995  to 1.
996
997  Args:
998    input: 4-D `Tensor` with shape according to `data_format`.
999    depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
1000      in_channels, channel_multiplier]`. Contains `in_channels` convolutional
1001      filters of depth 1.
1002    pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
1003      in_channels, out_channels]`.  Pointwise filter to mix channels after
1004      `depthwise_filter` has convolved spatially.
1005    strides: 1-D of size 4.  The strides for the depthwise convolution for each
1006      dimension of `input`.
1007    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm. See
1008      the "returns" section of `tf.nn.convolution` for details.
1009    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1010    dilations: 1-D of size 2. The dilation rate in which we sample input values
1011      across the `height` and `width` dimensions in atrous convolution. If it is
1012      greater than 1, then all values of strides must be 1.
1013    name: A name for this operation (optional).
1014
1015  Returns:
1016    A 4-D `Tensor` with shape according to 'data_format'. For
1017      example, with data_format="NHWC", shape is [batch, out_height,
1018      out_width, out_channels].
1019  """
1020  return separable_conv2d(
1021      input,
1022      depthwise_filter,
1023      pointwise_filter,
1024      strides,
1025      padding,
1026      rate=dilations,
1027      name=name,
1028      data_format=data_format)
1029
1030# pylint: enable=redefined-builtin,line-too-long
1031
1032
1033@tf_export(v1=["nn.sufficient_statistics"])
1034def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
1035                          keepdims=None):
1036  """Calculate the sufficient statistics for the mean and variance of `x`.
1037
1038  These sufficient statistics are computed using the one pass algorithm on
1039  an input that's optionally shifted. See:
1040  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1041
1042  Args:
1043    x: A `Tensor`.
1044    axes: Array of ints. Axes along which to compute mean and variance.
1045    shift: A `Tensor` containing the value by which to shift the data for
1046      numerical stability, or `None` if no shift is to be performed. A shift
1047      close to the true mean provides the most numerically stable results.
1048    keep_dims: produce statistics with the same dimensionality as the input.
1049    name: Name used to scope the operations that compute the sufficient stats.
1050    keepdims: Alias for keep_dims.
1051
1052  Returns:
1053    Four `Tensor` objects of the same type as `x`:
1054
1055    * the count (number of elements to average over).
1056    * the (possibly shifted) sum of the elements in the array.
1057    * the (possibly shifted) sum of squares of the elements in the array.
1058    * the shift by which the mean must be corrected or None if `shift` is None.
1059  """
1060  axes = list(set(axes))
1061  keep_dims = deprecated_argument_lookup(
1062      "keepdims", keepdims, "keep_dims", keep_dims)
1063  if keep_dims is None:
1064    keep_dims = False
1065  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
1066    x = ops.convert_to_tensor(x, name="x")
1067    x_shape = x.get_shape()
1068    if x_shape.rank is not None and all(
1069        x_shape.dims[d].value is not None for d in axes):
1070      counts = 1
1071      for d in axes:
1072        counts *= x_shape.dims[d].value
1073      counts = constant_op.constant(counts, dtype=x.dtype)
1074    else:  # shape needs to be inferred at runtime.
1075      x_dims = array_ops.gather(
1076          math_ops.cast(array_ops.shape(x), x.dtype), axes)
1077      counts = math_ops.reduce_prod(x_dims, name="count")
1078    if shift is not None:
1079      shift = ops.convert_to_tensor(shift, name="shift")
1080      m_ss = math_ops.subtract(x, shift)
1081      v_ss = math_ops.squared_difference(x, shift)
1082    else:  # no shift.
1083      m_ss = x
1084      v_ss = math_ops.square(x)
1085    m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
1086    v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
1087  return counts, m_ss, v_ss, shift
1088
1089
1090@tf_export("nn.sufficient_statistics", v1=[])
1091def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
1092  """Calculate the sufficient statistics for the mean and variance of `x`.
1093
1094  These sufficient statistics are computed using the one pass algorithm on
1095  an input that's optionally shifted. See:
1096  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1097
1098  Args:
1099    x: A `Tensor`.
1100    axes: Array of ints. Axes along which to compute mean and variance.
1101    shift: A `Tensor` containing the value by which to shift the data for
1102      numerical stability, or `None` if no shift is to be performed. A shift
1103      close to the true mean provides the most numerically stable results.
1104    keepdims: produce statistics with the same dimensionality as the input.
1105    name: Name used to scope the operations that compute the sufficient stats.
1106
1107  Returns:
1108    Four `Tensor` objects of the same type as `x`:
1109
1110    * the count (number of elements to average over).
1111    * the (possibly shifted) sum of the elements in the array.
1112    * the (possibly shifted) sum of squares of the elements in the array.
1113    * the shift by which the mean must be corrected or None if `shift` is None.
1114  """
1115  return sufficient_statistics(
1116      x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
1117
1118
1119@tf_export("nn.normalize_moments")
1120def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
1121  """Calculate the mean and variance of based on the sufficient statistics.
1122
1123  Args:
1124    counts: A `Tensor` containing the total count of the data (one value).
1125    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
1126      shifted) sum of the elements to average over.
1127    variance_ss: A `Tensor` containing the variance sufficient statistics: the
1128      (possibly shifted) squared sum of the data to compute the variance over.
1129    shift: A `Tensor` containing the value by which the data is shifted for
1130      numerical stability, or `None` if no shift was performed.
1131    name: Name used to scope the operations that compute the moments.
1132
1133  Returns:
1134    Two `Tensor` objects: `mean` and `variance`.
1135  """
1136  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
1137    divisor = math_ops.reciprocal(counts, name="divisor")
1138    if shift is not None:
1139      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
1140      mean = math_ops.add(shifted_mean, shift, name="mean")
1141    else:  # no shift.
1142      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
1143      mean = shifted_mean
1144    variance = math_ops.subtract(
1145        math_ops.multiply(variance_ss, divisor),
1146        math_ops.square(shifted_mean),
1147        name="variance")
1148  return (mean, variance)
1149
1150
1151@tf_export(v1=["nn.moments"])
1152def moments(
1153    x,
1154    axes,
1155    shift=None,  # pylint: disable=unused-argument
1156    name=None,
1157    keep_dims=None,
1158    keepdims=None):
1159  """Calculate the mean and variance of `x`.
1160
1161  The mean and variance are calculated by aggregating the contents of `x`
1162  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1163  and variance of a vector.
1164
1165  Note: shift is currently not used; the true mean is computed and used.
1166
1167  When using these moments for batch normalization (see
1168  `tf.nn.batch_normalization`):
1169
1170   * for so-called "global normalization", used with convolutional filters with
1171     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1172   * for simple batch normalization pass `axes=[0]` (batch only).
1173
1174  Args:
1175    x: A `Tensor`.
1176    axes: Array of ints.  Axes along which to compute mean and
1177      variance.
1178    shift: Not used in the current implementation
1179    name: Name used to scope the operations that compute the moments.
1180    keep_dims: produce moments with the same dimensionality as the input.
1181    keepdims: Alias to keep_dims.
1182
1183  Returns:
1184    Two `Tensor` objects: `mean` and `variance`.
1185  """
1186  keep_dims = deprecated_argument_lookup(
1187      "keepdims", keepdims, "keep_dims", keep_dims)
1188  if keep_dims is None:
1189    keep_dims = False
1190  with ops.name_scope(name, "moments", [x, axes]):
1191    # The dynamic range of fp16 is too limited to support the collection of
1192    # sufficient statistics. As a workaround we simply perform the operations
1193    # on 32-bit floats before converting the mean and variance back to fp16
1194    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1195    # Compute true mean while keeping the dims for proper broadcasting.
1196    mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
1197    # sample variance, not unbiased variance
1198    # Note: stop_gradient does not change the gradient that gets
1199    #       backpropagated to the mean from the variance calculation,
1200    #       because that gradient is zero
1201    variance = math_ops.reduce_mean(
1202        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1203        axes,
1204        keepdims=True,
1205        name="variance")
1206    if not keep_dims:
1207      mean = array_ops.squeeze(mean, axes)
1208      variance = array_ops.squeeze(variance, axes)
1209    if x.dtype == dtypes.float16:
1210      return (math_ops.cast(mean, dtypes.float16),
1211              math_ops.cast(variance, dtypes.float16))
1212    else:
1213      return (mean, variance)
1214
1215
1216@tf_export("nn.moments", v1=[])
1217def moments_v2(
1218    x,
1219    axes,
1220    shift=None,
1221    keepdims=False,
1222    name=None):
1223  """Calculates the mean and variance of `x`.
1224
1225  The mean and variance are calculated by aggregating the contents of `x`
1226  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1227  and variance of a vector.
1228
1229  Note: shift is currently not used; the true mean is computed and used.
1230
1231  When using these moments for batch normalization (see
1232  `tf.nn.batch_normalization`):
1233
1234   * for so-called "global normalization", used with convolutional filters with
1235     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1236   * for simple batch normalization pass `axes=[0]` (batch only).
1237
1238  Args:
1239    x: A `Tensor`.
1240    axes: Array of ints.  Axes along which to compute mean and
1241      variance.
1242    shift: Not used in the current implementation.
1243    keepdims: produce moments with the same dimensionality as the input.
1244    name: Name used to scope the operations that compute the moments.
1245
1246  Returns:
1247    Two `Tensor` objects: `mean` and `variance`.
1248  """
1249  return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1250
1251
1252@tf_export(v1=["nn.weighted_moments"])
1253def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1254                     keepdims=None):
1255  """Returns the frequency-weighted mean and variance of `x`.
1256
1257  Args:
1258    x: A tensor.
1259    axes: 1-d tensor of int32 values; these are the axes along which
1260      to compute mean and variance.
1261    frequency_weights: A tensor of positive weights which can be
1262      broadcast with x.
1263    name: Name used to scope the operation.
1264    keep_dims: Produce moments with the same dimensionality as the input.
1265    keepdims: Alias of keep_dims.
1266
1267  Returns:
1268    Two tensors: `weighted_mean` and `weighted_variance`.
1269  """
1270  keep_dims = deprecated_argument_lookup(
1271      "keepdims", keepdims, "keep_dims", keep_dims)
1272  if keep_dims is None:
1273    keep_dims = False
1274  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1275    x = ops.convert_to_tensor(x, name="x")
1276    frequency_weights = ops.convert_to_tensor(
1277        frequency_weights, name="frequency_weights")
1278
1279    # Unlike moments(), this just uses a simpler two-pass method.
1280
1281    # See comment in moments() WRT precision; it applies here too.
1282    needs_cast = x.dtype == dtypes.float16
1283    if needs_cast:
1284      x = math_ops.cast(x, dtypes.float32)
1285
1286    if frequency_weights.dtype != x.dtype:
1287      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1288
1289    # Note that we use keep_dims=True for our reductions regardless of the arg;
1290    # this is so that the results remain broadcast-compatible with the inputs.
1291    weighted_input_sum = math_ops.reduce_sum(
1292        frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1293
1294    # The shape of the weights isn't necessarily the same as x's
1295    # shape, just broadcast-compatible with it -- so this expression
1296    # performs broadcasting to give a per-item weight, with the same
1297    # shape as (freqency_weights * x). This avoids having to reason
1298    # through all the broadcast logic to compute a correct
1299    # sum_of_weights.
1300    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1301
1302    sum_of_weights = math_ops.reduce_sum(
1303        broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1304
1305    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
1306
1307    weighted_mean = math_ops.multiply(weighted_input_sum, divisor)
1308
1309    # Have the weighted mean; now on to variance:
1310    weighted_distsq = math_ops.reduce_sum(
1311        frequency_weights * math_ops.squared_difference(x, weighted_mean),
1312        axes,
1313        name="weighted_distsq",
1314        keepdims=True)
1315
1316    weighted_variance = math_ops.multiply(weighted_distsq, divisor)
1317
1318    if not keep_dims:
1319      weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1320      weighted_variance = array_ops.squeeze(
1321          weighted_variance, axis=axes)
1322
1323    if needs_cast:
1324      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1325      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1326
1327    return weighted_mean, weighted_variance
1328
1329
1330@tf_export("nn.weighted_moments", v1=[])
1331def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1332  """Returns the frequency-weighted mean and variance of `x`.
1333
1334  Args:
1335    x: A tensor.
1336    axes: 1-d tensor of int32 values; these are the axes along which
1337      to compute mean and variance.
1338    frequency_weights: A tensor of positive weights which can be
1339      broadcast with x.
1340    keepdims: Produce moments with the same dimensionality as the input.
1341    name: Name used to scope the operation.
1342
1343  Returns:
1344    Two tensors: `weighted_mean` and `weighted_variance`.
1345  """
1346  return weighted_moments(
1347      x=x,
1348      axes=axes,
1349      frequency_weights=frequency_weights,
1350      name=name,
1351      keep_dims=keepdims)
1352
1353
1354@tf_export("nn.batch_normalization")
1355def batch_normalization(x,
1356                        mean,
1357                        variance,
1358                        offset,
1359                        scale,
1360                        variance_epsilon,
1361                        name=None):
1362  r"""Batch normalization.
1363
1364  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1365  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1366
1367  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1368
1369  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1370  shapes:
1371
1372    * In all generality, they can have the same number of dimensions as the
1373      input `x`, with identical sizes as `x` for the dimensions that are not
1374      normalized over (the 'depth' dimension(s)), and dimension 1 for the
1375      others which are being normalized over.
1376      `mean` and `variance` in this case would typically be the outputs of
1377      `tf.nn.moments(..., keepdims=True)` during training, or running averages
1378      thereof during inference.
1379    * In the common case where the 'depth' dimension is the last dimension in
1380      the input tensor `x`, they may be one dimensional tensors of the same
1381      size as the 'depth' dimension.
1382      This is the case for example for the common `[batch, depth]` layout of
1383      fully-connected layers, and `[batch, height, width, depth]` for
1384      convolutions.
1385      `mean` and `variance` in this case would typically be the outputs of
1386      `tf.nn.moments(..., keepdims=False)` during training, or running averages
1387      thereof during inference.
1388
1389  See equation 11 in Algorithm 2 of source:
1390  [Batch Normalization: Accelerating Deep Network Training by
1391  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1392  (http://arxiv.org/abs/1502.03167).
1393
1394  Args:
1395    x: Input `Tensor` of arbitrary dimensionality.
1396    mean: A mean `Tensor`.
1397    variance: A variance `Tensor`.
1398    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1399      None. If present, will be added to the normalized tensor.
1400    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1401      `None`. If present, the scale is applied to the normalized tensor.
1402    variance_epsilon: A small float number to avoid dividing by 0.
1403    name: A name for this operation (optional).
1404
1405  Returns:
1406    the normalized, scaled, offset tensor.
1407
1408  References:
1409    Batch Normalization - Accelerating Deep Network Training by Reducing
1410    Internal Covariate Shift:
1411      [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
1412      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1413  """
1414  with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1415    inv = math_ops.rsqrt(variance + variance_epsilon)
1416    if scale is not None:
1417      inv *= scale
1418    # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1419    # the precise order of ops that are generated by the expression below.
1420    return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1421        offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1422
1423
1424@tf_export(v1=["nn.fused_batch_norm"])
1425def fused_batch_norm(
1426    x,
1427    scale,
1428    offset,  # pylint: disable=invalid-name
1429    mean=None,
1430    variance=None,
1431    epsilon=0.001,
1432    data_format="NHWC",
1433    is_training=True,
1434    name=None):
1435  r"""Batch normalization.
1436
1437
1438  See Source: [Batch Normalization: Accelerating Deep Network Training by
1439  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1440  (http://arxiv.org/abs/1502.03167).
1441
1442  Args:
1443    x: Input `Tensor` of 4 dimensions.
1444    scale: A `Tensor` of 1 dimension for scaling.
1445    offset: A `Tensor` of 1 dimension for bias.
1446    mean: A `Tensor` of 1 dimension for population mean used for inference.
1447    variance: A `Tensor` of 1 dimension for population variance
1448              used for inference.
1449    epsilon: A small float number added to the variance of x.
1450    data_format: The data format for x. Either "NHWC" (default) or "NCHW".
1451    is_training: A bool value to specify if the operation is used for
1452                 training or inference.
1453    name: A name for this operation (optional).
1454
1455  Returns:
1456    y: A 4D Tensor for the normalized, scaled, offsetted x.
1457    batch_mean: A 1D Tensor for the mean of x.
1458    batch_var: A 1D Tensor for the variance of x.
1459
1460  Raises:
1461    ValueError: If mean or variance is not None when is_training is True.
1462
1463  References:
1464    Batch Normalization - Accelerating Deep Network Training by Reducing
1465    Internal Covariate Shift:
1466      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1467      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1468  """
1469  x = ops.convert_to_tensor(x, name="input")
1470  scale = ops.convert_to_tensor(scale, name="scale")
1471  offset = ops.convert_to_tensor(offset, name="offset")
1472  if is_training:
1473    if (mean is not None) or (variance is not None):
1474      raise ValueError("Both 'mean' and 'variance' must be None "
1475                       "if is_training is True.")
1476  if mean is None:
1477    mean = constant_op.constant([])
1478  if variance is None:
1479    variance = constant_op.constant([])
1480  # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
1481  # prevent exception (see cudnn.h).
1482  min_epsilon = 1.001e-5
1483  epsilon = epsilon if epsilon > min_epsilon else min_epsilon
1484
1485  y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
1486      x,
1487      scale,
1488      offset,
1489      mean,
1490      variance,
1491      epsilon=epsilon,
1492      data_format=data_format,
1493      is_training=is_training,
1494      name=name)
1495  return y, batch_mean, batch_var
1496
1497
1498@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1499def batch_norm_with_global_normalization(t=None,
1500                                         m=None,
1501                                         v=None,
1502                                         beta=None,
1503                                         gamma=None,
1504                                         variance_epsilon=None,
1505                                         scale_after_normalization=None,
1506                                         name=None,
1507                                         input=None,  # pylint: disable=redefined-builtin
1508                                         mean=None,
1509                                         variance=None):
1510  """Batch normalization.
1511
1512  This op is deprecated. See `tf.nn.batch_normalization`.
1513
1514  Args:
1515    t: A 4D input Tensor.
1516    m: A 1D mean Tensor with size matching the last dimension of t.
1517      This is the first output from tf.nn.moments,
1518      or a saved moving average thereof.
1519    v: A 1D variance Tensor with size matching the last dimension of t.
1520      This is the second output from tf.nn.moments,
1521      or a saved moving average thereof.
1522    beta: A 1D beta Tensor with size matching the last dimension of t.
1523      An offset to be added to the normalized tensor.
1524    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1525      If "scale_after_normalization" is true, this tensor will be multiplied
1526      with the normalized tensor.
1527    variance_epsilon: A small float number to avoid dividing by 0.
1528    scale_after_normalization: A bool indicating whether the resulted tensor
1529      needs to be multiplied with gamma.
1530    name: A name for this operation (optional).
1531    input: Alias for t.
1532    mean: Alias for m.
1533    variance: Alias for v.
1534
1535  Returns:
1536     A batch-normalized `t`.
1537
1538  References:
1539    Batch Normalization - Accelerating Deep Network Training by Reducing
1540    Internal Covariate Shift:
1541      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1542      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1543  """
1544  t = deprecated_argument_lookup("input", input, "t", t)
1545  m = deprecated_argument_lookup("mean", mean, "m", m)
1546  v = deprecated_argument_lookup("variance", variance, "v", v)
1547  return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1548                             else None, variance_epsilon, name)
1549
1550
1551# pylint: disable=redefined-builtin,line-too-long
1552@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1553def batch_norm_with_global_normalization_v2(input,
1554                                            mean,
1555                                            variance,
1556                                            beta,
1557                                            gamma,
1558                                            variance_epsilon,
1559                                            scale_after_normalization,
1560                                            name=None):
1561  """Batch normalization.
1562
1563  This op is deprecated. See `tf.nn.batch_normalization`.
1564
1565  Args:
1566    input: A 4D input Tensor.
1567    mean: A 1D mean Tensor with size matching the last dimension of t.
1568      This is the first output from tf.nn.moments,
1569      or a saved moving average thereof.
1570    variance: A 1D variance Tensor with size matching the last dimension of t.
1571      This is the second output from tf.nn.moments,
1572      or a saved moving average thereof.
1573    beta: A 1D beta Tensor with size matching the last dimension of t.
1574      An offset to be added to the normalized tensor.
1575    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1576      If "scale_after_normalization" is true, this tensor will be multiplied
1577      with the normalized tensor.
1578    variance_epsilon: A small float number to avoid dividing by 0.
1579    scale_after_normalization: A bool indicating whether the resulted tensor
1580      needs to be multiplied with gamma.
1581    name: A name for this operation (optional).
1582
1583  Returns:
1584     A batch-normalized `t`.
1585
1586  References:
1587    Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
1588      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1589      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1590  """
1591  return batch_norm_with_global_normalization(t=input,
1592                                              m=mean,
1593                                              v=variance,
1594                                              beta=beta,
1595                                              gamma=gamma,
1596                                              variance_epsilon=variance_epsilon,
1597                                              scale_after_normalization=scale_after_normalization,
1598                                              name=name)
1599
1600# pylint: enable=redefined-builtin,line-too-long
1601
1602
1603def _sum_rows(x):
1604  """Returns a vector summing up each row of the matrix x."""
1605  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1606  # a matrix.  The gradient of _sum_rows(x) is more efficient than
1607  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1608  # we use _sum_rows(x) in the nce_loss() computation since the loss
1609  # is mostly used for training.
1610  cols = array_ops.shape(x)[1]
1611  ones_shape = array_ops.stack([cols, 1])
1612  ones = array_ops.ones(ones_shape, x.dtype)
1613  return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1614
1615
1616def _compute_sampled_logits(weights,
1617                            biases,
1618                            labels,
1619                            inputs,
1620                            num_sampled,
1621                            num_classes,
1622                            num_true=1,
1623                            sampled_values=None,
1624                            subtract_log_q=True,
1625                            remove_accidental_hits=False,
1626                            partition_strategy="mod",
1627                            name=None,
1628                            seed=None):
1629  """Helper function for nce_loss and sampled_softmax_loss functions.
1630
1631  Computes sampled output training logits and labels suitable for implementing
1632  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1633  sampled_softmax_loss).
1634
1635  Note: In the case where num_true > 1, we assign to each target class
1636  the target probability 1 / num_true so that the target probabilities
1637  sum to 1 per-example.
1638
1639  Args:
1640    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1641        objects whose concatenation along dimension 0 has shape
1642        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
1643    biases: A `Tensor` of shape `[num_classes]`.  The (possibly-partitioned)
1644        class biases.
1645    labels: A `Tensor` of type `int64` and shape `[batch_size,
1646        num_true]`. The target classes.  Note that this format differs from
1647        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
1648    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1649        activations of the input network.
1650    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1651    num_classes: An `int`. The number of possible classes.
1652    num_true: An `int`.  The number of target classes per training example.
1653    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1654        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1655        (if None, we default to `log_uniform_candidate_sampler`)
1656    subtract_log_q: A `bool`.  whether to subtract the log expected count of
1657        the labels in the sample to get the logits of the true labels.
1658        Default is True.  Turn off for Negative Sampling.
1659    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1660        where a sampled class equals one of the target classes.  Default is
1661        False.
1662    partition_strategy: A string specifying the partitioning strategy, relevant
1663        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1664        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1665    name: A name for the operation (optional).
1666    seed: random seed for candidate sampling. Default to None, which doesn't set
1667        the op-level random seed for candidate sampling.
1668  Returns:
1669    out_logits: `Tensor` object with shape
1670        `[batch_size, num_true + num_sampled]`, for passing to either
1671        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1672        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
1673    out_labels: A Tensor object with the same shape as `out_logits`.
1674  """
1675
1676  if isinstance(weights, variables.PartitionedVariable):
1677    weights = list(weights)
1678  if not isinstance(weights, list):
1679    weights = [weights]
1680
1681  with ops.name_scope(name, "compute_sampled_logits",
1682                      weights + [biases, inputs, labels]):
1683    if labels.dtype != dtypes.int64:
1684      labels = math_ops.cast(labels, dtypes.int64)
1685    labels_flat = array_ops.reshape(labels, [-1])
1686
1687    # Sample the negative labels.
1688    #   sampled shape: [num_sampled] tensor
1689    #   true_expected_count shape = [batch_size, 1] tensor
1690    #   sampled_expected_count shape = [num_sampled] tensor
1691    if sampled_values is None:
1692      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1693          true_classes=labels,
1694          num_true=num_true,
1695          num_sampled=num_sampled,
1696          unique=True,
1697          range_max=num_classes,
1698          seed=seed)
1699    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1700    # pylint: disable=unpacking-non-sequence
1701    sampled, true_expected_count, sampled_expected_count = (
1702        array_ops.stop_gradient(s) for s in sampled_values)
1703    # pylint: enable=unpacking-non-sequence
1704    sampled = math_ops.cast(sampled, dtypes.int64)
1705
1706    # labels_flat is a [batch_size * num_true] tensor
1707    # sampled is a [num_sampled] int tensor
1708    all_ids = array_ops.concat([labels_flat, sampled], 0)
1709
1710    # Retrieve the true weights and the logits of the sampled weights.
1711
1712    # weights shape is [num_classes, dim]
1713    all_w = embedding_ops.embedding_lookup(
1714        weights, all_ids, partition_strategy=partition_strategy)
1715    if all_w.dtype != inputs.dtype:
1716      all_w = math_ops.cast(all_w, inputs.dtype)
1717
1718    # true_w shape is [batch_size * num_true, dim]
1719    true_w = array_ops.slice(all_w, [0, 0],
1720                             array_ops.stack(
1721                                 [array_ops.shape(labels_flat)[0], -1]))
1722
1723    sampled_w = array_ops.slice(
1724        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1725    # inputs has shape [batch_size, dim]
1726    # sampled_w has shape [num_sampled, dim]
1727    # Apply X*W', which yields [batch_size, num_sampled]
1728    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1729
1730    # Retrieve the true and sampled biases, compute the true logits, and
1731    # add the biases to the true and sampled logits.
1732    all_b = embedding_ops.embedding_lookup(
1733        biases, all_ids, partition_strategy=partition_strategy)
1734    if all_b.dtype != inputs.dtype:
1735      all_b = math_ops.cast(all_b, inputs.dtype)
1736    # true_b is a [batch_size * num_true] tensor
1737    # sampled_b is a [num_sampled] float tensor
1738    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1739    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1740
1741    # inputs shape is [batch_size, dim]
1742    # true_w shape is [batch_size * num_true, dim]
1743    # row_wise_dots is [batch_size, num_true, dim]
1744    dim = array_ops.shape(true_w)[1:2]
1745    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1746    row_wise_dots = math_ops.multiply(
1747        array_ops.expand_dims(inputs, 1),
1748        array_ops.reshape(true_w, new_true_w_shape))
1749    # We want the row-wise dot plus biases which yields a
1750    # [batch_size, num_true] tensor of true_logits.
1751    dots_as_matrix = array_ops.reshape(row_wise_dots,
1752                                       array_ops.concat([[-1], dim], 0))
1753    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1754    true_b = array_ops.reshape(true_b, [-1, num_true])
1755    true_logits += true_b
1756    sampled_logits += sampled_b
1757
1758    if remove_accidental_hits:
1759      acc_hits = candidate_sampling_ops.compute_accidental_hits(
1760          labels, sampled, num_true=num_true)
1761      acc_indices, acc_ids, acc_weights = acc_hits
1762
1763      # This is how SparseToDense expects the indices.
1764      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1765      acc_ids_2d_int32 = array_ops.reshape(
1766          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1767      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1768                                        "sparse_indices")
1769      # Create sampled_logits_shape = [batch_size, num_sampled]
1770      sampled_logits_shape = array_ops.concat(
1771          [array_ops.shape(labels)[:1],
1772           array_ops.expand_dims(num_sampled, 0)], 0)
1773      if sampled_logits.dtype != acc_weights.dtype:
1774        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1775      sampled_logits += gen_sparse_ops.sparse_to_dense(
1776          sparse_indices,
1777          sampled_logits_shape,
1778          acc_weights,
1779          default_value=0.0,
1780          validate_indices=False)
1781
1782    if subtract_log_q:
1783      # Subtract log of Q(l), prior probability that l appears in sampled.
1784      true_logits -= math_ops.log(true_expected_count)
1785      sampled_logits -= math_ops.log(sampled_expected_count)
1786
1787    # Construct output logits and labels. The true labels/logits start at col 0.
1788    out_logits = array_ops.concat([true_logits, sampled_logits], 1)
1789
1790    # true_logits is a float tensor, ones_like(true_logits) is a float
1791    # tensor of ones. We then divide by num_true to ensure the per-example
1792    # labels sum to 1.0, i.e. form a proper probability distribution.
1793    out_labels = array_ops.concat([
1794        array_ops.ones_like(true_logits) / num_true,
1795        array_ops.zeros_like(sampled_logits)
1796    ], 1)
1797
1798    return out_logits, out_labels
1799
1800
1801@tf_export("nn.nce_loss", v1=[])
1802def nce_loss_v2(weights,
1803                biases,
1804                labels,
1805                inputs,
1806                num_sampled,
1807                num_classes,
1808                num_true=1,
1809                sampled_values=None,
1810                remove_accidental_hits=False,
1811                name="nce_loss"):
1812  """Computes and returns the noise-contrastive estimation training loss.
1813
1814  See [Noise-contrastive estimation: A new estimation principle for
1815  unnormalized statistical
1816  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
1817  Also see our [Candidate Sampling Algorithms
1818  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
1819
1820  A common use case is to use this method for training, and calculate the full
1821  sigmoid loss for evaluation or inference as in the following example:
1822
1823  ```python
1824  if mode == "train":
1825    loss = tf.nn.nce_loss(
1826        weights=weights,
1827        biases=biases,
1828        labels=labels,
1829        inputs=inputs,
1830        ...)
1831  elif mode == "eval":
1832    logits = tf.matmul(inputs, tf.transpose(weights))
1833    logits = tf.nn.bias_add(logits, biases)
1834    labels_one_hot = tf.one_hot(labels, n_classes)
1835    loss = tf.nn.sigmoid_cross_entropy_with_logits(
1836        labels=labels_one_hot,
1837        logits=logits)
1838    loss = tf.reduce_sum(loss, axis=1)
1839  ```
1840
1841  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
1842  strategy will be used. Support for other partition strategy will be added
1843  later.
1844
1845  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
1846  so your labels must be sorted in order of decreasing frequency to achieve
1847  good results.  For more details, see
1848  `tf.random.log_uniform_candidate_sampler`.
1849
1850  Note: In the case where `num_true` > 1, we assign to each target class
1851  the target probability 1 / `num_true` so that the target probabilities
1852  sum to 1 per-example.
1853
1854  Note: It would be useful to allow a variable number of target classes per
1855  example.  We hope to provide this functionality in a future release.
1856  For now, if you have a variable number of target classes, you can pad them
1857  out to a constant number by either repeating them or by padding
1858  with an otherwise unused class.
1859
1860  Args:
1861    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1862      objects whose concatenation along dimension 0 has shape [num_classes,
1863      dim].  The (possibly-partitioned) class embeddings.
1864    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1865    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
1866      target classes.
1867    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
1868      the input network.
1869    num_sampled: An `int`.  The number of negative classes to randomly sample
1870      per batch. This single sample of negative classes is evaluated for each
1871      element in the batch.
1872    num_classes: An `int`. The number of possible classes.
1873    num_true: An `int`.  The number of target classes per training example.
1874    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1875      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1876      (if None, we default to `log_uniform_candidate_sampler`)
1877    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
1878      where a sampled class equals one of the target classes.  If set to `True`,
1879      this is a "Sampled Logistic" loss instead of NCE, and we are learning to
1880      generate log-odds instead of log probabilities.  See our [Candidate
1881      Sampling Algorithms Reference]
1882        (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
1883          False.
1884    name: A name for the operation (optional).
1885
1886  Returns:
1887    A `batch_size` 1-D tensor of per-example NCE losses.
1888  """
1889  # TODO(yuefengz): get partition_strategy from either variables or distribution
1890  # strategies.
1891  return nce_loss(
1892      weights,
1893      biases,
1894      labels,
1895      inputs,
1896      num_sampled,
1897      num_classes,
1898      num_true=num_true,
1899      sampled_values=sampled_values,
1900      remove_accidental_hits=remove_accidental_hits,
1901      partition_strategy="div",
1902      name=name)
1903
1904
1905@tf_export(v1=["nn.nce_loss"])
1906def nce_loss(weights,
1907             biases,
1908             labels,
1909             inputs,
1910             num_sampled,
1911             num_classes,
1912             num_true=1,
1913             sampled_values=None,
1914             remove_accidental_hits=False,
1915             partition_strategy="mod",
1916             name="nce_loss"):
1917  """Computes and returns the noise-contrastive estimation training loss.
1918
1919  A common use case is to use this method for training, and calculate the full
1920  sigmoid loss for evaluation or inference. In this case, you must set
1921  `partition_strategy="div"` for the two losses to be consistent, as in the
1922  following example:
1923
1924  ```python
1925  if mode == "train":
1926    loss = tf.nn.nce_loss(
1927        weights=weights,
1928        biases=biases,
1929        labels=labels,
1930        inputs=inputs,
1931        ...,
1932        partition_strategy="div")
1933  elif mode == "eval":
1934    logits = tf.matmul(inputs, tf.transpose(weights))
1935    logits = tf.nn.bias_add(logits, biases)
1936    labels_one_hot = tf.one_hot(labels, n_classes)
1937    loss = tf.nn.sigmoid_cross_entropy_with_logits(
1938        labels=labels_one_hot,
1939        logits=logits)
1940    loss = tf.reduce_sum(loss, axis=1)
1941  ```
1942
1943  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
1944  so your labels must be sorted in order of decreasing frequency to achieve
1945  good results.  For more details, see
1946  `tf.random.log_uniform_candidate_sampler`.
1947
1948  Note: In the case where `num_true` > 1, we assign to each target class
1949  the target probability 1 / `num_true` so that the target probabilities
1950  sum to 1 per-example.
1951
1952  Note: It would be useful to allow a variable number of target classes per
1953  example.  We hope to provide this functionality in a future release.
1954  For now, if you have a variable number of target classes, you can pad them
1955  out to a constant number by either repeating them or by padding
1956  with an otherwise unused class.
1957
1958  Args:
1959    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1960        objects whose concatenation along dimension 0 has shape
1961        [num_classes, dim].  The (possibly-partitioned) class embeddings.
1962    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
1963    labels: A `Tensor` of type `int64` and shape `[batch_size,
1964        num_true]`. The target classes.
1965    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1966        activations of the input network.
1967    num_sampled: An `int`.  The number of negative classes to randomly sample
1968        per batch. This single sample of negative classes is evaluated for each
1969        element in the batch.
1970    num_classes: An `int`. The number of possible classes.
1971    num_true: An `int`.  The number of target classes per training example.
1972    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1973        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1974        (if None, we default to `log_uniform_candidate_sampler`)
1975    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
1976        where a sampled class equals one of the target classes.  If set to
1977        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
1978        learning to generate log-odds instead of log probabilities. See
1979        our Candidate Sampling Algorithms Reference
1980        ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
1981        Default is False.
1982    partition_strategy: A string specifying the partitioning strategy, relevant
1983        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1984        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1985    name: A name for the operation (optional).
1986
1987  Returns:
1988    A `batch_size` 1-D tensor of per-example NCE losses.
1989
1990  References:
1991    Noise-contrastive estimation - A new estimation principle for unnormalized
1992    statistical models:
1993      [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
1994      ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
1995  """
1996  logits, labels = _compute_sampled_logits(
1997      weights=weights,
1998      biases=biases,
1999      labels=labels,
2000      inputs=inputs,
2001      num_sampled=num_sampled,
2002      num_classes=num_classes,
2003      num_true=num_true,
2004      sampled_values=sampled_values,
2005      subtract_log_q=True,
2006      remove_accidental_hits=remove_accidental_hits,
2007      partition_strategy=partition_strategy,
2008      name=name)
2009  sampled_losses = sigmoid_cross_entropy_with_logits(
2010      labels=labels, logits=logits, name="sampled_losses")
2011  # sampled_losses is batch_size x {true_loss, sampled_losses...}
2012  # We sum out true and sampled losses.
2013  return _sum_rows(sampled_losses)
2014
2015
2016@tf_export("nn.sampled_softmax_loss", v1=[])
2017def sampled_softmax_loss_v2(weights,
2018                            biases,
2019                            labels,
2020                            inputs,
2021                            num_sampled,
2022                            num_classes,
2023                            num_true=1,
2024                            sampled_values=None,
2025                            remove_accidental_hits=True,
2026                            seed=None,
2027                            name="sampled_softmax_loss"):
2028  """Computes and returns the sampled softmax training loss.
2029
2030  This is a faster way to train a softmax classifier over a huge number of
2031  classes.
2032
2033  This operation is for training only.  It is generally an underestimate of
2034  the full softmax loss.
2035
2036  A common use case is to use this method for training, and calculate the full
2037  sigmoid loss for evaluation or inference as in the following example:
2038
2039  ```python
2040  if mode == "train":
2041    loss = tf.nn.sampled_softmax_loss(
2042        weights=weights,
2043        biases=biases,
2044        labels=labels,
2045        inputs=inputs,
2046        ...)
2047  elif mode == "eval":
2048    logits = tf.matmul(inputs, tf.transpose(weights))
2049    logits = tf.nn.bias_add(logits, biases)
2050    labels_one_hot = tf.one_hot(labels, n_classes)
2051    loss = tf.nn.softmax_cross_entropy_with_logits(
2052        labels=labels_one_hot,
2053        logits=logits)
2054  ```
2055
2056  See our [Candidate Sampling Algorithms Reference]
2057  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
2058
2059  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
2060  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
2061
2062  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2063  strategy will be used. Support for other partition strategy will be added
2064  later.
2065
2066  Args:
2067    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2068      objects whose concatenation along dimension 0 has shape [num_classes,
2069      dim].  The (possibly-sharded) class embeddings.
2070    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2071    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2072      target classes.  Note that this format differs from the `labels` argument
2073      of `nn.softmax_cross_entropy_with_logits`.
2074    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2075      the input network.
2076    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2077    num_classes: An `int`. The number of possible classes.
2078    num_true: An `int`.  The number of target classes per training example.
2079    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2080      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2081      (if None, we default to `log_uniform_candidate_sampler`)
2082    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2083      where a sampled class equals one of the target classes.  Default is True.
2084    seed: random seed for candidate sampling. Default to None, which doesn't set
2085      the op-level random seed for candidate sampling.
2086    name: A name for the operation (optional).
2087
2088  Returns:
2089    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2090
2091  """
2092  return sampled_softmax_loss(
2093      weights,
2094      biases,
2095      labels,
2096      inputs,
2097      num_sampled,
2098      num_classes,
2099      num_true=num_true,
2100      sampled_values=sampled_values,
2101      remove_accidental_hits=remove_accidental_hits,
2102      partition_strategy="div",
2103      name=name,
2104      seed=seed)
2105
2106
2107@tf_export(v1=["nn.sampled_softmax_loss"])
2108def sampled_softmax_loss(weights,
2109                         biases,
2110                         labels,
2111                         inputs,
2112                         num_sampled,
2113                         num_classes,
2114                         num_true=1,
2115                         sampled_values=None,
2116                         remove_accidental_hits=True,
2117                         partition_strategy="mod",
2118                         name="sampled_softmax_loss",
2119                         seed=None):
2120  """Computes and returns the sampled softmax training loss.
2121
2122  This is a faster way to train a softmax classifier over a huge number of
2123  classes.
2124
2125  This operation is for training only.  It is generally an underestimate of
2126  the full softmax loss.
2127
2128  A common use case is to use this method for training, and calculate the full
2129  softmax loss for evaluation or inference. In this case, you must set
2130  `partition_strategy="div"` for the two losses to be consistent, as in the
2131  following example:
2132
2133  ```python
2134  if mode == "train":
2135    loss = tf.nn.sampled_softmax_loss(
2136        weights=weights,
2137        biases=biases,
2138        labels=labels,
2139        inputs=inputs,
2140        ...,
2141        partition_strategy="div")
2142  elif mode == "eval":
2143    logits = tf.matmul(inputs, tf.transpose(weights))
2144    logits = tf.nn.bias_add(logits, biases)
2145    labels_one_hot = tf.one_hot(labels, n_classes)
2146    loss = tf.nn.softmax_cross_entropy_with_logits(
2147        labels=labels_one_hot,
2148        logits=logits)
2149  ```
2150
2151  See our Candidate Sampling Algorithms Reference
2152  ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2153  Also see Section 3 of (Jean et al., 2014) for the math.
2154
2155  Args:
2156    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2157        objects whose concatenation along dimension 0 has shape
2158        [num_classes, dim].  The (possibly-sharded) class embeddings.
2159    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2160    labels: A `Tensor` of type `int64` and shape `[batch_size,
2161        num_true]`. The target classes.  Note that this format differs from
2162        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
2163    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2164        activations of the input network.
2165    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2166    num_classes: An `int`. The number of possible classes.
2167    num_true: An `int`.  The number of target classes per training example.
2168    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2169        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2170        (if None, we default to `log_uniform_candidate_sampler`)
2171    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2172        where a sampled class equals one of the target classes.  Default is
2173        True.
2174    partition_strategy: A string specifying the partitioning strategy, relevant
2175        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2176        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2177    name: A name for the operation (optional).
2178    seed: random seed for candidate sampling. Default to None, which doesn't set
2179        the op-level random seed for candidate sampling.
2180
2181  Returns:
2182    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2183
2184  References:
2185    On Using Very Large Target Vocabulary for Neural Machine Translation:
2186      [Jean et al., 2014]
2187      (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
2188      ([pdf](http://aclweb.org/anthology/P15-1001))
2189  """
2190  logits, labels = _compute_sampled_logits(
2191      weights=weights,
2192      biases=biases,
2193      labels=labels,
2194      inputs=inputs,
2195      num_sampled=num_sampled,
2196      num_classes=num_classes,
2197      num_true=num_true,
2198      sampled_values=sampled_values,
2199      subtract_log_q=True,
2200      remove_accidental_hits=remove_accidental_hits,
2201      partition_strategy=partition_strategy,
2202      name=name,
2203      seed=seed)
2204  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
2205  sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
2206      labels=labels, logits=logits)
2207  # sampled_losses is a [batch_size] tensor.
2208  return sampled_losses
2209