• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Resampling dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.data.python.ops import batching
23from tensorflow.contrib.data.python.ops import scan_ops
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import logging_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import random_ops
32
33
34def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
35  """A transformation that resamples a dataset to achieve a target distribution.
36
37  **NOTE** Resampling is performed via rejection sampling; some fraction
38  of the input values will be dropped.
39
40  Args:
41    class_func: A function mapping an element of the input dataset to a scalar
42      `tf.int32` tensor. Values should be in `[0, num_classes)`.
43    target_dist: A floating point type tensor, shaped `[num_classes]`.
44    initial_dist: (Optional.)  A floating point type tensor, shaped
45      `[num_classes]`.  If not provided, the true class distribution is
46      estimated live in a streaming fashion.
47    seed: (Optional.) Python integer seed for the resampler.
48
49  Returns:
50    A `Dataset` transformation function, which can be passed to
51    @{tf.data.Dataset.apply}.
52  """
53
54  def _apply_fn(dataset):
55    """Function from `Dataset` to `Dataset` that applies the transformation."""
56    dist_estimation_batch_size = 32
57    target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist")
58    class_values_ds = dataset.map(class_func)
59    if initial_dist is not None:
60      initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
61      acceptance_dist = _calculate_acceptance_probs(initial_dist_t,
62                                                    target_dist_t)
63      initial_dist_ds = dataset_ops.Dataset.from_tensors(
64          initial_dist_t).repeat()
65      acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
66          acceptance_dist).repeat()
67    else:
68      num_classes = (target_dist_t.shape[0].value or
69                     array_ops.shape(target_dist_t)[0])
70      smoothing_constant = 10
71      initial_examples_per_class_seen = array_ops.fill(
72          [num_classes], np.int64(smoothing_constant))
73
74      def update_estimate_and_tile(num_examples_per_class_seen, c):
75        updated_examples_per_class_seen, dist = _estimate_data_distribution(
76            c, num_examples_per_class_seen)
77        tiled_dist = array_ops.tile(
78            array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
79        return updated_examples_per_class_seen, tiled_dist
80
81      initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
82                         .apply(scan_ops.scan(initial_examples_per_class_seen,
83                                              update_estimate_and_tile))
84                         .apply(batching.unbatch()))
85      acceptance_dist_ds = initial_dist_ds.map(
86          lambda initial: _calculate_acceptance_probs(initial, target_dist_t))
87
88    def maybe_warn_on_large_rejection(accept_dist, initial_dist):
89      proportion_rejected = math_ops.reduce_sum(
90          (1 - accept_dist) * initial_dist)
91      return control_flow_ops.cond(
92          math_ops.less(proportion_rejected, .5),
93          lambda: accept_dist,
94          lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
95              accept_dist, [proportion_rejected, initial_dist, accept_dist],
96              message="Proportion of examples rejected by sampler is high: ",
97              summarize=100,
98              first_n=10))
99
100    acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
101                                                   initial_dist_ds))
102                          .map(maybe_warn_on_large_rejection))
103
104    current_probabilities_ds = dataset_ops.Dataset.zip(
105        (acceptance_dist_ds, class_values_ds)).map(array_ops.gather)
106    filtered_ds = (
107        dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds,
108                                 dataset))
109        .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
110    return filtered_ds.map(lambda class_value, _, data: (class_value, data))
111
112  return _apply_fn
113
114
115def _calculate_acceptance_probs(initial_probs, target_probs):
116  """Calculate the per-class acceptance rates.
117
118  Args:
119    initial_probs: The class probabilities of the data.
120    target_probs: The desired class proportion in minibatches.
121  Returns:
122    A list of the per-class acceptance probabilities.
123
124  This method is based on solving the following analysis:
125
126  Let F be the probability of a rejection (on any example).
127  Let p_i be the proportion of examples in the data in class i (init_probs)
128  Let a_i is the rate the rejection sampler should *accept* class i
129  Let t_i is the target proportion in the minibatches for class i (target_probs)
130
131  ```
132  F = sum_i(p_i * (1-a_i))
133    = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
134  ```
135
136  An example with class `i` will be accepted if `k` rejections occur, then an
137  example with class `i` is seen by the rejector, and it is accepted. This can
138  be written as follows:
139
140  ```
141  t_i = sum_k=0^inf(F^k * p_i * a_i)
142      = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
143      = p_i * a_i / sum_j(p_j * a_j)        using F from above
144  ```
145
146  Note that the following constraints hold:
147  ```
148  0 <= p_i <= 1, sum_i(p_i) = 1
149  0 <= a_i <= 1
150  0 <= t_i <= 1, sum_i(t_i) = 1
151  ```
152
153
154  A solution for a_i in terms of the other variabes is the following:
155    ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
156  """
157  # Add tiny to initial_probs to avoid divide by zero.
158  denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
159  ratio_l = target_probs / denom
160
161  # Calculate list of acceptance probabilities.
162  max_ratio = math_ops.reduce_max(ratio_l)
163  return ratio_l / max_ratio
164
165
166def _estimate_data_distribution(c, num_examples_per_class_seen):
167  """Estimate data distribution as labels are seen.
168
169  Args:
170    c: The class labels.  Type `int32`, shape `[batch_size]`.
171    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
172      containing counts.
173
174  Returns:
175    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
176      `[num_classes]`.
177    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
178  """
179  num_classes = num_examples_per_class_seen.get_shape()[0].value
180  # Update the class-count based on what labels are seen in batch.
181  num_examples_per_class_seen = math_ops.add(
182      num_examples_per_class_seen, math_ops.reduce_sum(
183          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
184  init_prob_estimate = math_ops.truediv(
185      num_examples_per_class_seen,
186      math_ops.reduce_sum(num_examples_per_class_seen))
187  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
188  return num_examples_per_class_seen, dist
189