1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Kalman filtering.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy 22 23from tensorflow.contrib.timeseries.python.timeseries import math_utils 24from tensorflow.contrib.timeseries.python.timeseries.state_space_models import kalman_filter 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import linalg_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import test 32 33 34# Two-dimensional state model with "slope" and "level" components. 35STATE_TRANSITION = [ 36 [1., 1.], # Add slope to level 37 [0., 1.] # Maintain slope 38] 39# Independent noise for each component 40STATE_TRANSITION_NOISE = [[0.1, 0.0], [0.0, 0.2]] 41OBSERVATION_MODEL = [[[0.5, 0.0], [0.0, 1.0]]] 42OBSERVATION_NOISE = [[0.0001, 0.], [0., 0.0002]] 43STATE_NOISE_TRANSFORM = [[1.0, 0.0], [0.0, 1.0]] 44 45 46def _powers_and_sums_from_transition_matrix( 47 state_transition, state_transition_noise_covariance, 48 state_noise_transform, max_gap=1): 49 def _transition_matrix_powers(powers): 50 return math_utils.matrix_to_powers(state_transition, powers) 51 def _power_sums(num_steps): 52 power_sums_tensor = math_utils.power_sums_tensor( 53 max_gap + 1, state_transition, 54 math_ops.matmul(state_noise_transform, 55 math_ops.matmul( 56 state_transition_noise_covariance, 57 state_noise_transform, 58 adjoint_b=True))) 59 return array_ops.gather(power_sums_tensor, indices=num_steps) 60 return (_transition_matrix_powers, _power_sums) 61 62 63class MultivariateTests(test.TestCase): 64 65 def _multivariate_symmetric_covariance_test_template( 66 self, dtype, simplified_posterior_variance_computation): 67 """Check that errors aren't building up asymmetries in covariances.""" 68 kf = kalman_filter.KalmanFilter(dtype=dtype) 69 observation_noise_covariance = constant_op.constant( 70 [[1., 0.5], [0.5, 1.]], dtype=dtype) 71 observation_model = constant_op.constant( 72 [[[1., 0., 0., 0.], [0., 0., 1., 0.]]], dtype=dtype) 73 state = array_ops.placeholder(shape=[1, 4], dtype=dtype) 74 state_var = array_ops.placeholder(shape=[1, 4, 4], dtype=dtype) 75 observation = array_ops.placeholder(shape=[1, 2], dtype=dtype) 76 transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix( 77 state_transition=constant_op.constant( 78 [[1., 1., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 1.], 79 [0., 0., 0., 1.]], 80 dtype=dtype), 81 state_noise_transform=linalg_ops.eye(4, dtype=dtype), 82 state_transition_noise_covariance=constant_op.constant( 83 [[1., 0., 0.5, 0.], [0., 1., 0., 0.5], [0.5, 0., 1., 0.], 84 [0., 0.5, 0., 1.]], 85 dtype=dtype)) 86 pred_state = kf.predict_state_mean( 87 prior_state=state, transition_matrices=transition_fn([1])) 88 pred_state_var = kf.predict_state_var( 89 prior_state_var=state_var, transition_matrices=transition_fn([1]), 90 transition_noise_sums=power_sum_fn([1])) 91 observed_mean, observed_var = kf.observed_from_state( 92 state_mean=pred_state, state_var=pred_state_var, 93 observation_model=observation_model, 94 observation_noise=observation_noise_covariance) 95 post_state, post_state_var = kf.posterior_from_prior_state( 96 prior_state=pred_state, prior_state_var=pred_state_var, 97 observation=observation, 98 observation_model=observation_model, 99 predicted_observations=(observed_mean, observed_var), 100 observation_noise=observation_noise_covariance) 101 with self.cached_session() as session: 102 evaled_state = numpy.array([[1., 1., 1., 1.]]) 103 evaled_state_var = numpy.eye(4)[None] 104 for i in range(500): 105 evaled_state, evaled_state_var, evaled_observed_var = session.run( 106 [post_state, post_state_var, observed_var], 107 feed_dict={state: evaled_state, 108 state_var: evaled_state_var, 109 observation: [[float(i), float(i)]]}) 110 self.assertAllClose(evaled_observed_var[0], 111 evaled_observed_var[0].T) 112 self.assertAllClose(evaled_state_var[0], 113 evaled_state_var[0].T) 114 115 def test_multivariate_symmetric_covariance_float32(self): 116 self._multivariate_symmetric_covariance_test_template( 117 dtypes.float32, simplified_posterior_variance_computation=False) 118 119 def test_multivariate_symmetric_covariance_float64(self): 120 self._multivariate_symmetric_covariance_test_template( 121 dtypes.float64, simplified_posterior_variance_computation=True) 122 123 124class KalmanFilterNonBatchTest(test.TestCase): 125 """Single-batch KalmanFilter tests.""" 126 127 def setUp(self): 128 """The basic model defined above, with unit batches.""" 129 self.kalman_filter = kalman_filter.KalmanFilter() 130 self.transition_fn, self.power_sum_fn = ( 131 _powers_and_sums_from_transition_matrix( 132 state_transition=STATE_TRANSITION, 133 state_transition_noise_covariance=STATE_TRANSITION_NOISE, 134 state_noise_transform=STATE_NOISE_TRANSFORM, 135 max_gap=5)) 136 137 def test_observed_from_state(self): 138 """Compare observation mean and noise to hand-computed values.""" 139 with self.cached_session(): 140 state = constant_op.constant([[2., 1.]]) 141 state_var = constant_op.constant([[[4., 0.], [0., 3.]]]) 142 observed_mean, observed_var = self.kalman_filter.observed_from_state( 143 state, state_var, 144 observation_model=OBSERVATION_MODEL, 145 observation_noise=OBSERVATION_NOISE) 146 observed_mean_override, observed_var_override = ( 147 self.kalman_filter.observed_from_state( 148 state, state_var, 149 observation_model=OBSERVATION_MODEL, 150 observation_noise=100 * constant_op.constant( 151 OBSERVATION_NOISE)[None])) 152 self.assertAllClose(numpy.array([[1., 1.]]), 153 observed_mean.eval()) 154 self.assertAllClose(numpy.array([[1., 1.]]), 155 observed_mean_override.eval()) 156 self.assertAllClose(numpy.array([[[1.0001, 0.], [0., 3.0002]]]), 157 observed_var.eval()) 158 self.assertAllClose(numpy.array([[[1.01, 0.], [0., 3.02]]]), 159 observed_var_override.eval()) 160 161 def _posterior_from_prior_state_test_template( 162 self, state, state_var, observation, observation_model, observation_noise, 163 expected_state, expected_state_var): 164 """Test that repeated observations converge to the expected value.""" 165 predicted_observations = self.kalman_filter.observed_from_state( 166 state, state_var, observation_model, 167 observation_noise=observation_noise) 168 state_update, state_var_update = ( 169 self.kalman_filter.posterior_from_prior_state( 170 state, state_var, observation, 171 observation_model=observation_model, 172 predicted_observations=predicted_observations, 173 observation_noise=observation_noise)) 174 with self.cached_session() as session: 175 evaled_state, evaled_state_var = session.run([state, state_var]) 176 for _ in range(300): 177 evaled_state, evaled_state_var = session.run( 178 [state_update, state_var_update], 179 feed_dict={state: evaled_state, state_var: evaled_state_var}) 180 self.assertAllClose(expected_state, 181 evaled_state, 182 atol=1e-5) 183 self.assertAllClose( 184 expected_state_var, 185 evaled_state_var, 186 atol=1e-5) 187 188 def test_posterior_from_prior_state_univariate(self): 189 self._posterior_from_prior_state_test_template( 190 state=constant_op.constant([[0.3]]), 191 state_var=constant_op.constant([[[1.]]]), 192 observation=constant_op.constant([[1.]]), 193 observation_model=[[[2.]]], 194 observation_noise=[[[0.01]]], 195 expected_state=numpy.array([[0.5]]), 196 expected_state_var=[[[0.]]]) 197 198 def test_posterior_from_prior_state_univariate_unit_noise(self): 199 self._posterior_from_prior_state_test_template( 200 state=constant_op.constant([[0.3]]), 201 state_var=constant_op.constant([[[1e10]]]), 202 observation=constant_op.constant([[1.]]), 203 observation_model=[[[2.]]], 204 observation_noise=[[[1.0]]], 205 expected_state=numpy.array([[0.5]]), 206 expected_state_var=[[[1. / (300. * 2. ** 2)]]]) 207 208 def test_posterior_from_prior_state_multivariate_2d(self): 209 self._posterior_from_prior_state_test_template( 210 state=constant_op.constant([[1.9, 1.]]), 211 state_var=constant_op.constant([[[1., 0.], [0., 2.]]]), 212 observation=constant_op.constant([[1., 1.]]), 213 observation_model=OBSERVATION_MODEL, 214 observation_noise=OBSERVATION_NOISE, 215 expected_state=numpy.array([[2., 1.]]), 216 expected_state_var=[[[0., 0.], [0., 0.]]]) 217 218 def test_posterior_from_prior_state_multivariate_3d(self): 219 self._posterior_from_prior_state_test_template( 220 state=constant_op.constant([[1.9, 1., 5.]]), 221 state_var=constant_op.constant( 222 [[[200., 0., 1.], [0., 2000., 0.], [1., 0., 40000.]]]), 223 observation=constant_op.constant([[1., 1., 3.]]), 224 observation_model=constant_op.constant( 225 [[[0.5, 0., 0.], 226 [0., 10., 0.], 227 [0., 0., 100.]]]), 228 observation_noise=linalg_ops.eye(3) / 10000., 229 expected_state=numpy.array([[2., .1, .03]]), 230 expected_state_var=numpy.zeros([1, 3, 3])) 231 232 def test_predict_state_mean(self): 233 """Compare state mean transitions with simple hand-computed values.""" 234 with self.cached_session(): 235 state = constant_op.constant([[4., 2.]]) 236 state = self.kalman_filter.predict_state_mean( 237 state, self.transition_fn([1])) 238 for _ in range(2): 239 state = self.kalman_filter.predict_state_mean( 240 state, self.transition_fn([1])) 241 self.assertAllClose( 242 numpy.array([[2. * 3. + 4., # Slope * time + base 243 2.]]), 244 state.eval()) 245 246 def test_predict_state_var(self): 247 """Compare a variance transition with simple hand-computed values.""" 248 with self.cached_session(): 249 state_var = constant_op.constant([[[1., 0.], [0., 2.]]]) 250 state_var = self.kalman_filter.predict_state_var( 251 state_var, self.transition_fn([1]), self.power_sum_fn([1])) 252 self.assertAllClose( 253 numpy.array([[[3.1, 2.0], [2.0, 2.2]]]), 254 state_var.eval()) 255 256 def test_do_filter(self): 257 """Tests do_filter. 258 259 Tests that correct values have high probability and incorrect values 260 have low probability when there is low uncertainty. 261 """ 262 with self.cached_session(): 263 state = constant_op.constant([[4., 2.]]) 264 state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]]) 265 observation = constant_op.constant([[ 266 .5 * ( 267 4. # Base 268 + 2.), # State transition 269 2. 270 ]]) 271 estimated_state = self.kalman_filter.predict_state_mean( 272 state, self.transition_fn([1])) 273 estimated_state_covariance = self.kalman_filter.predict_state_var( 274 state_var, self.transition_fn([1]), self.power_sum_fn([1])) 275 (predicted_observation, 276 predicted_observation_covariance) = ( 277 self.kalman_filter.observed_from_state( 278 estimated_state, estimated_state_covariance, 279 observation_model=OBSERVATION_MODEL, 280 observation_noise=OBSERVATION_NOISE)) 281 (_, _, first_log_prob) = self.kalman_filter.do_filter( 282 estimated_state=estimated_state, 283 estimated_state_covariance=estimated_state_covariance, 284 predicted_observation=predicted_observation, 285 predicted_observation_covariance=predicted_observation_covariance, 286 observation=observation, 287 observation_model=OBSERVATION_MODEL, 288 observation_noise=OBSERVATION_NOISE) 289 self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99)) 290 291 def test_predict_n_ahead_mean(self): 292 with self.cached_session(): 293 original_state = constant_op.constant([[4., 2.]]) 294 n = 5 295 iterative_state = original_state 296 for i in range(n): 297 self.assertAllClose( 298 iterative_state.eval(), 299 self.kalman_filter.predict_state_mean( 300 original_state, 301 self.transition_fn([i])).eval()) 302 iterative_state = self.kalman_filter.predict_state_mean( 303 iterative_state, 304 self.transition_fn([1])) 305 306 def test_predict_n_ahead_var(self): 307 with self.cached_session(): 308 original_var = constant_op.constant([[[2., 3.], [4., 5.]]]) 309 n = 5 310 iterative_var = original_var 311 for i in range(n): 312 self.assertAllClose( 313 iterative_var.eval(), 314 self.kalman_filter.predict_state_var( 315 original_var, 316 self.transition_fn([i]), 317 self.power_sum_fn([i])).eval()) 318 iterative_var = self.kalman_filter.predict_state_var( 319 iterative_var, 320 self.transition_fn([1]), 321 self.power_sum_fn([1])) 322 323 324class KalmanFilterBatchTest(test.TestCase): 325 """KalmanFilter tests with more than one element batches.""" 326 327 def test_do_filter_batch(self): 328 """Tests do_filter, in batch mode. 329 330 Tests that correct values have high probability and incorrect values 331 have low probability when there is low uncertainty. 332 """ 333 with self.cached_session(): 334 state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]]) 335 state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]]) 336 observation = constant_op.constant([ 337 [ 338 .5 * ( 339 4. # Base 340 + 2.), # State transition 341 2. 342 ], 343 [ 344 .5 * ( 345 5. # Base 346 + 3.), # State transition 347 3. 348 ], 349 [3.14, 2.71] 350 ]) # Low probability observation 351 kf = kalman_filter.KalmanFilter() 352 transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix( 353 state_transition=STATE_TRANSITION, 354 state_transition_noise_covariance=STATE_TRANSITION_NOISE, 355 state_noise_transform=STATE_NOISE_TRANSFORM, 356 max_gap=2) 357 estimated_state = kf.predict_state_mean(state, transition_fn(3*[1])) 358 estimated_state_covariance = kf.predict_state_var( 359 state_var, transition_fn(3*[1]), power_sum_fn(3*[1])) 360 observation_model = array_ops.tile(OBSERVATION_MODEL, [3, 1, 1]) 361 (predicted_observation, 362 predicted_observation_covariance) = ( 363 kf.observed_from_state( 364 estimated_state, estimated_state_covariance, 365 observation_model=observation_model, 366 observation_noise=OBSERVATION_NOISE)) 367 (state, state_var, log_prob) = kf.do_filter( 368 estimated_state=estimated_state, 369 estimated_state_covariance=estimated_state_covariance, 370 predicted_observation=predicted_observation, 371 predicted_observation_covariance=predicted_observation_covariance, 372 observation=observation, 373 observation_model=observation_model, 374 observation_noise=OBSERVATION_NOISE) 375 first_log_prob, second_log_prob, third_log_prob = log_prob.eval() 376 self.assertGreater(first_log_prob.sum(), numpy.log(0.99)) 377 self.assertGreater(second_log_prob.sum(), numpy.log(0.99)) 378 self.assertLess(third_log_prob.sum(), numpy.log(0.01)) 379 380 def test_predict_n_ahead_mean(self): 381 with self.cached_session(): 382 kf = kalman_filter.KalmanFilter() 383 transition_fn, _ = _powers_and_sums_from_transition_matrix( 384 state_transition=STATE_TRANSITION, 385 state_transition_noise_covariance=STATE_TRANSITION_NOISE, 386 state_noise_transform=STATE_NOISE_TRANSFORM, 387 max_gap=2) 388 original_state = constant_op.constant([[4., 2.], [3., 1.], [6., 2.]]) 389 state0 = original_state 390 state1 = kf.predict_state_mean(state0, transition_fn(3 * [1])) 391 state2 = kf.predict_state_mean(state1, transition_fn(3 * [1])) 392 batch_eval = kf.predict_state_mean( 393 original_state, transition_fn([1, 0, 2])).eval() 394 self.assertAllClose(state0.eval()[1], batch_eval[1]) 395 self.assertAllClose(state1.eval()[0], batch_eval[0]) 396 self.assertAllClose(state2.eval()[2], batch_eval[2]) 397 398 def test_predict_n_ahead_var(self): 399 with self.cached_session(): 400 kf = kalman_filter.KalmanFilter() 401 transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix( 402 state_transition=STATE_TRANSITION, 403 state_transition_noise_covariance=STATE_TRANSITION_NOISE, 404 state_noise_transform=STATE_NOISE_TRANSFORM, 405 max_gap=2) 406 base_var = 2.0 * numpy.identity(2) + numpy.ones([2, 2]) 407 original_var = constant_op.constant( 408 numpy.array( 409 [base_var, 2.0 * base_var, 3.0 * base_var], dtype=numpy.float32)) 410 var0 = original_var 411 var1 = kf.predict_state_var( 412 var0, transition_fn(3 * [1]), power_sum_fn(3 * [1])) 413 var2 = kf.predict_state_var( 414 var1, transition_fn(3 * [1]), power_sum_fn(3 * [1])) 415 batch_eval = kf.predict_state_var( 416 original_var, 417 transition_fn([1, 0, 2]), 418 power_sum_fn([1, 0, 2])).eval() 419 self.assertAllClose(var0.eval()[1], batch_eval[1]) 420 self.assertAllClose(var1.eval()[0], batch_eval[0]) 421 self.assertAllClose(var2.eval()[2], batch_eval[2]) 422 423 424if __name__ == "__main__": 425 test.main() 426