• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sampling functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import logging_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.training import input as input_ops
31
32__all__ = [
33    'rejection_sample',
34    'stratified_sample',
35]
36
37
38def rejection_sample(tensors,
39                     accept_prob_fn,
40                     batch_size,
41                     queue_threads=1,
42                     enqueue_many=False,
43                     prebatch_capacity=16,
44                     prebatch_threads=1,
45                     runtime_checks=False,
46                     name=None):
47  """Stochastically creates batches by rejection sampling.
48
49  Each list of non-batched tensors is evaluated by `accept_prob_fn`, to produce
50  a scalar tensor between 0 and 1. This tensor corresponds to the probability of
51  being accepted. When `batch_size` tensor groups have been accepted, the batch
52  queue will return a mini-batch.
53
54  Args:
55    tensors: List of tensors for data. All tensors are either one item or a
56        batch, according to enqueue_many.
57    accept_prob_fn: A python lambda that takes a non-batch tensor from each
58        item in `tensors`, and produces a scalar tensor.
59    batch_size: Size of batch to be returned.
60    queue_threads: The number of threads for the queue that will hold the final
61      batch.
62    enqueue_many: Bool. If true, interpret input tensors as having a batch
63        dimension.
64    prebatch_capacity: Capacity for the large queue that is used to convert
65      batched tensors to single examples.
66    prebatch_threads: Number of threads for the large queue that is used to
67      convert batched tensors to single examples.
68    runtime_checks: Bool. If true, insert runtime checks on the output of
69        `accept_prob_fn`. Using `True` might have a performance impact.
70    name: Optional prefix for ops created by this function.
71  Raises:
72    ValueError: enqueue_many is True and labels doesn't have a batch
73        dimension, or if enqueue_many is False and labels isn't a scalar.
74    ValueError: enqueue_many is True, and batch dimension on data and labels
75        don't match.
76    ValueError: if a zero initial probability class has a nonzero target
77        probability.
78  Returns:
79    A list of tensors of the same length as `tensors`, with batch dimension
80    `batch_size`.
81
82  Example:
83    # Get tensor for a single data and label example.
84    data, label = data_provider.Get(['data', 'label'])
85
86    # Get stratified batch according to data tensor.
87    accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2
88    data_batch = tf.contrib.training.rejection_sample(
89        [data, label], accept_prob_fn, 16)
90
91    # Run batch through network.
92    ...
93  """
94  with variable_scope.variable_scope(name, 'rejection_sample', tensors):
95    tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
96    # Reduce the case of a batched example to that of a batch of a single
97    # example by taking a batch of size one.
98    if enqueue_many:
99      # Validate that batch dimension of the input is consistent.
100      tensor_list = _verify_data_inputs(tensor_list)
101
102      # Make a single queue to hold input examples. Reshape output so examples
103      # don't have singleton batch dimension.
104      batched = input_ops.batch(
105          tensor_list,
106          batch_size=1,
107          num_threads=prebatch_threads,
108          capacity=prebatch_capacity,
109          enqueue_many=True)
110      tensor_list = [array_ops.squeeze(x, [0]) for x in batched]
111
112    # Set up a queue containing batches that have the distribution.
113    cur_prob = accept_prob_fn(tensor_list)
114    if runtime_checks:
115      cur_prob = array_ops.identity(
116          control_flow_ops.with_dependencies([
117              check_ops.assert_less_equal(0.0, cur_prob),
118              check_ops.assert_less_equal(cur_prob, 1.0)
119          ], cur_prob),
120          name='prob_with_checks')
121    minibatch = input_ops.maybe_batch(
122        tensor_list,
123        keep_input=random_ops.random_uniform([]) < cur_prob,
124        batch_size=batch_size,
125        num_threads=queue_threads)
126
127    # Queues return a single tensor if the list of enqueued tensors is one. Since
128    # we want the type to always be the same, always return a list.
129    if isinstance(minibatch, ops.Tensor):
130      minibatch = [minibatch]
131
132    return minibatch
133
134
135def stratified_sample(tensors,
136                      labels,
137                      target_probs,
138                      batch_size,
139                      init_probs=None,
140                      enqueue_many=False,
141                      queue_capacity=16,
142                      threads_per_queue=1,
143                      name=None):
144  """Stochastically creates batches based on per-class probabilities.
145
146  This method discards examples. Internally, it creates one queue to amortize
147  the cost of disk reads, and one queue to hold the properly-proportioned
148  batch.
149
150  Args:
151    tensors: List of tensors for data. All tensors are either one item or a
152        batch, according to enqueue_many.
153    labels: Tensor for label of data. Label is a single integer or a batch,
154        depending on `enqueue_many`. It is not a one-hot vector.
155    target_probs: Target class proportions in batch. An object whose type has a
156        registered Tensor conversion function.
157    batch_size: Size of batch to be returned.
158    init_probs: Class proportions in the data. An object whose type has a
159        registered Tensor conversion function, or `None` for estimating the
160        initial distribution.
161    enqueue_many: Bool. If true, interpret input tensors as having a batch
162        dimension.
163    queue_capacity: Capacity of the large queue that holds input examples.
164    threads_per_queue: Number of threads for the large queue that holds input
165        examples and for the final queue with the proper class proportions.
166    name: Optional prefix for ops created by this function.
167  Raises:
168    ValueError: If `tensors` isn't iterable.
169    ValueError: `enqueue_many` is True and labels doesn't have a batch
170        dimension, or if `enqueue_many` is False and labels isn't a scalar.
171    ValueError: `enqueue_many` is True, and batch dimension on data and labels
172        don't match.
173    ValueError: if probs don't sum to one.
174    ValueError: if a zero initial probability class has a nonzero target
175        probability.
176    TFAssertion: if labels aren't integers in [0, num classes).
177  Returns:
178    (data_batch, label_batch), where data_batch is a list of tensors of the same
179        length as `tensors`
180
181  Example:
182    # Get tensor for a single data and label example.
183    data, label = data_provider.Get(['data', 'label'])
184
185    # Get stratified batch according to per-class probabilities.
186    target_probs = [...distribution you want...]
187    [data_batch], labels = tf.contrib.training.stratified_sample(
188        [data], label, target_probs)
189
190    # Run batch through network.
191    ...
192  """
193  with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
194    tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
195    labels = ops.convert_to_tensor(labels)
196    target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
197    # Reduce the case of a single example to that of a batch of size 1.
198    if not enqueue_many:
199      tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
200      labels = array_ops.expand_dims(labels, 0)
201
202    # If `init_probs` is `None`, set up online estimation of data distribution.
203    if init_probs is None:
204      # We use `target_probs` to get the number of classes, so its shape must be
205      # fully defined at graph construction time.
206      target_probs.get_shape().assert_is_fully_defined()
207      init_probs = _estimate_data_distribution(
208          labels, target_probs.get_shape().num_elements())
209    else:
210      init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
211
212    # Validate that input is consistent.
213    tensor_list, labels, [init_probs, target_probs] = _verify_input(
214        tensor_list, labels, [init_probs, target_probs])
215
216    # Check that all zero initial probabilities also have zero target
217    # probabilities.
218    assert_op = control_flow_ops.Assert(
219        math_ops.reduce_all(
220            math_ops.logical_or(
221                math_ops.not_equal(init_probs, 0),
222                math_ops.equal(target_probs, 0))),
223        ['All classes with zero initial probability must also have zero target '
224         'probability: ', init_probs, target_probs
225        ])
226    init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)
227
228    # Calculate acceptance sampling probabilities.
229    accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs)
230    proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs)
231    accept_probs = control_flow_ops.cond(
232        math_ops.less(proportion_rejected, .5),
233        lambda: accept_probs,
234        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
235            accept_probs, [accept_probs],
236            message='Proportion of examples rejected by sampler is high.',
237            first_n=10))
238
239    # Make a single queue to hold input examples. Reshape output so examples
240    # don't have singleton batch dimension.
241    batched = input_ops.batch(
242        tensor_list + [labels],
243        batch_size=1,
244        num_threads=threads_per_queue,
245        capacity=queue_capacity,
246        enqueue_many=True)
247    val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
248    label = array_ops.squeeze(batched[-1], [0])
249
250    # Set up second queue containing batches that have the desired class
251    # proportions.
252    cur_prob = array_ops.gather(accept_probs, label)
253    batched = input_ops.maybe_batch(
254        val_list + [label],
255        keep_input=random_ops.random_uniform([]) < cur_prob,
256        batch_size=batch_size,
257        num_threads=threads_per_queue)
258    return batched[:-1], batched[-1]
259
260
261def _estimate_data_distribution(labels, num_classes, smoothing_constant=10):
262  """Estimate data distribution as labels are seen."""
263  # Variable to track running count of classes. Smooth by a nonzero value to
264  # avoid division-by-zero. Higher values provide more stability at the cost of
265  # slower convergence.
266  if smoothing_constant <= 0:
267    raise ValueError('smoothing_constant must be nonzero.')
268  num_examples_per_class_seen = variable_scope.variable(
269      initial_value=[smoothing_constant] * num_classes,
270      trainable=False,
271      name='class_count',
272      dtype=dtypes.int64)
273
274  # Update the class-count based on what labels are seen in batch.
275  num_examples_per_class_seen = num_examples_per_class_seen.assign_add(
276      math_ops.reduce_sum(
277          array_ops.one_hot(
278              labels, num_classes, dtype=dtypes.int64), 0))
279
280  # Normalize count into a probability.
281  # NOTE: Without the `+= 0` line below, the test
282  # `testMultiThreadedEstimateDataDistribution` fails. The reason is that
283  # before this line, `num_examples_per_class_seen` is a Tensor that shares a
284  # buffer with an underlying `ref` object. When the `ref` is changed by another
285  # thread, `num_examples_per_class_seen` changes as well. Since this can happen
286  # in the middle of the normalization computation, we get probabilities that
287  # are very far from summing to one. Adding `+= 0` copies the contents of the
288  # tensor to a new buffer, which will be consistent from the start to the end
289  # of the normalization computation.
290  num_examples_per_class_seen += 0
291  init_prob_estimate = math_ops.truediv(
292      num_examples_per_class_seen,
293      math_ops.reduce_sum(num_examples_per_class_seen))
294
295  # Must return float32 (not float64) to agree with downstream `_verify_input`
296  # checks.
297  return math_ops.cast(init_prob_estimate, dtypes.float32)
298
299
300def _verify_data_inputs(tensor_list):
301  """Verify that batched data inputs are well-formed."""
302  for tensor in tensor_list:
303    # Data tensor should have a batch dimension.
304    shape = tensor.get_shape().with_rank_at_least(1)
305
306    # Data batch dimensions must be compatible.
307    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
308        tensor_list[0].get_shape()[0])
309
310  return tensor_list
311
312
313def _verify_input(tensor_list, labels, probs_list):
314  """Verify that batched inputs are well-formed."""
315  checked_probs_list = []
316  for probs in probs_list:
317    # Since number of classes shouldn't change at runtime, probabilities shape
318    # should be fully defined.
319    probs.get_shape().assert_is_fully_defined()
320
321    # Probabilities must be 1D.
322    probs.get_shape().assert_has_rank(1)
323
324    # Probabilities must be nonnegative and sum to one.
325    tol = 1e-6
326    prob_sum = math_ops.reduce_sum(probs)
327    checked_probs = control_flow_ops.with_dependencies([
328        check_ops.assert_non_negative(probs),
329        check_ops.assert_less(prob_sum, 1.0 + tol),
330        check_ops.assert_less(1.0 - tol, prob_sum)
331    ], probs)
332    checked_probs_list.append(checked_probs)
333
334  # All probabilities should be the same length.
335  prob_length = checked_probs_list[0].get_shape().num_elements()
336  for checked_prob in checked_probs_list:
337    if checked_prob.get_shape().num_elements() != prob_length:
338      raise ValueError('Probability parameters must have the same length.')
339
340  # Labels tensor should only have batch dimension.
341  labels.get_shape().assert_has_rank(1)
342
343  for tensor in tensor_list:
344    # Data tensor should have a batch dimension.
345    shape = tensor.get_shape().with_rank_at_least(1)
346
347    # Data and label batch dimensions must be compatible.
348    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
349        labels.get_shape()[0])
350
351  # Data and labels must have the same, strictly positive batch size. Since we
352  # can't assume we know the batch size at graph creation, add runtime checks.
353  labels_batch_size = array_ops.shape(labels)[0]
354  lbl_assert = check_ops.assert_positive(labels_batch_size)
355
356  # Make each tensor depend on its own checks.
357  labels = control_flow_ops.with_dependencies([lbl_assert], labels)
358  tensor_list = [
359      control_flow_ops.with_dependencies([
360          lbl_assert,
361          check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
362      ], x) for x in tensor_list
363  ]
364
365  # Label's classes must be integers 0 <= x < num_classes.
366  labels = control_flow_ops.with_dependencies([
367      check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
368      check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
369  ], labels)
370
371  return tensor_list, labels, checked_probs_list
372
373
374def _calculate_acceptance_probabilities(init_probs, target_probs):
375  """Calculate the per-class acceptance rates.
376
377  Args:
378    init_probs: The class probabilities of the data.
379    target_probs: The desired class proportion in minibatches.
380  Returns:
381    A list of the per-class acceptance probabilities.
382
383  This method is based on solving the following analysis:
384
385  Let F be the probability of a rejection (on any example).
386  Let p_i be the proportion of examples in the data in class i (init_probs)
387  Let a_i is the rate the rejection sampler should *accept* class i
388  Let t_i is the target proportion in the minibatches for class i (target_probs)
389
390  ```
391  F = sum_i(p_i * (1-a_i))
392    = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
393  ```
394
395  An example with class `i` will be accepted if `k` rejections occur, then an
396  example with class `i` is seen by the rejector, and it is accepted. This can
397  be written as follows:
398
399  ```
400  t_i = sum_k=0^inf(F^k * p_i * a_i)
401      = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
402      = p_i * a_i / sum_j(p_j * a_j)        using F from above
403  ```
404
405  Note that the following constraints hold:
406  ```
407  0 <= p_i <= 1, sum_i(p_i) = 1
408  0 <= a_i <= 1
409  0 <= t_i <= 1, sum_i(t_i) = 1
410  ```
411
412
413  A solution for a_i in terms of the other variables is the following:
414    ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
415  """
416  # Make list of t_i / p_i.
417  ratio_l = target_probs / init_probs
418
419  # Replace NaNs with 0s.
420  ratio_l = array_ops.where(
421      math_ops.is_nan(ratio_l), array_ops.zeros_like(ratio_l), ratio_l)
422
423  # Calculate list of acceptance probabilities.
424  max_ratio = math_ops.reduce_max(ratio_l)
425  return ratio_l / max_ratio
426