1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implements Kalman filtering for linear state space models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.timeseries.python.timeseries import math_utils 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import linalg_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import numerics 30 31 32# TODO(allenl): support for always-factored covariance matrices 33class KalmanFilter(object): 34 """Inference on linear state models. 35 36 The model for observations in a given state is: 37 observation(t) = observation_model * state(t) 38 + Gaussian(0, observation_noise_covariance) 39 40 State updates take the following form: 41 state(t) = state_transition * state(t-1) 42 + state_noise_transform * Gaussian(0, state_transition_noise_covariance) 43 44 This is a real-valued analog to hidden Markov models, with linear transitions 45 and a Gaussian noise model. Given initial conditions, noise, and state 46 transition, Kalman filtering recursively estimates states and observations, 47 along with their associated uncertainty. When fed observations, future state 48 and uncertainty estimates are conditioned on those observations (in a Bayesian 49 sense). 50 51 Typically some "given"s mentioned above (noises) will be unknown, and so 52 optimizing the Kalman filter's probabilistic predictions with respect to these 53 parameters is a good approach. The state transition and observation models are 54 usually known a priori as a modeling decision. 55 56 """ 57 58 def __init__(self, dtype=dtypes.float32, 59 simplified_posterior_covariance_computation=False): 60 """Initialize the Kalman filter. 61 62 Args: 63 dtype: The data type to use for floating point tensors. 64 simplified_posterior_covariance_computation: If True, uses an algebraic 65 simplification of the Kalman filtering posterior covariance update, 66 which is slightly faster at the cost of numerical stability. The 67 simplified update is often stable when using double precision on small 68 models or with fixed transition matrices. 69 """ 70 self._simplified_posterior_covariance_computation = ( 71 simplified_posterior_covariance_computation) 72 self.dtype = dtype 73 74 def do_filter( 75 self, estimated_state, estimated_state_covariance, 76 predicted_observation, predicted_observation_covariance, 77 observation, observation_model, observation_noise): 78 """Convenience function for scoring predictions. 79 80 Scores a prediction against an observation, and computes the updated 81 posterior over states. 82 83 Shapes given below for arguments are for single-model Kalman filtering 84 (e.g. KalmanFilter). For ensembles, prior_state and prior_state_var are 85 same-length tuples of values corresponding to each model. 86 87 Args: 88 estimated_state: A prior mean over states [batch size x state dimension] 89 estimated_state_covariance: Covariance of state prior [batch size x D x 90 D], with D depending on the Kalman filter implementation (typically 91 the state dimension). 92 predicted_observation: A prediction for the observed value, such as that 93 returned by observed_from_state. A [batch size x num features] Tensor. 94 predicted_observation_covariance: A covariance matrix corresponding to 95 `predicted_observation`, a [batch size x num features x num features] 96 Tensor. 97 observation: The observed value corresponding to the predictions 98 given [batch size x observation dimension] 99 observation_model: The [batch size x observation dimension x model state 100 dimension] Tensor indicating how a particular state is mapped to 101 (pre-noise) observations for each part of the batch. 102 observation_noise: A [batch size x observation dimension x observation 103 dimension] Tensor or [observation dimension x observation dimension] 104 Tensor with covariance matrices to use for each part of the batch (a 105 two-dimensional input will be broadcast). 106 Returns: 107 posterior_state, posterior_state_var: Posterior mean and 108 covariance, updated versions of prior_state and 109 prior_state_var. 110 log_prediction_prob: Log probability of the observations under 111 the priors, suitable for optimization (should be maximized). 112 113 """ 114 symmetrized_observation_covariance = 0.5 * ( 115 predicted_observation_covariance + array_ops.matrix_transpose( 116 predicted_observation_covariance)) 117 instability_message = ( 118 "This may occur due to numerically unstable filtering when there is " 119 "a large difference in posterior variances, or when inferences are " 120 "near-deterministic. Considering tuning the " 121 "'filtering_maximum_posterior_variance_ratio' or " 122 "'filtering_minimum_posterior_variance' parameters in your " 123 "StateSpaceModelConfiguration, or tuning the transition matrix.") 124 symmetrized_observation_covariance = numerics.verify_tensor_all_finite( 125 symmetrized_observation_covariance, 126 "Predicted observation covariance was not finite. {}".format( 127 instability_message)) 128 diag = array_ops.matrix_diag_part(symmetrized_observation_covariance) 129 min_diag = math_ops.reduce_min(diag) 130 non_negative_assert = control_flow_ops.Assert( 131 min_diag >= 0., 132 [("The predicted observation covariance " 133 "has a negative diagonal entry. {}").format(instability_message), 134 min_diag]) 135 with ops.control_dependencies([non_negative_assert]): 136 observation_covariance_cholesky = linalg_ops.cholesky( 137 symmetrized_observation_covariance) 138 log_prediction_prob = math_utils.mvn_tril_log_prob( 139 loc=predicted_observation, 140 scale_tril=observation_covariance_cholesky, 141 x=observation) 142 (posterior_state, 143 posterior_state_var) = self.posterior_from_prior_state( 144 prior_state=estimated_state, 145 prior_state_var=estimated_state_covariance, 146 observation=observation, 147 observation_model=observation_model, 148 predicted_observations=(predicted_observation, 149 predicted_observation_covariance), 150 observation_noise=observation_noise) 151 return (posterior_state, posterior_state_var, log_prediction_prob) 152 153 def predict_state_mean(self, prior_state, transition_matrices): 154 """Compute state transitions. 155 156 Args: 157 prior_state: Current estimated state mean [batch_size x state_dimension] 158 transition_matrices: A [batch size, state dimension, state dimension] 159 batch of matrices (dtype matching the `dtype` argument to the 160 constructor) with the transition matrix raised to the power of the 161 number of steps to be taken (not element-wise; use 162 math_utils.matrix_to_powers if there is no efficient special case) if 163 more than one step is desired. 164 Returns: 165 State mean advanced based on `transition_matrices` (dimensions matching 166 first argument). 167 """ 168 advanced_state = array_ops.squeeze( 169 math_ops.matmul( 170 transition_matrices, 171 prior_state[..., None]), 172 axis=[-1]) 173 return advanced_state 174 175 def predict_state_var( 176 self, prior_state_var, transition_matrices, transition_noise_sums): 177 r"""Compute variance for state transitions. 178 179 Computes a noise estimate corresponding to the value returned by 180 predict_state_mean. 181 182 Args: 183 prior_state_var: Covariance matrix specifying uncertainty of current state 184 estimate [batch size x state dimension x state dimension] 185 transition_matrices: A [batch size, state dimension, state dimension] 186 batch of matrices (dtype matching the `dtype` argument to the 187 constructor) with the transition matrix raised to the power of the 188 number of steps to be taken (not element-wise; use 189 math_utils.matrix_to_powers if there is no efficient special case). 190 transition_noise_sums: A [batch size, state dimension, state dimension] 191 Tensor (dtype matching the `dtype` argument to the constructor) with: 192 193 \sum_{i=0}^{num_steps - 1} ( 194 state_transition_to_powers_fn(i) 195 * state_transition_noise_covariance 196 * state_transition_to_powers_fn(i)^T 197 ) 198 199 for the number of steps to be taken in each part of the batch (this 200 should match `transition_matrices`). Use math_utils.power_sums_tensor 201 with `tf.gather` if there is no efficient special case. 202 Returns: 203 State variance advanced based on `transition_matrices` and 204 `transition_noise_sums` (dimensions matching first argument). 205 """ 206 prior_variance_transitioned = math_ops.matmul( 207 math_ops.matmul(transition_matrices, prior_state_var), 208 transition_matrices, 209 adjoint_b=True) 210 return prior_variance_transitioned + transition_noise_sums 211 212 def posterior_from_prior_state(self, prior_state, prior_state_var, 213 observation, observation_model, 214 predicted_observations, 215 observation_noise): 216 """Compute a posterior over states given an observation. 217 218 Args: 219 prior_state: Prior state mean [batch size x state dimension] 220 prior_state_var: Prior state covariance [batch size x state dimension x 221 state dimension] 222 observation: The observed value corresponding to the predictions given 223 [batch size x observation dimension] 224 observation_model: The [batch size x observation dimension x model state 225 dimension] Tensor indicating how a particular state is mapped to 226 (pre-noise) observations for each part of the batch. 227 predicted_observations: An (observation mean, observation variance) tuple 228 computed based on the current state, usually the output of 229 observed_from_state. 230 observation_noise: A [batch size x observation dimension x observation 231 dimension] or [observation dimension x observation dimension] Tensor 232 with covariance matrices to use for each part of the batch (a 233 two-dimensional input will be broadcast). 234 Returns: 235 Posterior mean and covariance (dimensions matching the first two 236 arguments). 237 238 """ 239 observed_mean, observed_var = predicted_observations 240 residual = observation - observed_mean 241 # TODO(allenl): Can more of this be done using matrix_solve_ls? 242 kalman_solve_rhs = math_ops.matmul( 243 observation_model, prior_state_var, adjoint_b=True) 244 # This matrix_solve adjoint doesn't make a difference symbolically (since 245 # observed_var is a covariance matrix, and should be symmetric), but 246 # filtering on multivariate series is unstable without it. See 247 # test_multivariate_symmetric_covariance_float64 in kalman_filter_test.py 248 # for an example of the instability (fails with adjoint=False). 249 kalman_gain_transposed = linalg_ops.matrix_solve( 250 matrix=observed_var, rhs=kalman_solve_rhs, adjoint=True) 251 posterior_state = prior_state + array_ops.squeeze( 252 math_ops.matmul( 253 kalman_gain_transposed, 254 array_ops.expand_dims(residual, -1), 255 adjoint_a=True), 256 axis=[-1]) 257 gain_obs = math_ops.matmul( 258 kalman_gain_transposed, observation_model, adjoint_a=True) 259 identity_extradim = linalg_ops.eye( 260 array_ops.shape(gain_obs)[1], dtype=gain_obs.dtype)[None] 261 identity_minus_factor = identity_extradim - gain_obs 262 if self._simplified_posterior_covariance_computation: 263 # posterior covariance = 264 # (I - kalman_gain * observation_model) * prior_state_var 265 posterior_state_var = math_ops.matmul(identity_minus_factor, 266 prior_state_var) 267 else: 268 observation_noise = ops.convert_to_tensor(observation_noise) 269 # A Joseph form update, which provides better numeric stability than the 270 # simplified optimal Kalman gain update, at the cost of a few extra 271 # operations. Joseph form updates are valid for any gain (not just the 272 # optimal Kalman gain), and so are more forgiving of numerical errors in 273 # computing the optimal Kalman gain. 274 # 275 # posterior covariance = 276 # (I - kalman_gain * observation_model) * prior_state_var 277 # * (I - kalman_gain * observation_model)^T 278 # + kalman_gain * observation_noise * kalman_gain^T 279 left_multiplied_state_var = math_ops.matmul(identity_minus_factor, 280 prior_state_var) 281 multiplied_state_var = math_ops.matmul( 282 identity_minus_factor, left_multiplied_state_var, adjoint_b=True) 283 def _batch_observation_noise_update(): 284 return (multiplied_state_var + math_ops.matmul( 285 math_ops.matmul( 286 kalman_gain_transposed, observation_noise, adjoint_a=True), 287 kalman_gain_transposed)) 288 def _matrix_observation_noise_update(): 289 return (multiplied_state_var + math_ops.matmul( 290 math_utils.batch_times_matrix( 291 kalman_gain_transposed, observation_noise, adj_x=True), 292 kalman_gain_transposed)) 293 if observation_noise.get_shape().ndims is None: 294 posterior_state_var = control_flow_ops.cond( 295 math_ops.equal(array_ops.rank(observation_noise), 2), 296 _matrix_observation_noise_update, _batch_observation_noise_update) 297 else: 298 # If static shape information exists, it gets checked in each cond() 299 # branch, so we need a special case to avoid graph-build-time 300 # exceptions. 301 if observation_noise.get_shape().ndims == 2: 302 posterior_state_var = _matrix_observation_noise_update() 303 else: 304 posterior_state_var = _batch_observation_noise_update() 305 return posterior_state, posterior_state_var 306 307 def observed_from_state(self, state_mean, state_var, observation_model, 308 observation_noise): 309 """Compute an observation distribution given a state distribution. 310 311 Args: 312 state_mean: State mean vector [batch size x state dimension] 313 state_var: State covariance [batch size x state dimension x state 314 dimension] 315 observation_model: The [batch size x observation dimension x model state 316 dimension] Tensor indicating how a particular state is mapped to 317 (pre-noise) observations for each part of the batch. 318 observation_noise: A [batch size x observation dimension x observation 319 dimension] Tensor with covariance matrices to use for each part of the 320 batch. To remove observation noise, pass a Tensor of zeros (or simply 321 0, which will broadcast). 322 Returns: 323 observed_mean: Observation mean vector [batch size x observation 324 dimension] 325 observed_var: Observation covariance [batch size x observation dimension x 326 observation dimension] 327 328 """ 329 observed_mean = array_ops.squeeze( 330 math_ops.matmul( 331 array_ops.expand_dims(state_mean, 1), 332 observation_model, 333 adjoint_b=True), 334 axis=[1]) 335 observed_var = math_ops.matmul( 336 math_ops.matmul(observation_model, state_var), 337 observation_model, 338 adjoint_b=True) 339 observed_var += observation_noise 340 return observed_mean, observed_var 341