• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for computing moving statistics."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import state_ops
25from tensorflow.python.ops import variable_scope
26
27
28__all__ = [
29    "assign_moving_mean_variance",
30    "assign_log_moving_mean_exp",
31    "moving_mean_variance",
32]
33
34
35def assign_moving_mean_variance(
36    mean_var, variance_var, value, decay, name=None):
37  """Compute exponentially weighted moving {mean,variance} of a streaming value.
38
39  The `value` updated exponentially weighted moving `mean_var` and
40  `variance_var` are given by the following recurrence relations:
41
42  ```python
43  variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
44  mean_var     = decay * mean_var + (1 - decay) * value
45  ```
46
47  Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
48  the lag-1 mean.
49
50  For derivation justification, see [Finch (2009; Eq. 143)][1].
51
52  Args:
53    mean_var: `float`-like `Variable` representing the exponentially weighted
54      moving mean. Same shape as `variance_var` and `value`.
55    variance_var: `float`-like `Variable` representing the
56      exponentially weighted moving variance. Same shape as `mean_var` and
57      `value`.
58    value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
59    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
60      `1.`, e.g., `0.999`.
61    name: Optional name of the returned operation.
62
63  Returns:
64    mean_var: `Variable` representing the `value`-updated exponentially weighted
65      moving mean.
66    variance_var: `Variable` representing the `value`-updated
67      exponentially weighted moving variance.
68
69  Raises:
70    TypeError: if `mean_var` does not have float type `dtype`.
71    TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different
72      `base_dtype`.
73
74  #### References
75
76  [1]: Tony Finch. Incremental calculation of weighted mean and variance.
77       _Technical Report_, 2009.
78       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
79  """
80  with ops.name_scope(name, "assign_moving_mean_variance",
81                      [variance_var, mean_var, value, decay]):
82    with ops.colocate_with(variance_var):
83      with ops.colocate_with(mean_var):
84        base_dtype = mean_var.dtype.base_dtype
85        if not base_dtype.is_floating:
86          raise TypeError(
87              "mean_var.base_dtype({}) does not have float type "
88              "`dtype`.".format(base_dtype.name))
89        if base_dtype != variance_var.dtype.base_dtype:
90          raise TypeError(
91              "mean_var.base_dtype({}) != variance_var.base_dtype({})".format(
92                  base_dtype.name,
93                  variance_var.dtype.base_dtype.name))
94        value = ops.convert_to_tensor(value, dtype=base_dtype, name="value")
95        decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
96        delta = value - mean_var
97        with ops.control_dependencies([delta]):
98          mean_var = state_ops.assign_add(
99              mean_var,
100              (1. - decay) * delta)
101          variance_var = state_ops.assign_sub(
102              variance_var,
103              (1. - decay) * (variance_var - decay * math_ops.square(delta)))
104        return mean_var, variance_var
105
106
107def assign_log_moving_mean_exp(
108    log_mean_exp_var, log_value, decay, name=None):
109  """Compute the log of the exponentially weighted moving mean of the exp.
110
111  If `log_value` is a draw from a stationary random variable, this function
112  approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More
113  precisely, a `tf.Variable`, `log_mean_exp_var`, is updated by `log_value`
114  using the following identity:
115
116  ```none
117  log_mean_exp_var =
118  = log(decay exp(log_mean_exp_var) + (1 - decay) exp(log_value))
119  = log(exp(log_mean_exp_var + log(decay)) + exp(log_value + log1p(-decay)))
120  = log_mean_exp_var
121    + log(  exp(log_mean_exp_var   - log_mean_exp_var + log(decay))
122          + exp(log_value - log_mean_exp_var + log1p(-decay)))
123  = log_mean_exp_var
124    + log_sum_exp([log(decay), log_value - log_mean_exp_var + log1p(-decay)]).
125  ```
126
127  In addition to numerical stability, this formulation is advantageous because
128  `log_mean_exp_var` can be updated in a lock-free manner, i.e., using
129  `assign_add`. (Note: the updates are not thread-safe; it's just that the
130  update to the tf.Variable is presumed efficient due to being lock-free.)
131
132  Args:
133    log_mean_exp_var: `float`-like `Variable` representing the log of the
134      exponentially weighted moving mean of the exp. Same shape as `log_value`.
135    log_value: `float`-like `Tensor` representing a new (streaming) observation.
136      Same shape as `log_mean_exp_var`.
137    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
138      `1.`, e.g., `0.999`.
139    name: Optional name of the returned operation.
140
141  Returns:
142    log_mean_exp_var: A reference to the input 'Variable' tensor with the
143      `log_value`-updated log of the exponentially weighted moving mean of exp.
144
145  Raises:
146    TypeError: if `log_mean_exp_var` does not have float type `dtype`.
147    TypeError: if `log_mean_exp_var`, `log_value`, `decay` have different
148      `base_dtype`.
149  """
150  with ops.name_scope(name, "assign_log_moving_mean_exp",
151                      [log_mean_exp_var, log_value, decay]):
152    # We want to update the variable in a numerically stable and lock-free way.
153    # To do this, observe that variable `x` updated by `v` is:
154    # x = log(w exp(x) + (1-w) exp(v))
155    #   = log(exp(x + log(w)) + exp(v + log1p(-w)))
156    #   = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w)))
157    #   = x + lse([log(w), v - x + log1p(-w)])
158    with ops.colocate_with(log_mean_exp_var):
159      base_dtype = log_mean_exp_var.dtype.base_dtype
160      if not base_dtype.is_floating:
161        raise TypeError(
162            "log_mean_exp_var.base_dtype({}) does not have float type "
163            "`dtype`.".format(base_dtype.name))
164      log_value = ops.convert_to_tensor(log_value, dtype=base_dtype,
165                                        name="log_value")
166      decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
167      delta = (log_value - log_mean_exp_var)[array_ops.newaxis, ...]
168      x = array_ops.concat([
169          math_ops.log(decay) * array_ops.ones_like(delta),
170          delta + math_ops.log1p(-decay)
171      ], axis=0)
172      x = math_ops.reduce_logsumexp(x, axis=0)
173      return log_mean_exp_var.assign_add(x)
174
175
176def moving_mean_variance(value, decay, collections=None, name=None):
177  """Compute exponentially weighted moving {mean,variance} of a streaming value.
178
179  The exponentially-weighting moving `mean_var` and `variance_var` are updated
180  by `value` according to the following recurrence:
181
182  ```python
183  variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2)
184  mean_var     = decay * mean_var + (1 - decay) * value
185  ```
186
187  Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses
188  the lag-`1` mean.
189
190  For derivation justification, see [Finch (2009; Eq. 143)][1].
191
192  Unlike `assign_moving_mean_variance`, this function handles
193  variable creation.
194
195  Args:
196    value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`.
197    decay: A `float`-like `Tensor`. The moving mean decay. Typically close to
198      `1.`, e.g., `0.999`.
199    collections: Python list of graph-collections keys to which the internal
200      variables `mean_var` and `variance_var` are added.
201      Default value is `[GraphKeys.GLOBAL_VARIABLES]`.
202    name: Optional name of the returned operation.
203
204  Returns:
205    mean_var: `Variable` representing the `value`-updated exponentially weighted
206      moving mean.
207    variance_var: `Variable` representing the `value`-updated
208      exponentially weighted moving variance.
209
210  Raises:
211    TypeError: if `value_var` does not have float type `dtype`.
212    TypeError: if `value`, `decay` have different `base_dtype`.
213
214  #### References
215
216  [1]: Tony Finch. Incremental calculation of weighted mean and variance.
217       _Technical Report_, 2009.
218       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
219  """
220  if collections is None:
221    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
222  with variable_scope.variable_scope(
223      name, "moving_mean_variance", [value, decay]):
224    value = ops.convert_to_tensor(value, name="value")
225    base_dtype = value.dtype.base_dtype
226    if not base_dtype.is_floating:
227      raise TypeError(
228          "value.base_dtype({}) does not have float type `dtype`.".format(
229              base_dtype.name))
230    decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay")
231    variance_var = variable_scope.get_variable(
232        "moving_variance",
233        shape=value.shape,
234        dtype=value.dtype,
235        initializer=init_ops.zeros_initializer(),
236        trainable=False,
237        collections=collections)
238    mean_var = variable_scope.get_variable(
239        "moving_mean",
240        shape=value.shape,
241        dtype=value.dtype,
242        initializer=init_ops.zeros_initializer(),
243        trainable=False,
244        collections=collections)
245    return assign_moving_mean_variance(
246        mean_var, variance_var, value, decay)
247