1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Implementation of Neural Net (NN) functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23from tensorflow.python.distribute import distribution_strategy_context as ds 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import candidate_sampling_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import custom_gradient 31from tensorflow.python.ops import embedding_ops 32from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import 33from tensorflow.python.ops import gen_nn_ops 34from tensorflow.python.ops import gen_sparse_ops 35from tensorflow.python.ops import linalg_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import nn_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.ops.losses import util as losses_util 40from tensorflow.python.platform import device_context 41from tensorflow.python.util.deprecation import deprecated_args 42from tensorflow.python.util.deprecation import deprecated_argument_lookup 43from tensorflow.python.util.tf_export import tf_export 44 45 46@tf_export("nn.log_poisson_loss") 47def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): 48 """Computes log Poisson loss given `log_input`. 49 50 Gives the log-likelihood loss between the prediction and the target under the 51 assumption that the target has a Poisson distribution. 52 Caveat: By default, this is not the exact loss, but the loss minus a 53 constant term [log(z!)]. That has no effect for optimization, but 54 does not play well with relative loss comparisons. To compute an 55 approximation of the log factorial term, specify 56 compute_full_loss=True to enable Stirling's Approximation. 57 58 For brevity, let `c = log(x) = log_input`, `z = targets`. The log Poisson 59 loss is 60 61 -log(exp(-x) * (x^z) / z!) 62 = -log(exp(-x) * (x^z)) + log(z!) 63 ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 64 [ Note the second term is the Stirling's Approximation for log(z!). 65 It is invariant to x and does not affect optimization, though 66 important for correct relative loss comparisons. It is only 67 computed when compute_full_loss == True. ] 68 = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 69 = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 70 71 Args: 72 targets: A `Tensor` of the same type and shape as `log_input`. 73 log_input: A `Tensor` of type `float32` or `float64`. 74 compute_full_loss: whether to compute the full loss. If false, a constant 75 term is dropped in favor of more efficient optimization. 76 name: A name for the operation (optional). 77 78 Returns: 79 A `Tensor` of the same shape as `log_input` with the componentwise 80 logistic losses. 81 82 Raises: 83 ValueError: If `log_input` and `targets` do not have the same shape. 84 """ 85 with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name: 86 log_input = ops.convert_to_tensor(log_input, name="log_input") 87 targets = ops.convert_to_tensor(targets, name="targets") 88 try: 89 targets.get_shape().merge_with(log_input.get_shape()) 90 except ValueError: 91 raise ValueError( 92 "log_input and targets must have the same shape (%s vs %s)" % 93 (log_input.get_shape(), targets.get_shape())) 94 95 result = math_ops.exp(log_input) - log_input * targets 96 if compute_full_loss: 97 # need to create constant tensors here so that their dtypes can be matched 98 # to that of the targets. 99 point_five = constant_op.constant(0.5, dtype=targets.dtype) 100 two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype) 101 102 stirling_approx = (targets * math_ops.log(targets)) - targets + ( 103 point_five * math_ops.log(two_pi * targets)) 104 zeros = array_ops.zeros_like(targets, dtype=targets.dtype) 105 ones = array_ops.ones_like(targets, dtype=targets.dtype) 106 cond = math_ops.logical_and(targets >= zeros, targets <= ones) 107 result += array_ops.where(cond, zeros, stirling_approx) 108 return result 109 110 111@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"]) 112def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name 113 _sentinel=None, 114 labels=None, 115 logits=None, 116 name=None): 117 """Computes sigmoid cross entropy given `logits`. 118 119 Measures the probability error in discrete classification tasks in which each 120 class is independent and not mutually exclusive. For instance, one could 121 perform multilabel classification where a picture can contain both an elephant 122 and a dog at the same time. 123 124 For brevity, let `x = logits`, `z = labels`. The logistic loss is 125 126 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 127 = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 128 = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 129 = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 130 = (1 - z) * x + log(1 + exp(-x)) 131 = x - x * z + log(1 + exp(-x)) 132 133 For x < 0, to avoid overflow in exp(-x), we reformulate the above 134 135 x - x * z + log(1 + exp(-x)) 136 = log(exp(x)) - x * z + log(1 + exp(-x)) 137 = - x * z + log(1 + exp(x)) 138 139 Hence, to ensure stability and avoid overflow, the implementation uses this 140 equivalent formulation 141 142 max(x, 0) - x * z + log(1 + exp(-abs(x))) 143 144 `logits` and `labels` must have the same type and shape. 145 146 Args: 147 _sentinel: Used to prevent positional parameters. Internal, do not use. 148 labels: A `Tensor` of the same type and shape as `logits`. 149 logits: A `Tensor` of type `float32` or `float64`. 150 name: A name for the operation (optional). 151 152 Returns: 153 A `Tensor` of the same shape as `logits` with the componentwise 154 logistic losses. 155 156 Raises: 157 ValueError: If `logits` and `labels` do not have the same shape. 158 """ 159 # pylint: disable=protected-access 160 nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel, 161 labels, logits) 162 # pylint: enable=protected-access 163 164 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name: 165 logits = ops.convert_to_tensor(logits, name="logits") 166 labels = ops.convert_to_tensor(labels, name="labels") 167 try: 168 labels.get_shape().merge_with(logits.get_shape()) 169 except ValueError: 170 raise ValueError("logits and labels must have the same shape (%s vs %s)" % 171 (logits.get_shape(), labels.get_shape())) 172 173 # The logistic loss formula from above is 174 # x - x * z + log(1 + exp(-x)) 175 # For x < 0, a more numerically stable formula is 176 # -x * z + log(1 + exp(x)) 177 # Note that these two expressions can be combined into the following: 178 # max(x, 0) - x * z + log(1 + exp(-abs(x))) 179 # To allow computing gradients at zero, we define custom versions of max and 180 # abs functions. 181 zeros = array_ops.zeros_like(logits, dtype=logits.dtype) 182 cond = (logits >= zeros) 183 relu_logits = array_ops.where(cond, logits, zeros) 184 neg_abs_logits = array_ops.where(cond, -logits, logits) 185 return math_ops.add( 186 relu_logits - logits * labels, 187 math_ops.log1p(math_ops.exp(neg_abs_logits)), 188 name=name) 189 190 191# Note: intentionally calling this v2 to not allow existing code with indirect 192# imports to ignore the sentinel behavior. 193@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[]) 194def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name 195 labels=None, 196 logits=None, 197 name=None): 198 """Computes sigmoid cross entropy given `logits`. 199 200 Measures the probability error in discrete classification tasks in which each 201 class is independent and not mutually exclusive. For instance, one could 202 perform multilabel classification where a picture can contain both an elephant 203 and a dog at the same time. 204 205 For brevity, let `x = logits`, `z = labels`. The logistic loss is 206 207 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 208 = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 209 = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 210 = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 211 = (1 - z) * x + log(1 + exp(-x)) 212 = x - x * z + log(1 + exp(-x)) 213 214 For x < 0, to avoid overflow in exp(-x), we reformulate the above 215 216 x - x * z + log(1 + exp(-x)) 217 = log(exp(x)) - x * z + log(1 + exp(-x)) 218 = - x * z + log(1 + exp(x)) 219 220 Hence, to ensure stability and avoid overflow, the implementation uses this 221 equivalent formulation 222 223 max(x, 0) - x * z + log(1 + exp(-abs(x))) 224 225 `logits` and `labels` must have the same type and shape. 226 227 Args: 228 labels: A `Tensor` of the same type and shape as `logits`. 229 logits: A `Tensor` of type `float32` or `float64`. 230 name: A name for the operation (optional). 231 232 Returns: 233 A `Tensor` of the same shape as `logits` with the componentwise 234 logistic losses. 235 236 Raises: 237 ValueError: If `logits` and `labels` do not have the same shape. 238 """ 239 return sigmoid_cross_entropy_with_logits( 240 logits=logits, labels=labels, name=name) 241 242 243@tf_export("nn.weighted_cross_entropy_with_logits", v1=[]) 244def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, 245 name=None): 246 """Computes a weighted cross entropy. 247 248 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`, 249 allows one to trade off recall and precision by up- or down-weighting the 250 cost of a positive error relative to a negative error. 251 252 The usual cross-entropy cost is defined as: 253 254 labels * -log(sigmoid(logits)) + 255 (1 - labels) * -log(1 - sigmoid(logits)) 256 257 A value `pos_weight > 1` decreases the false negative count, hence increasing 258 the recall. 259 Conversely setting `pos_weight < 1` decreases the false positive count and 260 increases the precision. 261 This can be seen from the fact that `pos_weight` is introduced as a 262 multiplicative coefficient for the positive labels term 263 in the loss expression: 264 265 labels * -log(sigmoid(logits)) * pos_weight + 266 (1 - labels) * -log(1 - sigmoid(logits)) 267 268 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`. 269 The loss is: 270 271 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 272 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 273 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 274 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 275 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x)) 276 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 277 278 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow, 279 the implementation uses 280 281 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 282 283 `logits` and `labels` must have the same type and shape. 284 285 Args: 286 labels: A `Tensor` of the same type and shape as `logits`. 287 logits: A `Tensor` of type `float32` or `float64`. 288 pos_weight: A coefficient to use on the positive examples. 289 name: A name for the operation (optional). 290 291 Returns: 292 A `Tensor` of the same shape as `logits` with the componentwise 293 weighted logistic losses. 294 295 Raises: 296 ValueError: If `logits` and `labels` do not have the same shape. 297 """ 298 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name: 299 logits = ops.convert_to_tensor(logits, name="logits") 300 labels = ops.convert_to_tensor(labels, name="labels") 301 try: 302 labels.get_shape().merge_with(logits.get_shape()) 303 except ValueError: 304 raise ValueError("logits and labels must have the same shape (%s vs %s)" % 305 (logits.get_shape(), labels.get_shape())) 306 307 # The logistic loss formula from above is 308 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 309 # For x < 0, a more numerically stable formula is 310 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x 311 # To avoid branching, we use the combined version 312 # (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 313 log_weight = 1 + (pos_weight - 1) * labels 314 return math_ops.add( 315 (1 - labels) * logits, 316 log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) + 317 nn_ops.relu(-logits)), 318 name=name) 319 320 321@tf_export(v1=["nn.weighted_cross_entropy_with_logits"]) 322@deprecated_args(None, "targets is deprecated, use labels instead", "targets") 323def weighted_cross_entropy_with_logits(labels=None, 324 logits=None, 325 pos_weight=None, 326 name=None, 327 targets=None): 328 """Computes a weighted cross entropy. 329 330 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`, 331 allows one to trade off recall and precision by up- or down-weighting the 332 cost of a positive error relative to a negative error. 333 334 The usual cross-entropy cost is defined as: 335 336 labels * -log(sigmoid(logits)) + 337 (1 - labels) * -log(1 - sigmoid(logits)) 338 339 A value `pos_weight > 1` decreases the false negative count, hence increasing 340 the recall. 341 Conversely setting `pos_weight < 1` decreases the false positive count and 342 increases the precision. 343 This can be seen from the fact that `pos_weight` is introduced as a 344 multiplicative coefficient for the positive labels term 345 in the loss expression: 346 347 labels * -log(sigmoid(logits)) * pos_weight + 348 (1 - labels) * -log(1 - sigmoid(logits)) 349 350 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`. 351 The loss is: 352 353 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 354 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 355 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 356 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 357 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x)) 358 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 359 360 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow, 361 the implementation uses 362 363 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 364 365 `logits` and `labels` must have the same type and shape. 366 367 Args: 368 labels: A `Tensor` of the same type and shape as `logits`. 369 logits: A `Tensor` of type `float32` or `float64`. 370 pos_weight: A coefficient to use on the positive examples. 371 name: A name for the operation (optional). 372 targets: Deprecated alias for labels. 373 374 Returns: 375 A `Tensor` of the same shape as `logits` with the componentwise 376 weighted logistic losses. 377 378 Raises: 379 ValueError: If `logits` and `labels` do not have the same shape. 380 """ 381 labels = deprecated_argument_lookup("labels", labels, "targets", targets) 382 return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name) 383 384 385@tf_export("nn.compute_average_loss") 386def compute_average_loss(per_example_loss, 387 sample_weight=None, 388 global_batch_size=None): 389 """Scales per-example losses with sample_weights and computes their average. 390 391 Usage with distribution strategy and custom training loop: 392 393 ```python 394 with strategy.scope(): 395 def compute_loss(labels, predictions, sample_weight=None): 396 397 # If you are using a `Loss` class instead, set reduction to `NONE` so that 398 # we can do the reduction afterwards and divide by global batch size. 399 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy( 400 labels, predictions) 401 402 # Compute loss that is scaled by sample_weight and by global batch size. 403 return tf.compute_average_loss( 404 per_example_loss, 405 sample_weight=sample_weight, 406 global_batch_size=GLOBAL_BATCH_SIZE) 407 ``` 408 409 Args: 410 per_example_loss: Per-example loss. 411 sample_weight: Optional weighting for each example. 412 global_batch_size: Optional global batch size value. Defaults to (size of 413 first dimension of `losses`) * (number of replicas). 414 415 Returns: 416 Scalar loss value. 417 """ # pylint: disable=g-doc-exception 418 per_example_loss = ops.convert_to_tensor(per_example_loss) 419 input_dtype = per_example_loss.dtype 420 421 with losses_util.check_per_example_loss_rank(per_example_loss): 422 if sample_weight is not None: 423 per_example_loss = losses_util.scale_losses_by_sample_weight( 424 per_example_loss, sample_weight) 425 per_example_loss = math_ops.cast(per_example_loss, input_dtype) 426 427 if global_batch_size is None: 428 if ds.has_strategy() and ds.in_cross_replica_context(): 429 raise RuntimeError( 430 "You are calling `compute_average_loss` in cross replica context, " 431 "while it was expected to be called in replica context.") 432 433 num_replicas = ds.get_strategy().num_replicas_in_sync 434 per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0] 435 global_batch_size = per_replica_batch_size * num_replicas 436 global_batch_size = math_ops.cast(global_batch_size, input_dtype) 437 438 return math_ops.reduce_sum(per_example_loss) / global_batch_size 439 440 441@tf_export("nn.scale_regularization_loss") 442def scale_regularization_loss(regularization_loss): 443 """Scales the sum of the given regularization losses by number of replicas. 444 445 Usage with distribution strategy and custom training loop: 446 447 ```python 448 with strategy.scope(): 449 def compute_loss(self, label, predictions): 450 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy( 451 labels, predictions) 452 453 # Compute loss that is scaled by sample_weight and by global batch size. 454 loss = tf.compute_average_loss( 455 per_example_loss, 456 sample_weight=sample_weight, 457 global_batch_size=GLOBAL_BATCH_SIZE) 458 459 # Add scaled regularization losses. 460 loss += tf.scale_regularization_loss(tf.nn.l2_loss(weights)) 461 return loss 462 ``` 463 464 Args: 465 regularization_loss: Regularization loss. 466 467 Returns: 468 Scalar loss value. 469 """ # pylint: disable=g-doc-exception 470 if ds.has_strategy() and ds.in_cross_replica_context(): 471 raise RuntimeError( 472 "You are calling `scale_regularization_loss` in cross replica context, " 473 "while it was expected to be called in replica context.") 474 475 num_replicas = ds.get_strategy().num_replicas_in_sync 476 return math_ops.reduce_sum(regularization_loss) / num_replicas 477 478 479@tf_export(v1=["nn.relu_layer"]) 480def relu_layer(x, weights, biases, name=None): 481 """Computes Relu(x * weight + biases). 482 483 Args: 484 x: a 2D tensor. Dimensions typically: batch, in_units 485 weights: a 2D tensor. Dimensions typically: in_units, out_units 486 biases: a 1D tensor. Dimensions: out_units 487 name: A name for the operation (optional). If not specified 488 "nn_relu_layer" is used. 489 490 Returns: 491 A 2-D Tensor computing relu(matmul(x, weights) + biases). 492 Dimensions typically: batch, out_units. 493 """ 494 with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name: 495 x = ops.convert_to_tensor(x, name="x") 496 weights = ops.convert_to_tensor(weights, name="weights") 497 biases = ops.convert_to_tensor(biases, name="biases") 498 xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases) 499 return nn_ops.relu(xw_plus_b, name=name) 500 501 502@tf_export("nn.swish") 503@custom_gradient.custom_gradient 504def swish(features): 505 # pylint: disable=g-doc-args 506 """Computes the Swish activation function: `x * sigmoid(x)`. 507 508 Source: "Searching for Activation Functions" (Ramachandran et al. 2017) 509 https://arxiv.org/abs/1710.05941 510 511 Args: 512 features: A `Tensor` representing preactivation values. 513 name: A name for the operation (optional). 514 515 Returns: 516 The activation value. 517 """ 518 # pylint: enable=g-doc-args 519 features = ops.convert_to_tensor(features, name="features") 520 521 def grad(dy): 522 """Gradient for the Swish activation function""" 523 # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) 524 # around for backprop, effectively doubling the tensor's memory consumption. 525 # We use a control dependency here so that sigmoid(features) is re-computed 526 # during backprop (the control dep prevents it being de-duped with the 527 # forward pass) and we can free the sigmoid(features) expression immediately 528 # after use during the forward pass. 529 with ops.control_dependencies([dy]): 530 sigmoid_features = math_ops.sigmoid(features) 531 activation_grad = ( 532 sigmoid_features * (1.0 + features * (1.0 - sigmoid_features))) 533 return dy * activation_grad 534 535 return features * math_ops.sigmoid(features), grad 536 537 538# pylint: disable=redefined-builtin 539@tf_export("linalg.normalize") 540def normalize(tensor, ord="euclidean", axis=None, name=None): 541 """Normalizes `tensor` along dimension `axis` using specified norm. 542 543 This uses `tf.linalg.norm` to compute the norm along `axis`. 544 545 This function can compute several different vector norms (the 1-norm, the 546 Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and 547 matrix norms (Frobenius, 1-norm, 2-norm and inf-norm). 548 549 Args: 550 tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128` 551 ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`, 552 `2`, `np.inf` and any positive real number yielding the corresponding 553 p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if 554 `tensor` is a matrix and equivalent to 2-norm for vectors. 555 Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for 556 vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`, 557 '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis` 558 on how to compute norms for a batch of vectors or matrices stored in a 559 tensor. 560 axis: If `axis` is `None` (the default), the input is considered a vector 561 and a single vector norm is computed over the entire set of values in the 562 tensor, i.e. `norm(tensor, ord=ord)` is equivalent to 563 `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the 564 input is considered a batch of vectors, and `axis` determines the axis in 565 `tensor` over which to compute vector norms. If `axis` is a 2-tuple of 566 Python integers it is considered a batch of matrices and `axis` determines 567 the axes in `tensor` over which to compute a matrix norm. 568 Negative indices are supported. Example: If you are passing a tensor that 569 can be either a matrix or a batch of matrices at runtime, pass 570 `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are 571 computed. 572 name: The name of the op. 573 574 Returns: 575 normalized: A normalized `Tensor` with the same shape as `tensor`. 576 norm: The computed norms with the same shape and dtype `tensor` but the 577 final axis is 1 instead. Same as running 578 `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`. 579 580 Raises: 581 ValueError: If `ord` or `axis` is invalid. 582 """ 583 with ops.name_scope(name, "normalize", [tensor]) as name: 584 tensor = ops.convert_to_tensor(tensor) 585 norm = linalg_ops.norm(tensor, ord, axis, keepdims=True) 586 norm = math_ops.cast(norm, tensor.dtype) 587 normalized = tensor / norm 588 return normalized, norm 589 590 591@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) 592@deprecated_args(None, "dim is deprecated, use axis instead", "dim") 593def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): 594 """Normalizes along dimension `axis` using an L2 norm. 595 596 For a 1-D tensor with `axis = 0`, computes 597 598 output = x / sqrt(max(sum(x**2), epsilon)) 599 600 For `x` with more dimensions, independently normalizes each 1-D slice along 601 dimension `axis`. 602 603 Args: 604 x: A `Tensor`. 605 axis: Dimension along which to normalize. A scalar or a vector of 606 integers. 607 epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the 608 divisor if `norm < sqrt(epsilon)`. 609 name: A name for this operation (optional). 610 dim: Deprecated alias for axis. 611 612 Returns: 613 A `Tensor` with the same shape as `x`. 614 """ 615 axis = deprecated_argument_lookup("axis", axis, "dim", dim) 616 return l2_normalize_v2(x, axis, epsilon, name) 617 618 619@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) 620def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): 621 """Normalizes along dimension `axis` using an L2 norm. 622 623 For a 1-D tensor with `axis = 0`, computes 624 625 output = x / sqrt(max(sum(x**2), epsilon)) 626 627 For `x` with more dimensions, independently normalizes each 1-D slice along 628 dimension `axis`. 629 630 Args: 631 x: A `Tensor`. 632 axis: Dimension along which to normalize. A scalar or a vector of 633 integers. 634 epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the 635 divisor if `norm < sqrt(epsilon)`. 636 name: A name for this operation (optional). 637 638 Returns: 639 A `Tensor` with the same shape as `x`. 640 """ 641 with ops.name_scope(name, "l2_normalize", [x]) as name: 642 x = ops.convert_to_tensor(x, name="x") 643 square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) 644 x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) 645 return math_ops.multiply(x, x_inv_norm, name=name) 646 647 648def _count_nonzero(input_tensor, dtype=dtypes.int64): 649 """Same as math_ops.count_nonzero. 650 651 The reduction is done in dtype, which can be faster for 32-bit dtypes. 652 653 Args: 654 input_tensor: numeric tensor 655 dtype: reduction dtype 656 657 Returns: 658 number of nonzero values with type dtype 659 """ 660 with ops.name_scope("count_nonzero", values=[input_tensor]): 661 zero = array_ops.zeros([], dtype=input_tensor.dtype) 662 nonzero_count = math_ops.reduce_sum( 663 math_ops.cast( 664 math_ops.not_equal(input_tensor, zero), 665 dtype=dtype), name="nonzero_count") 666 return nonzero_count 667 668 669@tf_export("math.zero_fraction", "nn.zero_fraction") 670def zero_fraction(value, name=None): 671 """Returns the fraction of zeros in `value`. 672 673 If `value` is empty, the result is `nan`. 674 675 This is useful in summaries to measure and report sparsity. For example, 676 677 ```python 678 z = tf.nn.relu(...) 679 summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z)) 680 ``` 681 682 Args: 683 value: A tensor of numeric type. 684 name: A name for the operation (optional). 685 686 Returns: 687 The fraction of zeros in `value`, with type `float32`. 688 """ 689 with ops.name_scope(name, "zero_fraction", [value]): 690 value = ops.convert_to_tensor(value, name="value") 691 size = array_ops.size(value, out_type=dtypes.int64) 692 # If the count is small, we can save memory/CPU with an int32 reduction. 693 num_nonzero = control_flow_ops.cond( 694 size <= dtypes.int32.max, 695 # pylint: disable=g-long-lambda 696 true_fn=lambda: math_ops.cast( 697 _count_nonzero(value, dtype=dtypes.int32), 698 dtype=dtypes.int64), 699 false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64)) 700 701 with ops.name_scope("counts_to_fraction"): 702 num_zero = size - num_nonzero 703 num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32) 704 size_float32 = math_ops.cast(size, dtype=dtypes.float32) 705 zero_fraction_float32 = num_zero_float32 / size_float32 706 707 return array_ops.identity(zero_fraction_float32, "fraction") 708 709 710# pylint: disable=redefined-builtin 711@tf_export(v1=["nn.depthwise_conv2d"]) 712def depthwise_conv2d(input, 713 filter, 714 strides, 715 padding, 716 rate=None, 717 name=None, 718 data_format=None, 719 dilations=None): 720 """Depthwise 2-D convolution. 721 722 Given a 4D input tensor ('NHWC' or 'NCHW' data formats) 723 and a filter tensor of shape 724 `[filter_height, filter_width, in_channels, channel_multiplier]` 725 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d` 726 applies a different filter to each input channel (expanding from 1 channel 727 to `channel_multiplier` channels for each), then concatenates the results 728 together. The output has `in_channels * channel_multiplier` channels. 729 730 In detail, with the default NHWC format, 731 732 output[b, i, j, k * channel_multiplier + q] = sum_{di, dj} 733 filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di, 734 strides[2] * j + rate[1] * dj, k] 735 736 Must have `strides[0] = strides[3] = 1`. For the most common case of the 737 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 738 If any value in `rate` is greater than 1, we perform atrous depthwise 739 convolution, in which case all values in the `strides` tensor must be equal 740 to 1. 741 742 Args: 743 input: 4-D with shape according to `data_format`. 744 filter: 4-D with shape 745 `[filter_height, filter_width, in_channels, channel_multiplier]`. 746 strides: 1-D of size 4. The stride of the sliding window for each 747 dimension of `input`. 748 padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. 749 See the "returns" section of `tf.nn.convolution` for details. 750 rate: 1-D of size 2. The dilation rate in which we sample input values 751 across the `height` and `width` dimensions in atrous convolution. If it is 752 greater than 1, then all values of strides must be 1. 753 name: A name for this operation (optional). 754 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 755 dilations: Alias of rate. 756 757 Returns: 758 A 4-D `Tensor` with shape according to `data_format`. E.g., for 759 "NHWC" format, shape is 760 `[batch, out_height, out_width, in_channels * channel_multiplier].` 761 """ 762 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate) 763 with ops.name_scope(name, "depthwise", [input, filter]) as name: 764 input = ops.convert_to_tensor(input, name="tensor_in") 765 filter = ops.convert_to_tensor(filter, name="filter_in") 766 if rate is None: 767 rate = [1, 1] 768 769 # Use depthwise_conv2d_native if executing on TPU. 770 if device_context.enclosing_tpu_context() is not None: 771 if data_format == "NCHW": 772 dilations = [1, 1, rate[0], rate[1]] 773 else: 774 dilations = [1, rate[0], rate[1], 1] 775 return nn_ops.depthwise_conv2d_native( 776 input=input, 777 filter=filter, 778 strides=strides, 779 padding=padding, 780 data_format=data_format, 781 dilations=dilations, 782 name=name) 783 784 def op(input_converted, _, padding): 785 return nn_ops.depthwise_conv2d_native( 786 input=input_converted, 787 filter=filter, 788 strides=strides, 789 padding=padding, 790 data_format=data_format, 791 name=name) 792 793 return nn_ops.with_space_to_batch( 794 input=input, 795 filter_shape=array_ops.shape(filter), 796 dilation_rate=rate, 797 padding=padding, 798 data_format=data_format, 799 op=op) 800 801 802@tf_export("nn.depthwise_conv2d", v1=[]) 803def depthwise_conv2d_v2(input, 804 filter, 805 strides, 806 padding, 807 data_format=None, 808 dilations=None, 809 name=None): 810 """Depthwise 2-D convolution. 811 812 Given a 4D input tensor ('NHWC' or 'NCHW' data formats) 813 and a filter tensor of shape 814 `[filter_height, filter_width, in_channels, channel_multiplier]` 815 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d` 816 applies a different filter to each input channel (expanding from 1 channel 817 to `channel_multiplier` channels for each), then concatenates the results 818 together. The output has `in_channels * channel_multiplier` channels. 819 820 In detail, with the default NHWC format, 821 822 output[b, i, j, k * channel_multiplier + q] = sum_{di, dj} 823 filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di, 824 strides[2] * j + rate[1] * dj, k] 825 826 Must have `strides[0] = strides[3] = 1`. For the most common case of the 827 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 828 If any value in `rate` is greater than 1, we perform atrous depthwise 829 convolution, in which case all values in the `strides` tensor must be equal 830 to 1. 831 832 Args: 833 input: 4-D with shape according to `data_format`. 834 filter: 4-D with shape 835 `[filter_height, filter_width, in_channels, channel_multiplier]`. 836 strides: 1-D of size 4. The stride of the sliding window for each 837 dimension of `input`. 838 padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. 839 See the "returns" section of `tf.nn.convolution` for details. 840 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 841 dilations: 1-D of size 2. The dilation rate in which we sample input values 842 across the `height` and `width` dimensions in atrous convolution. If it is 843 greater than 1, then all values of strides must be 1. 844 name: A name for this operation (optional). 845 846 Returns: 847 A 4-D `Tensor` with shape according to `data_format`. E.g., for 848 "NHWC" format, shape is 849 `[batch, out_height, out_width, in_channels * channel_multiplier].` 850 """ 851 return depthwise_conv2d(input=input, 852 filter=filter, 853 strides=strides, 854 padding=padding, 855 rate=dilations, 856 name=name, 857 data_format=data_format) 858 859# pylint: enable=redefined-builtin 860 861 862# pylint: disable=redefined-builtin,line-too-long 863@tf_export(v1=["nn.separable_conv2d"]) 864def separable_conv2d(input, 865 depthwise_filter, 866 pointwise_filter, 867 strides, 868 padding, 869 rate=None, 870 name=None, 871 data_format=None, 872 dilations=None): 873 """2-D convolution with separable filters. 874 875 Performs a depthwise convolution that acts separately on channels followed by 876 a pointwise convolution that mixes channels. Note that this is separability 877 between dimensions `[1, 2]` and `3`, not spatial separability between 878 dimensions `1` and `2`. 879 880 In detail, with the default NHWC format, 881 882 output[b, i, j, k] = sum_{di, dj, q, r} 883 input[b, strides[1] * i + di, strides[2] * j + dj, q] * 884 depthwise_filter[di, dj, q, r] * 885 pointwise_filter[0, 0, q * channel_multiplier + r, k] 886 887 `strides` controls the strides for the depthwise convolution only, since 888 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have 889 `strides[0] = strides[3] = 1`. For the most common case of the same 890 horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 891 If any value in `rate` is greater than 1, we perform atrous depthwise 892 convolution, in which case all values in the `strides` tensor must be equal 893 to 1. 894 895 Args: 896 input: 4-D `Tensor` with shape according to `data_format`. 897 depthwise_filter: 4-D `Tensor` with shape 898 `[filter_height, filter_width, in_channels, channel_multiplier]`. 899 Contains `in_channels` convolutional filters of depth 1. 900 pointwise_filter: 4-D `Tensor` with shape 901 `[1, 1, channel_multiplier * in_channels, out_channels]`. Pointwise 902 filter to mix channels after `depthwise_filter` has convolved spatially. 903 strides: 1-D of size 4. The strides for the depthwise convolution for 904 each dimension of `input`. 905 padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. 906 See the "returns" section of `tf.nn.convolution` for details. 907 rate: 1-D of size 2. The dilation rate in which we sample input values 908 across the `height` and `width` dimensions in atrous convolution. If it is 909 greater than 1, then all values of strides must be 1. 910 name: A name for this operation (optional). 911 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 912 dilations: Alias of rate. 913 914 Returns: 915 A 4-D `Tensor` with shape according to 'data_format'. For 916 example, with data_format="NHWC", shape is [batch, out_height, 917 out_width, out_channels]. 918 """ 919 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate) 920 with ops.name_scope(name, "separable_conv2d", 921 [input, depthwise_filter, pointwise_filter]) as name: 922 input = ops.convert_to_tensor(input, name="tensor_in") 923 depthwise_filter = ops.convert_to_tensor( 924 depthwise_filter, name="depthwise_filter") 925 pointwise_filter = ops.convert_to_tensor( 926 pointwise_filter, name="pointwise_filter") 927 928 pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4) 929 pointwise_filter_shape.dims[0].assert_is_compatible_with(1) 930 pointwise_filter_shape.dims[1].assert_is_compatible_with(1) 931 932 if rate is None: 933 rate = [1, 1] 934 935 # The layout of the ops in the graph are expected to be as follows: 936 # depthwise_conv2d // Conv2D op corresponding to native deptwise conv. 937 # separable_conv2d // Conv2D op corresponding to the pointwise conv. 938 939 def op(input_converted, _, padding): 940 return nn_ops.depthwise_conv2d_native( 941 input=input_converted, 942 filter=depthwise_filter, 943 strides=strides, 944 padding=padding, 945 data_format=data_format, 946 name="depthwise") 947 948 depthwise = nn_ops.with_space_to_batch( 949 input=input, 950 filter_shape=array_ops.shape(depthwise_filter), 951 dilation_rate=rate, 952 padding=padding, 953 data_format=data_format, 954 op=op) 955 956 return nn_ops.conv2d( 957 depthwise, 958 pointwise_filter, [1, 1, 1, 1], 959 padding="VALID", 960 data_format=data_format, 961 name=name) 962 963 964@tf_export("nn.separable_conv2d", v1=[]) 965def separable_conv2d_v2( 966 input, 967 depthwise_filter, 968 pointwise_filter, 969 strides, 970 padding, 971 data_format=None, 972 dilations=None, 973 name=None, 974): 975 """2-D convolution with separable filters. 976 977 Performs a depthwise convolution that acts separately on channels followed by 978 a pointwise convolution that mixes channels. Note that this is separability 979 between dimensions `[1, 2]` and `3`, not spatial separability between 980 dimensions `1` and `2`. 981 982 In detail, with the default NHWC format, 983 984 output[b, i, j, k] = sum_{di, dj, q, r} 985 input[b, strides[1] * i + di, strides[2] * j + dj, q] * 986 depthwise_filter[di, dj, q, r] * 987 pointwise_filter[0, 0, q * channel_multiplier + r, k] 988 989 `strides` controls the strides for the depthwise convolution only, since 990 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have 991 `strides[0] = strides[3] = 1`. For the most common case of the same 992 horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 993 If any value in `rate` is greater than 1, we perform atrous depthwise 994 convolution, in which case all values in the `strides` tensor must be equal 995 to 1. 996 997 Args: 998 input: 4-D `Tensor` with shape according to `data_format`. 999 depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width, 1000 in_channels, channel_multiplier]`. Contains `in_channels` convolutional 1001 filters of depth 1. 1002 pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier * 1003 in_channels, out_channels]`. Pointwise filter to mix channels after 1004 `depthwise_filter` has convolved spatially. 1005 strides: 1-D of size 4. The strides for the depthwise convolution for each 1006 dimension of `input`. 1007 padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See 1008 the "returns" section of `tf.nn.convolution` for details. 1009 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 1010 dilations: 1-D of size 2. The dilation rate in which we sample input values 1011 across the `height` and `width` dimensions in atrous convolution. If it is 1012 greater than 1, then all values of strides must be 1. 1013 name: A name for this operation (optional). 1014 1015 Returns: 1016 A 4-D `Tensor` with shape according to 'data_format'. For 1017 example, with data_format="NHWC", shape is [batch, out_height, 1018 out_width, out_channels]. 1019 """ 1020 return separable_conv2d( 1021 input, 1022 depthwise_filter, 1023 pointwise_filter, 1024 strides, 1025 padding, 1026 rate=dilations, 1027 name=name, 1028 data_format=data_format) 1029 1030# pylint: enable=redefined-builtin,line-too-long 1031 1032 1033@tf_export(v1=["nn.sufficient_statistics"]) 1034def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None, 1035 keepdims=None): 1036 """Calculate the sufficient statistics for the mean and variance of `x`. 1037 1038 These sufficient statistics are computed using the one pass algorithm on 1039 an input that's optionally shifted. See: 1040 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data 1041 1042 Args: 1043 x: A `Tensor`. 1044 axes: Array of ints. Axes along which to compute mean and variance. 1045 shift: A `Tensor` containing the value by which to shift the data for 1046 numerical stability, or `None` if no shift is to be performed. A shift 1047 close to the true mean provides the most numerically stable results. 1048 keep_dims: produce statistics with the same dimensionality as the input. 1049 name: Name used to scope the operations that compute the sufficient stats. 1050 keepdims: Alias for keep_dims. 1051 1052 Returns: 1053 Four `Tensor` objects of the same type as `x`: 1054 1055 * the count (number of elements to average over). 1056 * the (possibly shifted) sum of the elements in the array. 1057 * the (possibly shifted) sum of squares of the elements in the array. 1058 * the shift by which the mean must be corrected or None if `shift` is None. 1059 """ 1060 axes = list(set(axes)) 1061 keep_dims = deprecated_argument_lookup( 1062 "keepdims", keepdims, "keep_dims", keep_dims) 1063 if keep_dims is None: 1064 keep_dims = False 1065 with ops.name_scope(name, "sufficient_statistics", [x, shift]): 1066 x = ops.convert_to_tensor(x, name="x") 1067 x_shape = x.get_shape() 1068 if x_shape.rank is not None and all( 1069 x_shape.dims[d].value is not None for d in axes): 1070 counts = 1 1071 for d in axes: 1072 counts *= x_shape.dims[d].value 1073 counts = constant_op.constant(counts, dtype=x.dtype) 1074 else: # shape needs to be inferred at runtime. 1075 x_dims = array_ops.gather( 1076 math_ops.cast(array_ops.shape(x), x.dtype), axes) 1077 counts = math_ops.reduce_prod(x_dims, name="count") 1078 if shift is not None: 1079 shift = ops.convert_to_tensor(shift, name="shift") 1080 m_ss = math_ops.subtract(x, shift) 1081 v_ss = math_ops.squared_difference(x, shift) 1082 else: # no shift. 1083 m_ss = x 1084 v_ss = math_ops.square(x) 1085 m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss") 1086 v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss") 1087 return counts, m_ss, v_ss, shift 1088 1089 1090@tf_export("nn.sufficient_statistics", v1=[]) 1091def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None): 1092 """Calculate the sufficient statistics for the mean and variance of `x`. 1093 1094 These sufficient statistics are computed using the one pass algorithm on 1095 an input that's optionally shifted. See: 1096 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data 1097 1098 Args: 1099 x: A `Tensor`. 1100 axes: Array of ints. Axes along which to compute mean and variance. 1101 shift: A `Tensor` containing the value by which to shift the data for 1102 numerical stability, or `None` if no shift is to be performed. A shift 1103 close to the true mean provides the most numerically stable results. 1104 keepdims: produce statistics with the same dimensionality as the input. 1105 name: Name used to scope the operations that compute the sufficient stats. 1106 1107 Returns: 1108 Four `Tensor` objects of the same type as `x`: 1109 1110 * the count (number of elements to average over). 1111 * the (possibly shifted) sum of the elements in the array. 1112 * the (possibly shifted) sum of squares of the elements in the array. 1113 * the shift by which the mean must be corrected or None if `shift` is None. 1114 """ 1115 return sufficient_statistics( 1116 x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name) 1117 1118 1119@tf_export("nn.normalize_moments") 1120def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): 1121 """Calculate the mean and variance of based on the sufficient statistics. 1122 1123 Args: 1124 counts: A `Tensor` containing the total count of the data (one value). 1125 mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly 1126 shifted) sum of the elements to average over. 1127 variance_ss: A `Tensor` containing the variance sufficient statistics: the 1128 (possibly shifted) squared sum of the data to compute the variance over. 1129 shift: A `Tensor` containing the value by which the data is shifted for 1130 numerical stability, or `None` if no shift was performed. 1131 name: Name used to scope the operations that compute the moments. 1132 1133 Returns: 1134 Two `Tensor` objects: `mean` and `variance`. 1135 """ 1136 with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]): 1137 divisor = math_ops.reciprocal(counts, name="divisor") 1138 if shift is not None: 1139 shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean") 1140 mean = math_ops.add(shifted_mean, shift, name="mean") 1141 else: # no shift. 1142 shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean") 1143 mean = shifted_mean 1144 variance = math_ops.subtract( 1145 math_ops.multiply(variance_ss, divisor), 1146 math_ops.square(shifted_mean), 1147 name="variance") 1148 return (mean, variance) 1149 1150 1151@tf_export(v1=["nn.moments"]) 1152def moments( 1153 x, 1154 axes, 1155 shift=None, # pylint: disable=unused-argument 1156 name=None, 1157 keep_dims=None, 1158 keepdims=None): 1159 """Calculate the mean and variance of `x`. 1160 1161 The mean and variance are calculated by aggregating the contents of `x` 1162 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean 1163 and variance of a vector. 1164 1165 Note: shift is currently not used; the true mean is computed and used. 1166 1167 When using these moments for batch normalization (see 1168 `tf.nn.batch_normalization`): 1169 1170 * for so-called "global normalization", used with convolutional filters with 1171 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`. 1172 * for simple batch normalization pass `axes=[0]` (batch only). 1173 1174 Args: 1175 x: A `Tensor`. 1176 axes: Array of ints. Axes along which to compute mean and 1177 variance. 1178 shift: Not used in the current implementation 1179 name: Name used to scope the operations that compute the moments. 1180 keep_dims: produce moments with the same dimensionality as the input. 1181 keepdims: Alias to keep_dims. 1182 1183 Returns: 1184 Two `Tensor` objects: `mean` and `variance`. 1185 """ 1186 keep_dims = deprecated_argument_lookup( 1187 "keepdims", keepdims, "keep_dims", keep_dims) 1188 if keep_dims is None: 1189 keep_dims = False 1190 with ops.name_scope(name, "moments", [x, axes]): 1191 # The dynamic range of fp16 is too limited to support the collection of 1192 # sufficient statistics. As a workaround we simply perform the operations 1193 # on 32-bit floats before converting the mean and variance back to fp16 1194 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x 1195 # Compute true mean while keeping the dims for proper broadcasting. 1196 mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean") 1197 # sample variance, not unbiased variance 1198 # Note: stop_gradient does not change the gradient that gets 1199 # backpropagated to the mean from the variance calculation, 1200 # because that gradient is zero 1201 variance = math_ops.reduce_mean( 1202 math_ops.squared_difference(y, array_ops.stop_gradient(mean)), 1203 axes, 1204 keepdims=True, 1205 name="variance") 1206 if not keep_dims: 1207 mean = array_ops.squeeze(mean, axes) 1208 variance = array_ops.squeeze(variance, axes) 1209 if x.dtype == dtypes.float16: 1210 return (math_ops.cast(mean, dtypes.float16), 1211 math_ops.cast(variance, dtypes.float16)) 1212 else: 1213 return (mean, variance) 1214 1215 1216@tf_export("nn.moments", v1=[]) 1217def moments_v2( 1218 x, 1219 axes, 1220 shift=None, 1221 keepdims=False, 1222 name=None): 1223 """Calculates the mean and variance of `x`. 1224 1225 The mean and variance are calculated by aggregating the contents of `x` 1226 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean 1227 and variance of a vector. 1228 1229 Note: shift is currently not used; the true mean is computed and used. 1230 1231 When using these moments for batch normalization (see 1232 `tf.nn.batch_normalization`): 1233 1234 * for so-called "global normalization", used with convolutional filters with 1235 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`. 1236 * for simple batch normalization pass `axes=[0]` (batch only). 1237 1238 Args: 1239 x: A `Tensor`. 1240 axes: Array of ints. Axes along which to compute mean and 1241 variance. 1242 shift: Not used in the current implementation. 1243 keepdims: produce moments with the same dimensionality as the input. 1244 name: Name used to scope the operations that compute the moments. 1245 1246 Returns: 1247 Two `Tensor` objects: `mean` and `variance`. 1248 """ 1249 return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims) 1250 1251 1252@tf_export(v1=["nn.weighted_moments"]) 1253def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None, 1254 keepdims=None): 1255 """Returns the frequency-weighted mean and variance of `x`. 1256 1257 Args: 1258 x: A tensor. 1259 axes: 1-d tensor of int32 values; these are the axes along which 1260 to compute mean and variance. 1261 frequency_weights: A tensor of positive weights which can be 1262 broadcast with x. 1263 name: Name used to scope the operation. 1264 keep_dims: Produce moments with the same dimensionality as the input. 1265 keepdims: Alias of keep_dims. 1266 1267 Returns: 1268 Two tensors: `weighted_mean` and `weighted_variance`. 1269 """ 1270 keep_dims = deprecated_argument_lookup( 1271 "keepdims", keepdims, "keep_dims", keep_dims) 1272 if keep_dims is None: 1273 keep_dims = False 1274 with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]): 1275 x = ops.convert_to_tensor(x, name="x") 1276 frequency_weights = ops.convert_to_tensor( 1277 frequency_weights, name="frequency_weights") 1278 1279 # Unlike moments(), this just uses a simpler two-pass method. 1280 1281 # See comment in moments() WRT precision; it applies here too. 1282 needs_cast = x.dtype == dtypes.float16 1283 if needs_cast: 1284 x = math_ops.cast(x, dtypes.float32) 1285 1286 if frequency_weights.dtype != x.dtype: 1287 frequency_weights = math_ops.cast(frequency_weights, x.dtype) 1288 1289 # Note that we use keep_dims=True for our reductions regardless of the arg; 1290 # this is so that the results remain broadcast-compatible with the inputs. 1291 weighted_input_sum = math_ops.reduce_sum( 1292 frequency_weights * x, axes, name="weighted_input_sum", keepdims=True) 1293 1294 # The shape of the weights isn't necessarily the same as x's 1295 # shape, just broadcast-compatible with it -- so this expression 1296 # performs broadcasting to give a per-item weight, with the same 1297 # shape as (freqency_weights * x). This avoids having to reason 1298 # through all the broadcast logic to compute a correct 1299 # sum_of_weights. 1300 broadcasted_weights = frequency_weights + array_ops.zeros_like(x) 1301 1302 sum_of_weights = math_ops.reduce_sum( 1303 broadcasted_weights, axes, name="sum_of_weights", keepdims=True) 1304 1305 divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum") 1306 1307 weighted_mean = math_ops.multiply(weighted_input_sum, divisor) 1308 1309 # Have the weighted mean; now on to variance: 1310 weighted_distsq = math_ops.reduce_sum( 1311 frequency_weights * math_ops.squared_difference(x, weighted_mean), 1312 axes, 1313 name="weighted_distsq", 1314 keepdims=True) 1315 1316 weighted_variance = math_ops.multiply(weighted_distsq, divisor) 1317 1318 if not keep_dims: 1319 weighted_mean = array_ops.squeeze(weighted_mean, axis=axes) 1320 weighted_variance = array_ops.squeeze( 1321 weighted_variance, axis=axes) 1322 1323 if needs_cast: 1324 weighted_mean = math_ops.cast(weighted_mean, dtypes.float16) 1325 weighted_variance = math_ops.cast(weighted_variance, dtypes.float16) 1326 1327 return weighted_mean, weighted_variance 1328 1329 1330@tf_export("nn.weighted_moments", v1=[]) 1331def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None): 1332 """Returns the frequency-weighted mean and variance of `x`. 1333 1334 Args: 1335 x: A tensor. 1336 axes: 1-d tensor of int32 values; these are the axes along which 1337 to compute mean and variance. 1338 frequency_weights: A tensor of positive weights which can be 1339 broadcast with x. 1340 keepdims: Produce moments with the same dimensionality as the input. 1341 name: Name used to scope the operation. 1342 1343 Returns: 1344 Two tensors: `weighted_mean` and `weighted_variance`. 1345 """ 1346 return weighted_moments( 1347 x=x, 1348 axes=axes, 1349 frequency_weights=frequency_weights, 1350 name=name, 1351 keep_dims=keepdims) 1352 1353 1354@tf_export("nn.batch_normalization") 1355def batch_normalization(x, 1356 mean, 1357 variance, 1358 offset, 1359 scale, 1360 variance_epsilon, 1361 name=None): 1362 r"""Batch normalization. 1363 1364 Normalizes a tensor by `mean` and `variance`, and applies (optionally) a 1365 `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\): 1366 1367 \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\) 1368 1369 `mean`, `variance`, `offset` and `scale` are all expected to be of one of two 1370 shapes: 1371 1372 * In all generality, they can have the same number of dimensions as the 1373 input `x`, with identical sizes as `x` for the dimensions that are not 1374 normalized over (the 'depth' dimension(s)), and dimension 1 for the 1375 others which are being normalized over. 1376 `mean` and `variance` in this case would typically be the outputs of 1377 `tf.nn.moments(..., keepdims=True)` during training, or running averages 1378 thereof during inference. 1379 * In the common case where the 'depth' dimension is the last dimension in 1380 the input tensor `x`, they may be one dimensional tensors of the same 1381 size as the 'depth' dimension. 1382 This is the case for example for the common `[batch, depth]` layout of 1383 fully-connected layers, and `[batch, height, width, depth]` for 1384 convolutions. 1385 `mean` and `variance` in this case would typically be the outputs of 1386 `tf.nn.moments(..., keepdims=False)` during training, or running averages 1387 thereof during inference. 1388 1389 See equation 11 in Algorithm 2 of source: 1390 [Batch Normalization: Accelerating Deep Network Training by 1391 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy] 1392 (http://arxiv.org/abs/1502.03167). 1393 1394 Args: 1395 x: Input `Tensor` of arbitrary dimensionality. 1396 mean: A mean `Tensor`. 1397 variance: A variance `Tensor`. 1398 offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or 1399 None. If present, will be added to the normalized tensor. 1400 scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or 1401 `None`. If present, the scale is applied to the normalized tensor. 1402 variance_epsilon: A small float number to avoid dividing by 0. 1403 name: A name for this operation (optional). 1404 1405 Returns: 1406 the normalized, scaled, offset tensor. 1407 1408 References: 1409 Batch Normalization - Accelerating Deep Network Training by Reducing 1410 Internal Covariate Shift: 1411 [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167) 1412 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1413 """ 1414 with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]): 1415 inv = math_ops.rsqrt(variance + variance_epsilon) 1416 if scale is not None: 1417 inv *= scale 1418 # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on 1419 # the precise order of ops that are generated by the expression below. 1420 return x * math_ops.cast(inv, x.dtype) + math_ops.cast( 1421 offset - mean * inv if offset is not None else -mean * inv, x.dtype) 1422 1423 1424@tf_export(v1=["nn.fused_batch_norm"]) 1425def fused_batch_norm( 1426 x, 1427 scale, 1428 offset, # pylint: disable=invalid-name 1429 mean=None, 1430 variance=None, 1431 epsilon=0.001, 1432 data_format="NHWC", 1433 is_training=True, 1434 name=None): 1435 r"""Batch normalization. 1436 1437 1438 See Source: [Batch Normalization: Accelerating Deep Network Training by 1439 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy] 1440 (http://arxiv.org/abs/1502.03167). 1441 1442 Args: 1443 x: Input `Tensor` of 4 dimensions. 1444 scale: A `Tensor` of 1 dimension for scaling. 1445 offset: A `Tensor` of 1 dimension for bias. 1446 mean: A `Tensor` of 1 dimension for population mean used for inference. 1447 variance: A `Tensor` of 1 dimension for population variance 1448 used for inference. 1449 epsilon: A small float number added to the variance of x. 1450 data_format: The data format for x. Either "NHWC" (default) or "NCHW". 1451 is_training: A bool value to specify if the operation is used for 1452 training or inference. 1453 name: A name for this operation (optional). 1454 1455 Returns: 1456 y: A 4D Tensor for the normalized, scaled, offsetted x. 1457 batch_mean: A 1D Tensor for the mean of x. 1458 batch_var: A 1D Tensor for the variance of x. 1459 1460 Raises: 1461 ValueError: If mean or variance is not None when is_training is True. 1462 1463 References: 1464 Batch Normalization - Accelerating Deep Network Training by Reducing 1465 Internal Covariate Shift: 1466 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1467 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1468 """ 1469 x = ops.convert_to_tensor(x, name="input") 1470 scale = ops.convert_to_tensor(scale, name="scale") 1471 offset = ops.convert_to_tensor(offset, name="offset") 1472 if is_training: 1473 if (mean is not None) or (variance is not None): 1474 raise ValueError("Both 'mean' and 'variance' must be None " 1475 "if is_training is True.") 1476 if mean is None: 1477 mean = constant_op.constant([]) 1478 if variance is None: 1479 variance = constant_op.constant([]) 1480 # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to 1481 # prevent exception (see cudnn.h). 1482 min_epsilon = 1.001e-5 1483 epsilon = epsilon if epsilon > min_epsilon else min_epsilon 1484 1485 y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( 1486 x, 1487 scale, 1488 offset, 1489 mean, 1490 variance, 1491 epsilon=epsilon, 1492 data_format=data_format, 1493 is_training=is_training, 1494 name=name) 1495 return y, batch_mean, batch_var 1496 1497 1498@tf_export(v1=["nn.batch_norm_with_global_normalization"]) 1499def batch_norm_with_global_normalization(t=None, 1500 m=None, 1501 v=None, 1502 beta=None, 1503 gamma=None, 1504 variance_epsilon=None, 1505 scale_after_normalization=None, 1506 name=None, 1507 input=None, # pylint: disable=redefined-builtin 1508 mean=None, 1509 variance=None): 1510 """Batch normalization. 1511 1512 This op is deprecated. See `tf.nn.batch_normalization`. 1513 1514 Args: 1515 t: A 4D input Tensor. 1516 m: A 1D mean Tensor with size matching the last dimension of t. 1517 This is the first output from tf.nn.moments, 1518 or a saved moving average thereof. 1519 v: A 1D variance Tensor with size matching the last dimension of t. 1520 This is the second output from tf.nn.moments, 1521 or a saved moving average thereof. 1522 beta: A 1D beta Tensor with size matching the last dimension of t. 1523 An offset to be added to the normalized tensor. 1524 gamma: A 1D gamma Tensor with size matching the last dimension of t. 1525 If "scale_after_normalization" is true, this tensor will be multiplied 1526 with the normalized tensor. 1527 variance_epsilon: A small float number to avoid dividing by 0. 1528 scale_after_normalization: A bool indicating whether the resulted tensor 1529 needs to be multiplied with gamma. 1530 name: A name for this operation (optional). 1531 input: Alias for t. 1532 mean: Alias for m. 1533 variance: Alias for v. 1534 1535 Returns: 1536 A batch-normalized `t`. 1537 1538 References: 1539 Batch Normalization - Accelerating Deep Network Training by Reducing 1540 Internal Covariate Shift: 1541 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1542 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1543 """ 1544 t = deprecated_argument_lookup("input", input, "t", t) 1545 m = deprecated_argument_lookup("mean", mean, "m", m) 1546 v = deprecated_argument_lookup("variance", variance, "v", v) 1547 return batch_normalization(t, m, v, beta, gamma if scale_after_normalization 1548 else None, variance_epsilon, name) 1549 1550 1551# pylint: disable=redefined-builtin,line-too-long 1552@tf_export("nn.batch_norm_with_global_normalization", v1=[]) 1553def batch_norm_with_global_normalization_v2(input, 1554 mean, 1555 variance, 1556 beta, 1557 gamma, 1558 variance_epsilon, 1559 scale_after_normalization, 1560 name=None): 1561 """Batch normalization. 1562 1563 This op is deprecated. See `tf.nn.batch_normalization`. 1564 1565 Args: 1566 input: A 4D input Tensor. 1567 mean: A 1D mean Tensor with size matching the last dimension of t. 1568 This is the first output from tf.nn.moments, 1569 or a saved moving average thereof. 1570 variance: A 1D variance Tensor with size matching the last dimension of t. 1571 This is the second output from tf.nn.moments, 1572 or a saved moving average thereof. 1573 beta: A 1D beta Tensor with size matching the last dimension of t. 1574 An offset to be added to the normalized tensor. 1575 gamma: A 1D gamma Tensor with size matching the last dimension of t. 1576 If "scale_after_normalization" is true, this tensor will be multiplied 1577 with the normalized tensor. 1578 variance_epsilon: A small float number to avoid dividing by 0. 1579 scale_after_normalization: A bool indicating whether the resulted tensor 1580 needs to be multiplied with gamma. 1581 name: A name for this operation (optional). 1582 1583 Returns: 1584 A batch-normalized `t`. 1585 1586 References: 1587 Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift: 1588 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1589 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1590 """ 1591 return batch_norm_with_global_normalization(t=input, 1592 m=mean, 1593 v=variance, 1594 beta=beta, 1595 gamma=gamma, 1596 variance_epsilon=variance_epsilon, 1597 scale_after_normalization=scale_after_normalization, 1598 name=name) 1599 1600# pylint: enable=redefined-builtin,line-too-long 1601 1602 1603def _sum_rows(x): 1604 """Returns a vector summing up each row of the matrix x.""" 1605 # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is 1606 # a matrix. The gradient of _sum_rows(x) is more efficient than 1607 # reduce_sum(x, 1)'s gradient in today's implementation. Therefore, 1608 # we use _sum_rows(x) in the nce_loss() computation since the loss 1609 # is mostly used for training. 1610 cols = array_ops.shape(x)[1] 1611 ones_shape = array_ops.stack([cols, 1]) 1612 ones = array_ops.ones(ones_shape, x.dtype) 1613 return array_ops.reshape(math_ops.matmul(x, ones), [-1]) 1614 1615 1616def _compute_sampled_logits(weights, 1617 biases, 1618 labels, 1619 inputs, 1620 num_sampled, 1621 num_classes, 1622 num_true=1, 1623 sampled_values=None, 1624 subtract_log_q=True, 1625 remove_accidental_hits=False, 1626 partition_strategy="mod", 1627 name=None, 1628 seed=None): 1629 """Helper function for nce_loss and sampled_softmax_loss functions. 1630 1631 Computes sampled output training logits and labels suitable for implementing 1632 e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see 1633 sampled_softmax_loss). 1634 1635 Note: In the case where num_true > 1, we assign to each target class 1636 the target probability 1 / num_true so that the target probabilities 1637 sum to 1 per-example. 1638 1639 Args: 1640 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 1641 objects whose concatenation along dimension 0 has shape 1642 `[num_classes, dim]`. The (possibly-partitioned) class embeddings. 1643 biases: A `Tensor` of shape `[num_classes]`. The (possibly-partitioned) 1644 class biases. 1645 labels: A `Tensor` of type `int64` and shape `[batch_size, 1646 num_true]`. The target classes. Note that this format differs from 1647 the `labels` argument of `nn.softmax_cross_entropy_with_logits`. 1648 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 1649 activations of the input network. 1650 num_sampled: An `int`. The number of classes to randomly sample per batch. 1651 num_classes: An `int`. The number of possible classes. 1652 num_true: An `int`. The number of target classes per training example. 1653 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 1654 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 1655 (if None, we default to `log_uniform_candidate_sampler`) 1656 subtract_log_q: A `bool`. whether to subtract the log expected count of 1657 the labels in the sample to get the logits of the true labels. 1658 Default is True. Turn off for Negative Sampling. 1659 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 1660 where a sampled class equals one of the target classes. Default is 1661 False. 1662 partition_strategy: A string specifying the partitioning strategy, relevant 1663 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 1664 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 1665 name: A name for the operation (optional). 1666 seed: random seed for candidate sampling. Default to None, which doesn't set 1667 the op-level random seed for candidate sampling. 1668 Returns: 1669 out_logits: `Tensor` object with shape 1670 `[batch_size, num_true + num_sampled]`, for passing to either 1671 `nn.sigmoid_cross_entropy_with_logits` (NCE) or 1672 `nn.softmax_cross_entropy_with_logits` (sampled softmax). 1673 out_labels: A Tensor object with the same shape as `out_logits`. 1674 """ 1675 1676 if isinstance(weights, variables.PartitionedVariable): 1677 weights = list(weights) 1678 if not isinstance(weights, list): 1679 weights = [weights] 1680 1681 with ops.name_scope(name, "compute_sampled_logits", 1682 weights + [biases, inputs, labels]): 1683 if labels.dtype != dtypes.int64: 1684 labels = math_ops.cast(labels, dtypes.int64) 1685 labels_flat = array_ops.reshape(labels, [-1]) 1686 1687 # Sample the negative labels. 1688 # sampled shape: [num_sampled] tensor 1689 # true_expected_count shape = [batch_size, 1] tensor 1690 # sampled_expected_count shape = [num_sampled] tensor 1691 if sampled_values is None: 1692 sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler( 1693 true_classes=labels, 1694 num_true=num_true, 1695 num_sampled=num_sampled, 1696 unique=True, 1697 range_max=num_classes, 1698 seed=seed) 1699 # NOTE: pylint cannot tell that 'sampled_values' is a sequence 1700 # pylint: disable=unpacking-non-sequence 1701 sampled, true_expected_count, sampled_expected_count = ( 1702 array_ops.stop_gradient(s) for s in sampled_values) 1703 # pylint: enable=unpacking-non-sequence 1704 sampled = math_ops.cast(sampled, dtypes.int64) 1705 1706 # labels_flat is a [batch_size * num_true] tensor 1707 # sampled is a [num_sampled] int tensor 1708 all_ids = array_ops.concat([labels_flat, sampled], 0) 1709 1710 # Retrieve the true weights and the logits of the sampled weights. 1711 1712 # weights shape is [num_classes, dim] 1713 all_w = embedding_ops.embedding_lookup( 1714 weights, all_ids, partition_strategy=partition_strategy) 1715 if all_w.dtype != inputs.dtype: 1716 all_w = math_ops.cast(all_w, inputs.dtype) 1717 1718 # true_w shape is [batch_size * num_true, dim] 1719 true_w = array_ops.slice(all_w, [0, 0], 1720 array_ops.stack( 1721 [array_ops.shape(labels_flat)[0], -1])) 1722 1723 sampled_w = array_ops.slice( 1724 all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1]) 1725 # inputs has shape [batch_size, dim] 1726 # sampled_w has shape [num_sampled, dim] 1727 # Apply X*W', which yields [batch_size, num_sampled] 1728 sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True) 1729 1730 # Retrieve the true and sampled biases, compute the true logits, and 1731 # add the biases to the true and sampled logits. 1732 all_b = embedding_ops.embedding_lookup( 1733 biases, all_ids, partition_strategy=partition_strategy) 1734 if all_b.dtype != inputs.dtype: 1735 all_b = math_ops.cast(all_b, inputs.dtype) 1736 # true_b is a [batch_size * num_true] tensor 1737 # sampled_b is a [num_sampled] float tensor 1738 true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat)) 1739 sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1]) 1740 1741 # inputs shape is [batch_size, dim] 1742 # true_w shape is [batch_size * num_true, dim] 1743 # row_wise_dots is [batch_size, num_true, dim] 1744 dim = array_ops.shape(true_w)[1:2] 1745 new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0) 1746 row_wise_dots = math_ops.multiply( 1747 array_ops.expand_dims(inputs, 1), 1748 array_ops.reshape(true_w, new_true_w_shape)) 1749 # We want the row-wise dot plus biases which yields a 1750 # [batch_size, num_true] tensor of true_logits. 1751 dots_as_matrix = array_ops.reshape(row_wise_dots, 1752 array_ops.concat([[-1], dim], 0)) 1753 true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true]) 1754 true_b = array_ops.reshape(true_b, [-1, num_true]) 1755 true_logits += true_b 1756 sampled_logits += sampled_b 1757 1758 if remove_accidental_hits: 1759 acc_hits = candidate_sampling_ops.compute_accidental_hits( 1760 labels, sampled, num_true=num_true) 1761 acc_indices, acc_ids, acc_weights = acc_hits 1762 1763 # This is how SparseToDense expects the indices. 1764 acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1]) 1765 acc_ids_2d_int32 = array_ops.reshape( 1766 math_ops.cast(acc_ids, dtypes.int32), [-1, 1]) 1767 sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1, 1768 "sparse_indices") 1769 # Create sampled_logits_shape = [batch_size, num_sampled] 1770 sampled_logits_shape = array_ops.concat( 1771 [array_ops.shape(labels)[:1], 1772 array_ops.expand_dims(num_sampled, 0)], 0) 1773 if sampled_logits.dtype != acc_weights.dtype: 1774 acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype) 1775 sampled_logits += gen_sparse_ops.sparse_to_dense( 1776 sparse_indices, 1777 sampled_logits_shape, 1778 acc_weights, 1779 default_value=0.0, 1780 validate_indices=False) 1781 1782 if subtract_log_q: 1783 # Subtract log of Q(l), prior probability that l appears in sampled. 1784 true_logits -= math_ops.log(true_expected_count) 1785 sampled_logits -= math_ops.log(sampled_expected_count) 1786 1787 # Construct output logits and labels. The true labels/logits start at col 0. 1788 out_logits = array_ops.concat([true_logits, sampled_logits], 1) 1789 1790 # true_logits is a float tensor, ones_like(true_logits) is a float 1791 # tensor of ones. We then divide by num_true to ensure the per-example 1792 # labels sum to 1.0, i.e. form a proper probability distribution. 1793 out_labels = array_ops.concat([ 1794 array_ops.ones_like(true_logits) / num_true, 1795 array_ops.zeros_like(sampled_logits) 1796 ], 1) 1797 1798 return out_logits, out_labels 1799 1800 1801@tf_export("nn.nce_loss", v1=[]) 1802def nce_loss_v2(weights, 1803 biases, 1804 labels, 1805 inputs, 1806 num_sampled, 1807 num_classes, 1808 num_true=1, 1809 sampled_values=None, 1810 remove_accidental_hits=False, 1811 name="nce_loss"): 1812 """Computes and returns the noise-contrastive estimation training loss. 1813 1814 See [Noise-contrastive estimation: A new estimation principle for 1815 unnormalized statistical 1816 models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). 1817 Also see our [Candidate Sampling Algorithms 1818 Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf) 1819 1820 A common use case is to use this method for training, and calculate the full 1821 sigmoid loss for evaluation or inference as in the following example: 1822 1823 ```python 1824 if mode == "train": 1825 loss = tf.nn.nce_loss( 1826 weights=weights, 1827 biases=biases, 1828 labels=labels, 1829 inputs=inputs, 1830 ...) 1831 elif mode == "eval": 1832 logits = tf.matmul(inputs, tf.transpose(weights)) 1833 logits = tf.nn.bias_add(logits, biases) 1834 labels_one_hot = tf.one_hot(labels, n_classes) 1835 loss = tf.nn.sigmoid_cross_entropy_with_logits( 1836 labels=labels_one_hot, 1837 logits=logits) 1838 loss = tf.reduce_sum(loss, axis=1) 1839 ``` 1840 1841 Note: when doing embedding lookup on `weights` and `bias`, "div" partition 1842 strategy will be used. Support for other partition strategy will be added 1843 later. 1844 1845 Note: By default this uses a log-uniform (Zipfian) distribution for sampling, 1846 so your labels must be sorted in order of decreasing frequency to achieve 1847 good results. For more details, see 1848 `tf.random.log_uniform_candidate_sampler`. 1849 1850 Note: In the case where `num_true` > 1, we assign to each target class 1851 the target probability 1 / `num_true` so that the target probabilities 1852 sum to 1 per-example. 1853 1854 Note: It would be useful to allow a variable number of target classes per 1855 example. We hope to provide this functionality in a future release. 1856 For now, if you have a variable number of target classes, you can pad them 1857 out to a constant number by either repeating them or by padding 1858 with an otherwise unused class. 1859 1860 Args: 1861 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 1862 objects whose concatenation along dimension 0 has shape [num_classes, 1863 dim]. The (possibly-partitioned) class embeddings. 1864 biases: A `Tensor` of shape `[num_classes]`. The class biases. 1865 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The 1866 target classes. 1867 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of 1868 the input network. 1869 num_sampled: An `int`. The number of negative classes to randomly sample 1870 per batch. This single sample of negative classes is evaluated for each 1871 element in the batch. 1872 num_classes: An `int`. The number of possible classes. 1873 num_true: An `int`. The number of target classes per training example. 1874 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 1875 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 1876 (if None, we default to `log_uniform_candidate_sampler`) 1877 remove_accidental_hits: A `bool`. Whether to remove "accidental hits" 1878 where a sampled class equals one of the target classes. If set to `True`, 1879 this is a "Sampled Logistic" loss instead of NCE, and we are learning to 1880 generate log-odds instead of log probabilities. See our [Candidate 1881 Sampling Algorithms Reference] 1882 (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is 1883 False. 1884 name: A name for the operation (optional). 1885 1886 Returns: 1887 A `batch_size` 1-D tensor of per-example NCE losses. 1888 """ 1889 # TODO(yuefengz): get partition_strategy from either variables or distribution 1890 # strategies. 1891 return nce_loss( 1892 weights, 1893 biases, 1894 labels, 1895 inputs, 1896 num_sampled, 1897 num_classes, 1898 num_true=num_true, 1899 sampled_values=sampled_values, 1900 remove_accidental_hits=remove_accidental_hits, 1901 partition_strategy="div", 1902 name=name) 1903 1904 1905@tf_export(v1=["nn.nce_loss"]) 1906def nce_loss(weights, 1907 biases, 1908 labels, 1909 inputs, 1910 num_sampled, 1911 num_classes, 1912 num_true=1, 1913 sampled_values=None, 1914 remove_accidental_hits=False, 1915 partition_strategy="mod", 1916 name="nce_loss"): 1917 """Computes and returns the noise-contrastive estimation training loss. 1918 1919 A common use case is to use this method for training, and calculate the full 1920 sigmoid loss for evaluation or inference. In this case, you must set 1921 `partition_strategy="div"` for the two losses to be consistent, as in the 1922 following example: 1923 1924 ```python 1925 if mode == "train": 1926 loss = tf.nn.nce_loss( 1927 weights=weights, 1928 biases=biases, 1929 labels=labels, 1930 inputs=inputs, 1931 ..., 1932 partition_strategy="div") 1933 elif mode == "eval": 1934 logits = tf.matmul(inputs, tf.transpose(weights)) 1935 logits = tf.nn.bias_add(logits, biases) 1936 labels_one_hot = tf.one_hot(labels, n_classes) 1937 loss = tf.nn.sigmoid_cross_entropy_with_logits( 1938 labels=labels_one_hot, 1939 logits=logits) 1940 loss = tf.reduce_sum(loss, axis=1) 1941 ``` 1942 1943 Note: By default this uses a log-uniform (Zipfian) distribution for sampling, 1944 so your labels must be sorted in order of decreasing frequency to achieve 1945 good results. For more details, see 1946 `tf.random.log_uniform_candidate_sampler`. 1947 1948 Note: In the case where `num_true` > 1, we assign to each target class 1949 the target probability 1 / `num_true` so that the target probabilities 1950 sum to 1 per-example. 1951 1952 Note: It would be useful to allow a variable number of target classes per 1953 example. We hope to provide this functionality in a future release. 1954 For now, if you have a variable number of target classes, you can pad them 1955 out to a constant number by either repeating them or by padding 1956 with an otherwise unused class. 1957 1958 Args: 1959 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 1960 objects whose concatenation along dimension 0 has shape 1961 [num_classes, dim]. The (possibly-partitioned) class embeddings. 1962 biases: A `Tensor` of shape `[num_classes]`. The class biases. 1963 labels: A `Tensor` of type `int64` and shape `[batch_size, 1964 num_true]`. The target classes. 1965 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 1966 activations of the input network. 1967 num_sampled: An `int`. The number of negative classes to randomly sample 1968 per batch. This single sample of negative classes is evaluated for each 1969 element in the batch. 1970 num_classes: An `int`. The number of possible classes. 1971 num_true: An `int`. The number of target classes per training example. 1972 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 1973 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 1974 (if None, we default to `log_uniform_candidate_sampler`) 1975 remove_accidental_hits: A `bool`. Whether to remove "accidental hits" 1976 where a sampled class equals one of the target classes. If set to 1977 `True`, this is a "Sampled Logistic" loss instead of NCE, and we are 1978 learning to generate log-odds instead of log probabilities. See 1979 our Candidate Sampling Algorithms Reference 1980 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)). 1981 Default is False. 1982 partition_strategy: A string specifying the partitioning strategy, relevant 1983 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 1984 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 1985 name: A name for the operation (optional). 1986 1987 Returns: 1988 A `batch_size` 1-D tensor of per-example NCE losses. 1989 1990 References: 1991 Noise-contrastive estimation - A new estimation principle for unnormalized 1992 statistical models: 1993 [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a) 1994 ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf)) 1995 """ 1996 logits, labels = _compute_sampled_logits( 1997 weights=weights, 1998 biases=biases, 1999 labels=labels, 2000 inputs=inputs, 2001 num_sampled=num_sampled, 2002 num_classes=num_classes, 2003 num_true=num_true, 2004 sampled_values=sampled_values, 2005 subtract_log_q=True, 2006 remove_accidental_hits=remove_accidental_hits, 2007 partition_strategy=partition_strategy, 2008 name=name) 2009 sampled_losses = sigmoid_cross_entropy_with_logits( 2010 labels=labels, logits=logits, name="sampled_losses") 2011 # sampled_losses is batch_size x {true_loss, sampled_losses...} 2012 # We sum out true and sampled losses. 2013 return _sum_rows(sampled_losses) 2014 2015 2016@tf_export("nn.sampled_softmax_loss", v1=[]) 2017def sampled_softmax_loss_v2(weights, 2018 biases, 2019 labels, 2020 inputs, 2021 num_sampled, 2022 num_classes, 2023 num_true=1, 2024 sampled_values=None, 2025 remove_accidental_hits=True, 2026 seed=None, 2027 name="sampled_softmax_loss"): 2028 """Computes and returns the sampled softmax training loss. 2029 2030 This is a faster way to train a softmax classifier over a huge number of 2031 classes. 2032 2033 This operation is for training only. It is generally an underestimate of 2034 the full softmax loss. 2035 2036 A common use case is to use this method for training, and calculate the full 2037 sigmoid loss for evaluation or inference as in the following example: 2038 2039 ```python 2040 if mode == "train": 2041 loss = tf.nn.sampled_softmax_loss( 2042 weights=weights, 2043 biases=biases, 2044 labels=labels, 2045 inputs=inputs, 2046 ...) 2047 elif mode == "eval": 2048 logits = tf.matmul(inputs, tf.transpose(weights)) 2049 logits = tf.nn.bias_add(logits, biases) 2050 labels_one_hot = tf.one_hot(labels, n_classes) 2051 loss = tf.nn.softmax_cross_entropy_with_logits( 2052 labels=labels_one_hot, 2053 logits=logits) 2054 ``` 2055 2056 See our [Candidate Sampling Algorithms Reference] 2057 (https://www.tensorflow.org/extras/candidate_sampling.pdf) 2058 2059 Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007) 2060 ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math. 2061 2062 Note: when doing embedding lookup on `weights` and `bias`, "div" partition 2063 strategy will be used. Support for other partition strategy will be added 2064 later. 2065 2066 Args: 2067 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2068 objects whose concatenation along dimension 0 has shape [num_classes, 2069 dim]. The (possibly-sharded) class embeddings. 2070 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2071 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The 2072 target classes. Note that this format differs from the `labels` argument 2073 of `nn.softmax_cross_entropy_with_logits`. 2074 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of 2075 the input network. 2076 num_sampled: An `int`. The number of classes to randomly sample per batch. 2077 num_classes: An `int`. The number of possible classes. 2078 num_true: An `int`. The number of target classes per training example. 2079 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2080 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2081 (if None, we default to `log_uniform_candidate_sampler`) 2082 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 2083 where a sampled class equals one of the target classes. Default is True. 2084 seed: random seed for candidate sampling. Default to None, which doesn't set 2085 the op-level random seed for candidate sampling. 2086 name: A name for the operation (optional). 2087 2088 Returns: 2089 A `batch_size` 1-D tensor of per-example sampled softmax losses. 2090 2091 """ 2092 return sampled_softmax_loss( 2093 weights, 2094 biases, 2095 labels, 2096 inputs, 2097 num_sampled, 2098 num_classes, 2099 num_true=num_true, 2100 sampled_values=sampled_values, 2101 remove_accidental_hits=remove_accidental_hits, 2102 partition_strategy="div", 2103 name=name, 2104 seed=seed) 2105 2106 2107@tf_export(v1=["nn.sampled_softmax_loss"]) 2108def sampled_softmax_loss(weights, 2109 biases, 2110 labels, 2111 inputs, 2112 num_sampled, 2113 num_classes, 2114 num_true=1, 2115 sampled_values=None, 2116 remove_accidental_hits=True, 2117 partition_strategy="mod", 2118 name="sampled_softmax_loss", 2119 seed=None): 2120 """Computes and returns the sampled softmax training loss. 2121 2122 This is a faster way to train a softmax classifier over a huge number of 2123 classes. 2124 2125 This operation is for training only. It is generally an underestimate of 2126 the full softmax loss. 2127 2128 A common use case is to use this method for training, and calculate the full 2129 softmax loss for evaluation or inference. In this case, you must set 2130 `partition_strategy="div"` for the two losses to be consistent, as in the 2131 following example: 2132 2133 ```python 2134 if mode == "train": 2135 loss = tf.nn.sampled_softmax_loss( 2136 weights=weights, 2137 biases=biases, 2138 labels=labels, 2139 inputs=inputs, 2140 ..., 2141 partition_strategy="div") 2142 elif mode == "eval": 2143 logits = tf.matmul(inputs, tf.transpose(weights)) 2144 logits = tf.nn.bias_add(logits, biases) 2145 labels_one_hot = tf.one_hot(labels, n_classes) 2146 loss = tf.nn.softmax_cross_entropy_with_logits( 2147 labels=labels_one_hot, 2148 logits=logits) 2149 ``` 2150 2151 See our Candidate Sampling Algorithms Reference 2152 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)). 2153 Also see Section 3 of (Jean et al., 2014) for the math. 2154 2155 Args: 2156 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2157 objects whose concatenation along dimension 0 has shape 2158 [num_classes, dim]. The (possibly-sharded) class embeddings. 2159 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2160 labels: A `Tensor` of type `int64` and shape `[batch_size, 2161 num_true]`. The target classes. Note that this format differs from 2162 the `labels` argument of `nn.softmax_cross_entropy_with_logits`. 2163 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 2164 activations of the input network. 2165 num_sampled: An `int`. The number of classes to randomly sample per batch. 2166 num_classes: An `int`. The number of possible classes. 2167 num_true: An `int`. The number of target classes per training example. 2168 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2169 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2170 (if None, we default to `log_uniform_candidate_sampler`) 2171 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 2172 where a sampled class equals one of the target classes. Default is 2173 True. 2174 partition_strategy: A string specifying the partitioning strategy, relevant 2175 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 2176 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 2177 name: A name for the operation (optional). 2178 seed: random seed for candidate sampling. Default to None, which doesn't set 2179 the op-level random seed for candidate sampling. 2180 2181 Returns: 2182 A `batch_size` 1-D tensor of per-example sampled softmax losses. 2183 2184 References: 2185 On Using Very Large Target Vocabulary for Neural Machine Translation: 2186 [Jean et al., 2014] 2187 (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001) 2188 ([pdf](http://aclweb.org/anthology/P15-1001)) 2189 """ 2190 logits, labels = _compute_sampled_logits( 2191 weights=weights, 2192 biases=biases, 2193 labels=labels, 2194 inputs=inputs, 2195 num_sampled=num_sampled, 2196 num_classes=num_classes, 2197 num_true=num_true, 2198 sampled_values=sampled_values, 2199 subtract_log_q=True, 2200 remove_accidental_hits=remove_accidental_hits, 2201 partition_strategy=partition_strategy, 2202 name=name, 2203 seed=seed) 2204 labels = array_ops.stop_gradient(labels, name="labels_stop_gradient") 2205 sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2( 2206 labels=labels, logits=logits) 2207 # sampled_losses is a [batch_size] tensor. 2208 return sampled_losses 2209