1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions for computing moving statistics.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import init_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import state_ops 25from tensorflow.python.ops import variable_scope 26 27 28__all__ = [ 29 "assign_moving_mean_variance", 30 "assign_log_moving_mean_exp", 31 "moving_mean_variance", 32] 33 34 35def assign_moving_mean_variance( 36 mean_var, variance_var, value, decay, name=None): 37 """Compute exponentially weighted moving {mean,variance} of a streaming value. 38 39 The `value` updated exponentially weighted moving `mean_var` and 40 `variance_var` are given by the following recurrence relations: 41 42 ```python 43 variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2) 44 mean_var = decay * mean_var + (1 - decay) * value 45 ``` 46 47 Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses 48 the lag-1 mean. 49 50 For derivation justification, see [Finch (2009; Eq. 143)][1]. 51 52 Args: 53 mean_var: `float`-like `Variable` representing the exponentially weighted 54 moving mean. Same shape as `variance_var` and `value`. 55 variance_var: `float`-like `Variable` representing the 56 exponentially weighted moving variance. Same shape as `mean_var` and 57 `value`. 58 value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`. 59 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 60 `1.`, e.g., `0.999`. 61 name: Optional name of the returned operation. 62 63 Returns: 64 mean_var: `Variable` representing the `value`-updated exponentially weighted 65 moving mean. 66 variance_var: `Variable` representing the `value`-updated 67 exponentially weighted moving variance. 68 69 Raises: 70 TypeError: if `mean_var` does not have float type `dtype`. 71 TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different 72 `base_dtype`. 73 74 #### References 75 76 [1]: Tony Finch. Incremental calculation of weighted mean and variance. 77 _Technical Report_, 2009. 78 http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf 79 """ 80 with ops.name_scope(name, "assign_moving_mean_variance", 81 [variance_var, mean_var, value, decay]): 82 with ops.colocate_with(variance_var): 83 with ops.colocate_with(mean_var): 84 base_dtype = mean_var.dtype.base_dtype 85 if not base_dtype.is_floating: 86 raise TypeError( 87 "mean_var.base_dtype({}) does not have float type " 88 "`dtype`.".format(base_dtype.name)) 89 if base_dtype != variance_var.dtype.base_dtype: 90 raise TypeError( 91 "mean_var.base_dtype({}) != variance_var.base_dtype({})".format( 92 base_dtype.name, 93 variance_var.dtype.base_dtype.name)) 94 value = ops.convert_to_tensor(value, dtype=base_dtype, name="value") 95 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 96 delta = value - mean_var 97 with ops.control_dependencies([delta]): 98 mean_var = state_ops.assign_add( 99 mean_var, 100 (1. - decay) * delta) 101 variance_var = state_ops.assign_sub( 102 variance_var, 103 (1. - decay) * (variance_var - decay * math_ops.square(delta))) 104 return mean_var, variance_var 105 106 107def assign_log_moving_mean_exp( 108 log_mean_exp_var, log_value, decay, name=None): 109 """Compute the log of the exponentially weighted moving mean of the exp. 110 111 If `log_value` is a draw from a stationary random variable, this function 112 approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More 113 precisely, a `tf.Variable`, `log_mean_exp_var`, is updated by `log_value` 114 using the following identity: 115 116 ```none 117 log_mean_exp_var = 118 = log(decay exp(log_mean_exp_var) + (1 - decay) exp(log_value)) 119 = log(exp(log_mean_exp_var + log(decay)) + exp(log_value + log1p(-decay))) 120 = log_mean_exp_var 121 + log( exp(log_mean_exp_var - log_mean_exp_var + log(decay)) 122 + exp(log_value - log_mean_exp_var + log1p(-decay))) 123 = log_mean_exp_var 124 + log_sum_exp([log(decay), log_value - log_mean_exp_var + log1p(-decay)]). 125 ``` 126 127 In addition to numerical stability, this formulation is advantageous because 128 `log_mean_exp_var` can be updated in a lock-free manner, i.e., using 129 `assign_add`. (Note: the updates are not thread-safe; it's just that the 130 update to the tf.Variable is presumed efficient due to being lock-free.) 131 132 Args: 133 log_mean_exp_var: `float`-like `Variable` representing the log of the 134 exponentially weighted moving mean of the exp. Same shape as `log_value`. 135 log_value: `float`-like `Tensor` representing a new (streaming) observation. 136 Same shape as `log_mean_exp_var`. 137 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 138 `1.`, e.g., `0.999`. 139 name: Optional name of the returned operation. 140 141 Returns: 142 log_mean_exp_var: A reference to the input 'Variable' tensor with the 143 `log_value`-updated log of the exponentially weighted moving mean of exp. 144 145 Raises: 146 TypeError: if `log_mean_exp_var` does not have float type `dtype`. 147 TypeError: if `log_mean_exp_var`, `log_value`, `decay` have different 148 `base_dtype`. 149 """ 150 with ops.name_scope(name, "assign_log_moving_mean_exp", 151 [log_mean_exp_var, log_value, decay]): 152 # We want to update the variable in a numerically stable and lock-free way. 153 # To do this, observe that variable `x` updated by `v` is: 154 # x = log(w exp(x) + (1-w) exp(v)) 155 # = log(exp(x + log(w)) + exp(v + log1p(-w))) 156 # = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w))) 157 # = x + lse([log(w), v - x + log1p(-w)]) 158 with ops.colocate_with(log_mean_exp_var): 159 base_dtype = log_mean_exp_var.dtype.base_dtype 160 if not base_dtype.is_floating: 161 raise TypeError( 162 "log_mean_exp_var.base_dtype({}) does not have float type " 163 "`dtype`.".format(base_dtype.name)) 164 log_value = ops.convert_to_tensor(log_value, dtype=base_dtype, 165 name="log_value") 166 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 167 delta = (log_value - log_mean_exp_var)[array_ops.newaxis, ...] 168 x = array_ops.concat([ 169 math_ops.log(decay) * array_ops.ones_like(delta), 170 delta + math_ops.log1p(-decay) 171 ], axis=0) 172 x = math_ops.reduce_logsumexp(x, axis=0) 173 return log_mean_exp_var.assign_add(x) 174 175 176def moving_mean_variance(value, decay, collections=None, name=None): 177 """Compute exponentially weighted moving {mean,variance} of a streaming value. 178 179 The exponentially-weighting moving `mean_var` and `variance_var` are updated 180 by `value` according to the following recurrence: 181 182 ```python 183 variance_var = decay * (variance_var + (1-decay) * (value - mean_var)**2) 184 mean_var = decay * mean_var + (1 - decay) * value 185 ``` 186 187 Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses 188 the lag-`1` mean. 189 190 For derivation justification, see [Finch (2009; Eq. 143)][1]. 191 192 Unlike `assign_moving_mean_variance`, this function handles 193 variable creation. 194 195 Args: 196 value: `float`-like `Tensor`. Same shape as `mean_var` and `variance_var`. 197 decay: A `float`-like `Tensor`. The moving mean decay. Typically close to 198 `1.`, e.g., `0.999`. 199 collections: Python list of graph-collections keys to which the internal 200 variables `mean_var` and `variance_var` are added. 201 Default value is `[GraphKeys.GLOBAL_VARIABLES]`. 202 name: Optional name of the returned operation. 203 204 Returns: 205 mean_var: `Variable` representing the `value`-updated exponentially weighted 206 moving mean. 207 variance_var: `Variable` representing the `value`-updated 208 exponentially weighted moving variance. 209 210 Raises: 211 TypeError: if `value_var` does not have float type `dtype`. 212 TypeError: if `value`, `decay` have different `base_dtype`. 213 214 #### References 215 216 [1]: Tony Finch. Incremental calculation of weighted mean and variance. 217 _Technical Report_, 2009. 218 http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf 219 """ 220 if collections is None: 221 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 222 with variable_scope.variable_scope( 223 name, "moving_mean_variance", [value, decay]): 224 value = ops.convert_to_tensor(value, name="value") 225 base_dtype = value.dtype.base_dtype 226 if not base_dtype.is_floating: 227 raise TypeError( 228 "value.base_dtype({}) does not have float type `dtype`.".format( 229 base_dtype.name)) 230 decay = ops.convert_to_tensor(decay, dtype=base_dtype, name="decay") 231 variance_var = variable_scope.get_variable( 232 "moving_variance", 233 shape=value.shape, 234 dtype=value.dtype, 235 initializer=init_ops.zeros_initializer(), 236 trainable=False, 237 collections=collections) 238 mean_var = variable_scope.get_variable( 239 "moving_mean", 240 shape=value.shape, 241 dtype=value.dtype, 242 initializer=init_ops.zeros_initializer(), 243 trainable=False, 244 collections=collections) 245 return assign_moving_mean_variance( 246 mean_var, variance_var, value, decay) 247