1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops related to candidate sampling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import embedding_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27from tensorflow.python.ops import nn_impl 28from tensorflow.python.ops import nn_ops 29 30 31def _rank_resample(weights, biases, inputs, sampled_values, num_resampled, 32 resampling_temperature, partition_strategy): 33 """A helper function for rank_sampled_softmax_loss. 34 35 This computes, for each i in `sampled_values`, 36 37 log(sum_j exp((w_i * x_j + b_i) / resampling_temperature)) 38 39 where w_i, b_i are the weight and bias of the i-th class, respectively, 40 and j ranges over the rows of `inputs`. For efficiency, we rearrange the 41 computation to 42 43 log(sum_j exp(w_i * (x_j / resampling_temperature))) + 44 b_i / resampling_temperature. 45 46 This translates to the following batched computation using tensorflow ops: 47 48 reduce_logsumexp(matmul(embeddings, 49 transpose(inputs / resampling_temperature))) + 50 biases / resampling_temperature 51 52 The computation of the first term is colocated with the embeddings using 53 `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second 54 term, not the bottleneck, is computed at the worker. 55 56 Args: 57 weights: From `rank_sampled_softmax_loss`. 58 biases: From `rank_sampled_softmax_loss`. 59 inputs: From `rank_sampled_softmax_loss`. 60 sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`, 61 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 62 num_resampled: An `int`. This many values are selected from 63 `sampled_values` using the adaptive resampling algorithm. The caller 64 must ensure that `num_resampled` is less than the size of 65 `sampled_values`. 66 resampling_temperature: A scalar `Tensor` with the temperature parameter 67 for the adaptive resampling algorithm. 68 partition_strategy: From `rank_sampled_softmax_loss`. 69 70 Returns: 71 A tuple of (`resampled_candidates`, `true_expected_count`, 72 `resampled_expected_count`), similar to `sampled_values` but sampled 73 down to `num_resampled` values. 74 """ 75 # This code supports passing a Tensor for num_resampled, but since it is only 76 # called with an int, that's what we specify in the arg list. If this 77 # function is ever externalized, we should change the doc to support Tensor. 78 79 sampled, true_expected_count, sampled_expected_count = sampled_values 80 81 sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64) 82 true_expected_count = array_ops.stop_gradient(true_expected_count) 83 sampled_expected_count = array_ops.stop_gradient(sampled_expected_count) 84 85 reweighted_inputs = inputs / resampling_temperature 86 87 def logsumexp_logit(embeddings): 88 return math_ops.reduce_logsumexp( 89 math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True), 90 axis=1, 91 keepdims=False) 92 93 # Calling this protected form of embedding_lookup allows co-locating 94 # the logsumexp computation with the partitioned weights, which yields 95 # a large speedup in practice. 96 sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access 97 weights, sampled, partition_strategy, transform_fn=logsumexp_logit) 98 sampled_b = array_ops.reshape( 99 embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1]) 100 sampled_logits += sampled_b / resampling_temperature 101 102 _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False) 103 resampled = array_ops.gather(sampled, indices=resampled_indices) 104 resampled_expected_count = array_ops.gather( 105 sampled_expected_count, indices=resampled_indices) 106 107 return resampled, true_expected_count, resampled_expected_count 108 109 110def rank_sampled_softmax_loss(weights, 111 biases, 112 labels, 113 inputs, 114 num_sampled, 115 num_resampled, 116 num_classes, 117 num_true, 118 sampled_values, 119 resampling_temperature, 120 remove_accidental_hits, 121 partition_strategy, 122 name=None): 123 """Computes softmax loss using rank-based adaptive resampling. 124 125 This has been shown to improve rank loss after training compared to 126 `tf.nn.sampled_softmax_loss`. For a description of the algorithm and some 127 experimental results, please see: [TAPAS: Two-pass Approximate Adaptive 128 Sampling for Softmax](https://arxiv.org/abs/1707.03073). 129 130 Sampling follows two phases: 131 * In the first phase, `num_sampled` classes are selected using 132 `tf.nn.learned_unigram_candidate_sampler` or supplied `sampled_values`. 133 The logits are calculated on those sampled classes. This phases is 134 similar to `tf.nn.sampled_softmax_loss`. 135 * In the second phase, the `num_resampled` classes with highest predicted 136 probability are kept. Probabilities are 137 `LogSumExp(logits / resampling_temperature)`, where the sum is over 138 `inputs`. 139 140 The `resampling_temperature` parameter controls the "adaptiveness" of the 141 resampling. At lower temperatures, resampling is more adaptive because it 142 picks more candidates close to the predicted classes. A common strategy is 143 to decrease the temperature as training proceeds. 144 145 See `tf.nn.sampled_softmax_loss` for more documentation on sampling and 146 for typical default values for some of the parameters. 147 148 This operation is for training only. It is generally an underestimate of 149 the full softmax loss. 150 151 A common use case is to use this method for training, and calculate the full 152 softmax loss for evaluation or inference. In this case, you must set 153 `partition_strategy="div"` for the two losses to be consistent, as in the 154 following example: 155 156 ```python 157 if mode == "train": 158 loss = rank_sampled_softmax_loss( 159 weights=weights, 160 biases=biases, 161 labels=labels, 162 inputs=inputs, 163 ..., 164 partition_strategy="div") 165 elif mode == "eval": 166 logits = tf.matmul(inputs, tf.transpose(weights)) 167 logits = tf.nn.bias_add(logits, biases) 168 labels_one_hot = tf.one_hot(labels, n_classes) 169 loss = tf.nn.softmax_cross_entropy_with_logits( 170 labels=labels_one_hot, 171 logits=logits) 172 ``` 173 174 Args: 175 weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`, 176 or a list of `Tensor` objects whose concatenation along dimension 0 177 has shape [num_classes, dim]. The (possibly-sharded) class embeddings. 178 biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`. 179 The (possibly-sharded) class biases. 180 labels: A `Tensor` of type `int64` and shape `[batch_size, 181 num_true]`. The target classes. Note that this format differs from 182 the `labels` argument of `nn.softmax_cross_entropy_with_logits`. 183 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 184 activations of the input network. 185 num_sampled: An `int`. The number of classes to randomly sample per batch. 186 num_resampled: An `int`. The number of classes to select from the 187 `num_sampled` classes using the adaptive resampling algorithm. Must be 188 less than `num_sampled`. 189 num_classes: An `int`. The number of possible classes. 190 num_true: An `int`. The number of target classes per training example. 191 sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`, 192 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 193 If None, default to `nn.learned_unigram_candidate_sampler`. 194 resampling_temperature: A scalar `Tensor` with the temperature parameter 195 for the adaptive resampling algorithm. 196 remove_accidental_hits: A `bool`. Whether to remove "accidental hits" 197 where a sampled class equals one of the target classes. 198 partition_strategy: A string specifying the partitioning strategy, relevant 199 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 200 See `tf.nn.embedding_lookup` for more details. 201 name: A name for the operation (optional). 202 203 Returns: 204 A `batch_size` 1-D tensor of per-example sampled softmax losses. 205 206 Raises: 207 ValueError: If `num_sampled <= num_resampled`. 208 """ 209 if num_sampled > num_classes: 210 raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})". 211 format(num_sampled, num_classes)) 212 if num_sampled <= num_resampled: 213 raise ValueError("num_resampled ({}) must be less than num_sampled ({})". 214 format(num_resampled, num_sampled)) 215 if partition_strategy not in ("div", "mod"): 216 raise ValueError( 217 "unsupported partition_strategy ({})".format(partition_strategy)) 218 with ops.name_scope(name, "rank_sampled_softmax_loss", [ 219 weights, biases, labels, inputs, sampled_values, resampling_temperature 220 ]) as name: 221 if not sampled_values: 222 sampled_values = nn.learned_unigram_candidate_sampler( 223 true_classes=labels, 224 num_true=num_true, 225 num_sampled=num_sampled, 226 unique=True, 227 range_max=num_classes) 228 # From sampled_values, select the top num_resampled values using the 229 # adaptive rank resampling strategy. 230 resampled_values = _rank_resample(weights, biases, inputs, sampled_values, 231 num_resampled, resampling_temperature, 232 partition_strategy) 233 return nn.sampled_softmax_loss( 234 weights=weights, 235 biases=biases, 236 labels=labels, 237 inputs=inputs, 238 num_sampled=num_resampled, 239 num_classes=num_classes, 240 num_true=num_true, 241 sampled_values=resampled_values, 242 remove_accidental_hits=remove_accidental_hits, 243 partition_strategy=partition_strategy, 244 name=name) 245 246 247def sampled_sparse_softmax_loss(weights, 248 biases, 249 labels, 250 inputs, 251 num_sampled, 252 num_classes, 253 sampled_values=None, 254 remove_accidental_hits=True, 255 partition_strategy="mod", 256 name="sampled_sparse_softmax_loss"): 257 """Computes and returns the sampled sparse softmax training loss. 258 259 This is a faster way to train a softmax classifier over a huge number of 260 classes. 261 262 This operation is for training only. It is generally an underestimate of 263 the full softmax loss. 264 265 A common use case is to use this method for training, and calculate the full 266 softmax loss for evaluation or inference. In this case, you must set 267 `partition_strategy="div"` for the two losses to be consistent, as in the 268 following example: 269 270 ```python 271 if mode == "train": 272 loss = tf.nn.sampled_sparse_softmax_loss( 273 weights=weights, 274 biases=biases, 275 labels=labels, 276 inputs=inputs, 277 ..., 278 partition_strategy="div") 279 elif mode == "eval": 280 logits = tf.matmul(inputs, tf.transpose(weights)) 281 logits = tf.nn.bias_add(logits, biases) 282 loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 283 labels=tf.squeeze(labels), 284 logits=logits) 285 ``` 286 287 See our [Candidate Sampling Algorithms Reference] 288 (https://www.tensorflow.org/extras/candidate_sampling.pdf) 289 290 Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007) 291 ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math. 292 293 Args: 294 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 295 objects whose concatenation along dimension 0 has shape 296 [num_classes, dim]. The (possibly-sharded) class embeddings. 297 biases: A `Tensor` of shape `[num_classes]`. The class biases. 298 labels: A `Tensor` of type `int64` and shape `[batch_size, 1]`. 299 The index of the single target class for each row of logits. Note that 300 this format differs from the `labels` argument of 301 `nn.sparse_softmax_cross_entropy_with_logits`. 302 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 303 activations of the input network. 304 num_sampled: An `int`. The number of classes to randomly sample per batch. 305 num_classes: An `int`. The number of possible classes. 306 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 307 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 308 (if None, we default to `log_uniform_candidate_sampler`) 309 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 310 where a sampled class equals one of the target classes. Default is 311 True. 312 partition_strategy: A string specifying the partitioning strategy, relevant 313 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 314 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 315 name: A name for the operation (optional). 316 317 Returns: 318 A `batch_size` 1-D tensor of per-example sampled softmax losses. 319 320 """ 321 logits, _ = nn_impl._compute_sampled_logits( 322 weights=weights, 323 biases=biases, 324 labels=labels, 325 inputs=inputs, 326 num_sampled=num_sampled, 327 num_classes=num_classes, 328 num_true=1, 329 sampled_values=sampled_values, 330 subtract_log_q=True, 331 remove_accidental_hits=remove_accidental_hits, 332 partition_strategy=partition_strategy, 333 name=name) 334 335 # There is only one true label. _compute_sampled_logits puts the true logit 336 # at index 0. 337 labels = array_ops.zeros([array_ops.shape(logits)[0], 1], dtype=dtypes.int64) 338 339 sampled_losses = nn_ops.sparse_softmax_cross_entropy_with_logits( 340 labels=array_ops.squeeze(labels), logits=logits) 341 # sampled_losses is a [batch_size] tensor. 342 return sampled_losses 343