1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sampling functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import logging_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import random_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.training import input as input_ops 31 32__all__ = [ 33 'rejection_sample', 34 'stratified_sample', 35] 36 37 38def rejection_sample(tensors, 39 accept_prob_fn, 40 batch_size, 41 queue_threads=1, 42 enqueue_many=False, 43 prebatch_capacity=16, 44 prebatch_threads=1, 45 runtime_checks=False, 46 name=None): 47 """Stochastically creates batches by rejection sampling. 48 49 Each list of non-batched tensors is evaluated by `accept_prob_fn`, to produce 50 a scalar tensor between 0 and 1. This tensor corresponds to the probability of 51 being accepted. When `batch_size` tensor groups have been accepted, the batch 52 queue will return a mini-batch. 53 54 Args: 55 tensors: List of tensors for data. All tensors are either one item or a 56 batch, according to enqueue_many. 57 accept_prob_fn: A python lambda that takes a non-batch tensor from each 58 item in `tensors`, and produces a scalar tensor. 59 batch_size: Size of batch to be returned. 60 queue_threads: The number of threads for the queue that will hold the final 61 batch. 62 enqueue_many: Bool. If true, interpret input tensors as having a batch 63 dimension. 64 prebatch_capacity: Capacity for the large queue that is used to convert 65 batched tensors to single examples. 66 prebatch_threads: Number of threads for the large queue that is used to 67 convert batched tensors to single examples. 68 runtime_checks: Bool. If true, insert runtime checks on the output of 69 `accept_prob_fn`. Using `True` might have a performance impact. 70 name: Optional prefix for ops created by this function. 71 Raises: 72 ValueError: enqueue_many is True and labels doesn't have a batch 73 dimension, or if enqueue_many is False and labels isn't a scalar. 74 ValueError: enqueue_many is True, and batch dimension on data and labels 75 don't match. 76 ValueError: if a zero initial probability class has a nonzero target 77 probability. 78 Returns: 79 A list of tensors of the same length as `tensors`, with batch dimension 80 `batch_size`. 81 82 Example: 83 # Get tensor for a single data and label example. 84 data, label = data_provider.Get(['data', 'label']) 85 86 # Get stratified batch according to data tensor. 87 accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2 88 data_batch = tf.contrib.training.rejection_sample( 89 [data, label], accept_prob_fn, 16) 90 91 # Run batch through network. 92 ... 93 """ 94 with variable_scope.variable_scope(name, 'rejection_sample', tensors): 95 tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) 96 # Reduce the case of a batched example to that of a batch of a single 97 # example by taking a batch of size one. 98 if enqueue_many: 99 # Validate that batch dimension of the input is consistent. 100 tensor_list = _verify_data_inputs(tensor_list) 101 102 # Make a single queue to hold input examples. Reshape output so examples 103 # don't have singleton batch dimension. 104 batched = input_ops.batch( 105 tensor_list, 106 batch_size=1, 107 num_threads=prebatch_threads, 108 capacity=prebatch_capacity, 109 enqueue_many=True) 110 tensor_list = [array_ops.squeeze(x, [0]) for x in batched] 111 112 # Set up a queue containing batches that have the distribution. 113 cur_prob = accept_prob_fn(tensor_list) 114 if runtime_checks: 115 cur_prob = array_ops.identity( 116 control_flow_ops.with_dependencies([ 117 check_ops.assert_less_equal(0.0, cur_prob), 118 check_ops.assert_less_equal(cur_prob, 1.0) 119 ], cur_prob), 120 name='prob_with_checks') 121 minibatch = input_ops.maybe_batch( 122 tensor_list, 123 keep_input=random_ops.random_uniform([]) < cur_prob, 124 batch_size=batch_size, 125 num_threads=queue_threads) 126 127 # Queues return a single tensor if the list of enqueued tensors is one. Since 128 # we want the type to always be the same, always return a list. 129 if isinstance(minibatch, ops.Tensor): 130 minibatch = [minibatch] 131 132 return minibatch 133 134 135def stratified_sample(tensors, 136 labels, 137 target_probs, 138 batch_size, 139 init_probs=None, 140 enqueue_many=False, 141 queue_capacity=16, 142 threads_per_queue=1, 143 name=None): 144 """Stochastically creates batches based on per-class probabilities. 145 146 This method discards examples. Internally, it creates one queue to amortize 147 the cost of disk reads, and one queue to hold the properly-proportioned 148 batch. 149 150 Args: 151 tensors: List of tensors for data. All tensors are either one item or a 152 batch, according to enqueue_many. 153 labels: Tensor for label of data. Label is a single integer or a batch, 154 depending on `enqueue_many`. It is not a one-hot vector. 155 target_probs: Target class proportions in batch. An object whose type has a 156 registered Tensor conversion function. 157 batch_size: Size of batch to be returned. 158 init_probs: Class proportions in the data. An object whose type has a 159 registered Tensor conversion function, or `None` for estimating the 160 initial distribution. 161 enqueue_many: Bool. If true, interpret input tensors as having a batch 162 dimension. 163 queue_capacity: Capacity of the large queue that holds input examples. 164 threads_per_queue: Number of threads for the large queue that holds input 165 examples and for the final queue with the proper class proportions. 166 name: Optional prefix for ops created by this function. 167 Raises: 168 ValueError: If `tensors` isn't iterable. 169 ValueError: `enqueue_many` is True and labels doesn't have a batch 170 dimension, or if `enqueue_many` is False and labels isn't a scalar. 171 ValueError: `enqueue_many` is True, and batch dimension on data and labels 172 don't match. 173 ValueError: if probs don't sum to one. 174 ValueError: if a zero initial probability class has a nonzero target 175 probability. 176 TFAssertion: if labels aren't integers in [0, num classes). 177 Returns: 178 (data_batch, label_batch), where data_batch is a list of tensors of the same 179 length as `tensors` 180 181 Example: 182 # Get tensor for a single data and label example. 183 data, label = data_provider.Get(['data', 'label']) 184 185 # Get stratified batch according to per-class probabilities. 186 target_probs = [...distribution you want...] 187 [data_batch], labels = tf.contrib.training.stratified_sample( 188 [data], label, target_probs) 189 190 # Run batch through network. 191 ... 192 """ 193 with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]): 194 tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors) 195 labels = ops.convert_to_tensor(labels) 196 target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32) 197 # Reduce the case of a single example to that of a batch of size 1. 198 if not enqueue_many: 199 tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list] 200 labels = array_ops.expand_dims(labels, 0) 201 202 # If `init_probs` is `None`, set up online estimation of data distribution. 203 if init_probs is None: 204 # We use `target_probs` to get the number of classes, so its shape must be 205 # fully defined at graph construction time. 206 target_probs.get_shape().assert_is_fully_defined() 207 init_probs = _estimate_data_distribution( 208 labels, target_probs.get_shape().num_elements()) 209 else: 210 init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32) 211 212 # Validate that input is consistent. 213 tensor_list, labels, [init_probs, target_probs] = _verify_input( 214 tensor_list, labels, [init_probs, target_probs]) 215 216 # Check that all zero initial probabilities also have zero target 217 # probabilities. 218 assert_op = control_flow_ops.Assert( 219 math_ops.reduce_all( 220 math_ops.logical_or( 221 math_ops.not_equal(init_probs, 0), 222 math_ops.equal(target_probs, 0))), 223 ['All classes with zero initial probability must also have zero target ' 224 'probability: ', init_probs, target_probs 225 ]) 226 init_probs = control_flow_ops.with_dependencies([assert_op], init_probs) 227 228 # Calculate acceptance sampling probabilities. 229 accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs) 230 proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs) 231 accept_probs = control_flow_ops.cond( 232 math_ops.less(proportion_rejected, .5), 233 lambda: accept_probs, 234 lambda: logging_ops.Print( # pylint: disable=g-long-lambda 235 accept_probs, [accept_probs], 236 message='Proportion of examples rejected by sampler is high.', 237 first_n=10)) 238 239 # Make a single queue to hold input examples. Reshape output so examples 240 # don't have singleton batch dimension. 241 batched = input_ops.batch( 242 tensor_list + [labels], 243 batch_size=1, 244 num_threads=threads_per_queue, 245 capacity=queue_capacity, 246 enqueue_many=True) 247 val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]] 248 label = array_ops.squeeze(batched[-1], [0]) 249 250 # Set up second queue containing batches that have the desired class 251 # proportions. 252 cur_prob = array_ops.gather(accept_probs, label) 253 batched = input_ops.maybe_batch( 254 val_list + [label], 255 keep_input=random_ops.random_uniform([]) < cur_prob, 256 batch_size=batch_size, 257 num_threads=threads_per_queue) 258 return batched[:-1], batched[-1] 259 260 261def _estimate_data_distribution(labels, num_classes, smoothing_constant=10): 262 """Estimate data distribution as labels are seen.""" 263 # Variable to track running count of classes. Smooth by a nonzero value to 264 # avoid division-by-zero. Higher values provide more stability at the cost of 265 # slower convergence. 266 if smoothing_constant <= 0: 267 raise ValueError('smoothing_constant must be nonzero.') 268 num_examples_per_class_seen = variable_scope.variable( 269 initial_value=[smoothing_constant] * num_classes, 270 trainable=False, 271 name='class_count', 272 dtype=dtypes.int64) 273 274 # Update the class-count based on what labels are seen in batch. 275 num_examples_per_class_seen = num_examples_per_class_seen.assign_add( 276 math_ops.reduce_sum( 277 array_ops.one_hot( 278 labels, num_classes, dtype=dtypes.int64), 0)) 279 280 # Normalize count into a probability. 281 # NOTE: Without the `+= 0` line below, the test 282 # `testMultiThreadedEstimateDataDistribution` fails. The reason is that 283 # before this line, `num_examples_per_class_seen` is a Tensor that shares a 284 # buffer with an underlying `ref` object. When the `ref` is changed by another 285 # thread, `num_examples_per_class_seen` changes as well. Since this can happen 286 # in the middle of the normalization computation, we get probabilities that 287 # are very far from summing to one. Adding `+= 0` copies the contents of the 288 # tensor to a new buffer, which will be consistent from the start to the end 289 # of the normalization computation. 290 num_examples_per_class_seen += 0 291 init_prob_estimate = math_ops.truediv( 292 num_examples_per_class_seen, 293 math_ops.reduce_sum(num_examples_per_class_seen)) 294 295 # Must return float32 (not float64) to agree with downstream `_verify_input` 296 # checks. 297 return math_ops.cast(init_prob_estimate, dtypes.float32) 298 299 300def _verify_data_inputs(tensor_list): 301 """Verify that batched data inputs are well-formed.""" 302 for tensor in tensor_list: 303 # Data tensor should have a batch dimension. 304 shape = tensor.get_shape().with_rank_at_least(1) 305 306 # Data batch dimensions must be compatible. 307 tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with( 308 tensor_list[0].get_shape()[0]) 309 310 return tensor_list 311 312 313def _verify_input(tensor_list, labels, probs_list): 314 """Verify that batched inputs are well-formed.""" 315 checked_probs_list = [] 316 for probs in probs_list: 317 # Since number of classes shouldn't change at runtime, probabilities shape 318 # should be fully defined. 319 probs.get_shape().assert_is_fully_defined() 320 321 # Probabilities must be 1D. 322 probs.get_shape().assert_has_rank(1) 323 324 # Probabilities must be nonnegative and sum to one. 325 tol = 1e-6 326 prob_sum = math_ops.reduce_sum(probs) 327 checked_probs = control_flow_ops.with_dependencies([ 328 check_ops.assert_non_negative(probs), 329 check_ops.assert_less(prob_sum, 1.0 + tol), 330 check_ops.assert_less(1.0 - tol, prob_sum) 331 ], probs) 332 checked_probs_list.append(checked_probs) 333 334 # All probabilities should be the same length. 335 prob_length = checked_probs_list[0].get_shape().num_elements() 336 for checked_prob in checked_probs_list: 337 if checked_prob.get_shape().num_elements() != prob_length: 338 raise ValueError('Probability parameters must have the same length.') 339 340 # Labels tensor should only have batch dimension. 341 labels.get_shape().assert_has_rank(1) 342 343 for tensor in tensor_list: 344 # Data tensor should have a batch dimension. 345 shape = tensor.get_shape().with_rank_at_least(1) 346 347 # Data and label batch dimensions must be compatible. 348 tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with( 349 labels.get_shape()[0]) 350 351 # Data and labels must have the same, strictly positive batch size. Since we 352 # can't assume we know the batch size at graph creation, add runtime checks. 353 labels_batch_size = array_ops.shape(labels)[0] 354 lbl_assert = check_ops.assert_positive(labels_batch_size) 355 356 # Make each tensor depend on its own checks. 357 labels = control_flow_ops.with_dependencies([lbl_assert], labels) 358 tensor_list = [ 359 control_flow_ops.with_dependencies([ 360 lbl_assert, 361 check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size) 362 ], x) for x in tensor_list 363 ] 364 365 # Label's classes must be integers 0 <= x < num_classes. 366 labels = control_flow_ops.with_dependencies([ 367 check_ops.assert_integer(labels), check_ops.assert_non_negative(labels), 368 check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype)) 369 ], labels) 370 371 return tensor_list, labels, checked_probs_list 372 373 374def _calculate_acceptance_probabilities(init_probs, target_probs): 375 """Calculate the per-class acceptance rates. 376 377 Args: 378 init_probs: The class probabilities of the data. 379 target_probs: The desired class proportion in minibatches. 380 Returns: 381 A list of the per-class acceptance probabilities. 382 383 This method is based on solving the following analysis: 384 385 Let F be the probability of a rejection (on any example). 386 Let p_i be the proportion of examples in the data in class i (init_probs) 387 Let a_i is the rate the rejection sampler should *accept* class i 388 Let t_i is the target proportion in the minibatches for class i (target_probs) 389 390 ``` 391 F = sum_i(p_i * (1-a_i)) 392 = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 393 ``` 394 395 An example with class `i` will be accepted if `k` rejections occur, then an 396 example with class `i` is seen by the rejector, and it is accepted. This can 397 be written as follows: 398 399 ``` 400 t_i = sum_k=0^inf(F^k * p_i * a_i) 401 = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 402 = p_i * a_i / sum_j(p_j * a_j) using F from above 403 ``` 404 405 Note that the following constraints hold: 406 ``` 407 0 <= p_i <= 1, sum_i(p_i) = 1 408 0 <= a_i <= 1 409 0 <= t_i <= 1, sum_i(t_i) = 1 410 ``` 411 412 413 A solution for a_i in terms of the other variables is the following: 414 ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` 415 """ 416 # Make list of t_i / p_i. 417 ratio_l = target_probs / init_probs 418 419 # Replace NaNs with 0s. 420 ratio_l = array_ops.where( 421 math_ops.is_nan(ratio_l), array_ops.zeros_like(ratio_l), ratio_l) 422 423 # Calculate list of acceptance probabilities. 424 max_ratio = math_ops.reduce_max(ratio_l) 425 return ratio_l / max_ratio 426