1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for state space model infrastructure.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23import numpy 24 25from tensorflow.contrib import layers 26 27from tensorflow.contrib.timeseries.python.timeseries import estimators 28from tensorflow.contrib.timeseries.python.timeseries import feature_keys 29from tensorflow.contrib.timeseries.python.timeseries import input_pipeline 30from tensorflow.contrib.timeseries.python.timeseries import math_utils 31from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils 32from tensorflow.contrib.timeseries.python.timeseries import state_management 33from tensorflow.contrib.timeseries.python.timeseries import test_utils 34from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model 35 36from tensorflow.python.estimator import estimator_lib 37from tensorflow.python.framework import constant_op 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import random_seed 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import linalg_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.ops import variables 47from tensorflow.python.platform import test 48from tensorflow.python.saved_model import loader 49from tensorflow.python.saved_model import tag_constants 50from tensorflow.python.training import coordinator as coordinator_lib 51from tensorflow.python.training import gradient_descent 52from tensorflow.python.training import queue_runner_impl 53 54 55class RandomStateSpaceModel(state_space_model.StateSpaceModel): 56 57 def __init__(self, 58 state_dimension, 59 state_noise_dimension, 60 configuration=state_space_model.StateSpaceModelConfiguration()): 61 self.transition = numpy.random.normal( 62 size=[state_dimension, state_dimension]).astype( 63 configuration.dtype.as_numpy_dtype) 64 self.noise_transform = numpy.random.normal( 65 size=(state_dimension, state_noise_dimension)).astype( 66 configuration.dtype.as_numpy_dtype) 67 # Test batch broadcasting 68 self.observation_model = numpy.random.normal( 69 size=(configuration.num_features, state_dimension)).astype( 70 configuration.dtype.as_numpy_dtype) 71 super(RandomStateSpaceModel, self).__init__( 72 configuration=configuration._replace( 73 covariance_prior_fn=lambda _: 0.)) 74 75 def get_state_transition(self): 76 return self.transition 77 78 def get_noise_transform(self): 79 return self.noise_transform 80 81 def get_observation_model(self, times): 82 return self.observation_model 83 84 85class ConstructionTests(test.TestCase): 86 87 def test_initialize_graph_error(self): 88 with self.assertRaisesRegexp(ValueError, "initialize_graph"): 89 model = RandomStateSpaceModel(2, 2) 90 outputs = model.define_loss( 91 features={ 92 feature_keys.TrainEvalFeatures.TIMES: 93 constant_op.constant([[1, 2]]), 94 feature_keys.TrainEvalFeatures.VALUES: 95 constant_op.constant([[[1.], [2.]]]) 96 }, 97 mode=estimator_lib.ModeKeys.TRAIN) 98 initializer = variables.global_variables_initializer() 99 with self.cached_session() as sess: 100 sess.run([initializer]) 101 outputs.loss.eval() 102 103 def test_initialize_graph_state_manager_error(self): 104 with self.assertRaisesRegexp(ValueError, "initialize_graph"): 105 model = RandomStateSpaceModel(2, 2) 106 state_manager = state_management.ChainingStateManager() 107 outputs = state_manager.define_loss( 108 model=model, 109 features={ 110 feature_keys.TrainEvalFeatures.TIMES: 111 constant_op.constant([[1, 2]]), 112 feature_keys.TrainEvalFeatures.VALUES: 113 constant_op.constant([[[1.], [2.]]]) 114 }, 115 mode=estimator_lib.ModeKeys.TRAIN) 116 initializer = variables.global_variables_initializer() 117 with self.cached_session() as sess: 118 sess.run([initializer]) 119 outputs.loss.eval() 120 121 122class GapTests(test.TestCase): 123 124 def _gap_test_template(self, times, values): 125 random_model = RandomStateSpaceModel( 126 state_dimension=1, state_noise_dimension=1, 127 configuration=state_space_model.StateSpaceModelConfiguration( 128 num_features=1)) 129 random_model.initialize_graph() 130 input_fn = input_pipeline.WholeDatasetInputFn( 131 input_pipeline.NumpyReader({ 132 feature_keys.TrainEvalFeatures.TIMES: times, 133 feature_keys.TrainEvalFeatures.VALUES: values 134 })) 135 features, _ = input_fn() 136 times = features[feature_keys.TrainEvalFeatures.TIMES] 137 values = features[feature_keys.TrainEvalFeatures.VALUES] 138 model_outputs = random_model.get_batch_loss( 139 features={ 140 feature_keys.TrainEvalFeatures.TIMES: times, 141 feature_keys.TrainEvalFeatures.VALUES: values 142 }, 143 mode=None, 144 state=math_utils.replicate_state( 145 start_state=random_model.get_start_state(), 146 batch_size=array_ops.shape(times)[0])) 147 with self.cached_session() as session: 148 variables.global_variables_initializer().run() 149 coordinator = coordinator_lib.Coordinator() 150 queue_runner_impl.start_queue_runners(session, coord=coordinator) 151 model_outputs.loss.eval() 152 coordinator.request_stop() 153 coordinator.join() 154 155 def test_start_gap(self): 156 self._gap_test_template(times=[20, 21, 22], values=numpy.arange(3)) 157 158 def test_mid_gap(self): 159 self._gap_test_template(times=[2, 60, 61], values=numpy.arange(3)) 160 161 def test_end_gap(self): 162 self._gap_test_template(times=[2, 3, 73], values=numpy.arange(3)) 163 164 def test_all_gaps(self): 165 self._gap_test_template(times=[2, 4, 8, 16, 32, 64, 128], 166 values=numpy.arange(7)) 167 168 169class StateSpaceEquivalenceTests(test.TestCase): 170 171 def test_savedmodel_state_override(self): 172 random_model = RandomStateSpaceModel( 173 state_dimension=5, 174 state_noise_dimension=4, 175 configuration=state_space_model.StateSpaceModelConfiguration( 176 exogenous_feature_columns=[layers.real_valued_column("exogenous")], 177 dtype=dtypes.float64, num_features=1)) 178 estimator = estimators.StateSpaceRegressor( 179 model=random_model, 180 optimizer=gradient_descent.GradientDescentOptimizer(0.1)) 181 combined_input_fn = input_pipeline.WholeDatasetInputFn( 182 input_pipeline.NumpyReader({ 183 feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4], 184 feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.], 185 "exogenous": [-1., -2., -3., -4.] 186 })) 187 estimator.train(combined_input_fn, steps=1) 188 export_location = estimator.export_saved_model( 189 self.get_temp_dir(), estimator.build_raw_serving_input_receiver_fn()) 190 with ops.Graph().as_default() as graph: 191 random_model.initialize_graph() 192 with self.session(graph=graph) as session: 193 variables.global_variables_initializer().run() 194 evaled_start_state = session.run(random_model.get_start_state()) 195 evaled_start_state = [ 196 state_element[None, ...] for state_element in evaled_start_state] 197 with ops.Graph().as_default() as graph: 198 with self.session(graph=graph) as session: 199 signatures = loader.load( 200 session, [tag_constants.SERVING], export_location) 201 first_split_filtering = saved_model_utils.filter_continuation( 202 continue_from={ 203 feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state}, 204 signatures=signatures, 205 session=session, 206 features={ 207 feature_keys.FilteringFeatures.TIMES: [1, 2], 208 feature_keys.FilteringFeatures.VALUES: [1., 2.], 209 "exogenous": [[-1.], [-2.]]}) 210 second_split_filtering = saved_model_utils.filter_continuation( 211 continue_from=first_split_filtering, 212 signatures=signatures, 213 session=session, 214 features={ 215 feature_keys.FilteringFeatures.TIMES: [3, 4], 216 feature_keys.FilteringFeatures.VALUES: [3., 4.], 217 "exogenous": [[-3.], [-4.]] 218 }) 219 combined_filtering = saved_model_utils.filter_continuation( 220 continue_from={ 221 feature_keys.FilteringResults.STATE_TUPLE: evaled_start_state}, 222 signatures=signatures, 223 session=session, 224 features={ 225 feature_keys.FilteringFeatures.TIMES: [1, 2, 3, 4], 226 feature_keys.FilteringFeatures.VALUES: [1., 2., 3., 4.], 227 "exogenous": [[-1.], [-2.], [-3.], [-4.]] 228 }) 229 split_predict = saved_model_utils.predict_continuation( 230 continue_from=second_split_filtering, 231 signatures=signatures, 232 session=session, 233 steps=1, 234 exogenous_features={ 235 "exogenous": [[[-5.]]]}) 236 combined_predict = saved_model_utils.predict_continuation( 237 continue_from=combined_filtering, 238 signatures=signatures, 239 session=session, 240 steps=1, 241 exogenous_features={ 242 "exogenous": [[[-5.]]]}) 243 for state_key, combined_state_value in combined_filtering.items(): 244 if state_key == feature_keys.FilteringResults.TIMES: 245 continue 246 self.assertAllClose( 247 combined_state_value, second_split_filtering[state_key]) 248 for prediction_key, combined_value in combined_predict.items(): 249 self.assertAllClose(combined_value, split_predict[prediction_key]) 250 251 def _equivalent_to_single_model_test_template(self, model_generator): 252 with self.cached_session() as session: 253 random_model = RandomStateSpaceModel( 254 state_dimension=5, 255 state_noise_dimension=4, 256 configuration=state_space_model.StateSpaceModelConfiguration( 257 dtype=dtypes.float64, num_features=1)) 258 random_model.initialize_graph() 259 series_length = 10 260 model_data = random_model.generate( 261 number_of_series=1, series_length=series_length, 262 model_parameters=random_model.random_model_parameters()) 263 input_fn = input_pipeline.WholeDatasetInputFn( 264 input_pipeline.NumpyReader(model_data)) 265 features, _ = input_fn() 266 model_outputs = random_model.get_batch_loss( 267 features=features, 268 mode=None, 269 state=math_utils.replicate_state( 270 start_state=random_model.get_start_state(), 271 batch_size=array_ops.shape( 272 features[feature_keys.TrainEvalFeatures.TIMES])[0])) 273 variables.global_variables_initializer().run() 274 compare_outputs_evaled_fn = model_generator( 275 random_model, model_data) 276 coordinator = coordinator_lib.Coordinator() 277 queue_runner_impl.start_queue_runners(session, coord=coordinator) 278 compare_outputs_evaled = compare_outputs_evaled_fn(session) 279 model_outputs_evaled = session.run( 280 (model_outputs.end_state, model_outputs.predictions)) 281 coordinator.request_stop() 282 coordinator.join() 283 model_posteriors, model_predictions = model_outputs_evaled 284 (_, compare_posteriors, 285 compare_predictions) = compare_outputs_evaled 286 (model_posterior_mean, model_posterior_var, 287 model_from_time) = model_posteriors 288 (compare_posterior_mean, compare_posterior_var, 289 compare_from_time) = compare_posteriors 290 self.assertAllClose(model_posterior_mean, compare_posterior_mean[0]) 291 self.assertAllClose(model_posterior_var, compare_posterior_var[0]) 292 self.assertAllClose(model_from_time, compare_from_time) 293 self.assertEqual(sorted(model_predictions.keys()), 294 sorted(compare_predictions.keys())) 295 for prediction_name in model_predictions: 296 if prediction_name == "loss": 297 # Chunking means that losses will be different; skip testing them. 298 continue 299 # Compare the last chunk to their corresponding un-chunked model 300 # predictions 301 last_prediction_chunk = compare_predictions[prediction_name][-1] 302 comparison_values = last_prediction_chunk.shape[0] 303 model_prediction = ( 304 model_predictions[prediction_name][0, -comparison_values:]) 305 self.assertAllClose(model_prediction, 306 last_prediction_chunk) 307 308 def _model_equivalent_to_chained_model_test_template(self, chunk_size): 309 def chained_model_outputs(original_model, data): 310 input_fn = test_utils.AllWindowInputFn( 311 input_pipeline.NumpyReader(data), window_size=chunk_size) 312 state_manager = state_management.ChainingStateManager( 313 state_saving_interval=1) 314 features, _ = input_fn() 315 state_manager.initialize_graph(original_model) 316 model_outputs = state_manager.define_loss( 317 model=original_model, 318 features=features, 319 mode=estimator_lib.ModeKeys.TRAIN) 320 def _eval_outputs(session): 321 for _ in range(50): 322 # Warm up saved state 323 model_outputs.loss.eval() 324 (posterior_mean, posterior_var, 325 priors_from_time) = model_outputs.end_state 326 posteriors = ((posterior_mean,), (posterior_var,), priors_from_time) 327 outputs = (model_outputs.loss, posteriors, 328 model_outputs.predictions) 329 chunked_outputs_evaled = session.run(outputs) 330 return chunked_outputs_evaled 331 return _eval_outputs 332 self._equivalent_to_single_model_test_template(chained_model_outputs) 333 334 def test_model_equivalent_to_chained_model_chunk_size_one(self): 335 numpy.random.seed(2) 336 random_seed.set_random_seed(3) 337 self._model_equivalent_to_chained_model_test_template(1) 338 339 def test_model_equivalent_to_chained_model_chunk_size_five(self): 340 numpy.random.seed(4) 341 random_seed.set_random_seed(5) 342 self._model_equivalent_to_chained_model_test_template(5) 343 344 345class PredictionTests(test.TestCase): 346 347 def _check_predictions( 348 self, predicted_mean, predicted_covariance, window_size): 349 self.assertAllEqual(predicted_covariance.shape, 350 [1, # batch 351 window_size, 352 1, # num features 353 1]) # num features 354 self.assertAllEqual(predicted_mean.shape, 355 [1, # batch 356 window_size, 357 1]) # num features 358 for position in range(window_size - 2): 359 self.assertGreater(predicted_covariance[0, position + 2, 0, 0], 360 predicted_covariance[0, position, 0, 0]) 361 362 def test_predictions_direct(self): 363 dtype = dtypes.float64 364 with variable_scope.variable_scope(dtype.name): 365 random_model = RandomStateSpaceModel( 366 state_dimension=5, state_noise_dimension=4, 367 configuration=state_space_model.StateSpaceModelConfiguration( 368 dtype=dtype, num_features=1)) 369 random_model.initialize_graph() 370 prediction_dict = random_model.predict(features={ 371 feature_keys.PredictionFeatures.TIMES: [[1, 3, 5, 6]], 372 feature_keys.PredictionFeatures.STATE_TUPLE: 373 math_utils.replicate_state( 374 start_state=random_model.get_start_state(), batch_size=1) 375 }) 376 with self.cached_session(): 377 variables.global_variables_initializer().run() 378 predicted_mean = prediction_dict["mean"].eval() 379 predicted_covariance = prediction_dict["covariance"].eval() 380 self._check_predictions(predicted_mean, predicted_covariance, 381 window_size=4) 382 383 def test_predictions_after_loss(self): 384 dtype = dtypes.float32 385 with variable_scope.variable_scope(dtype.name): 386 random_model = RandomStateSpaceModel( 387 state_dimension=5, state_noise_dimension=4, 388 configuration=state_space_model.StateSpaceModelConfiguration( 389 dtype=dtype, num_features=1)) 390 features = { 391 feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3, 4]], 392 feature_keys.TrainEvalFeatures.VALUES: 393 array_ops.ones([1, 4, 1], dtype=dtype) 394 } 395 passthrough = state_management.PassthroughStateManager() 396 random_model.initialize_graph() 397 passthrough.initialize_graph(random_model) 398 model_outputs = passthrough.define_loss( 399 model=random_model, 400 features=features, 401 mode=estimator_lib.ModeKeys.EVAL) 402 predictions = random_model.predict({ 403 feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]], 404 feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state 405 }) 406 with self.cached_session(): 407 variables.global_variables_initializer().run() 408 predicted_mean = predictions["mean"].eval() 409 predicted_covariance = predictions["covariance"].eval() 410 self._check_predictions(predicted_mean, predicted_covariance, 411 window_size=3) 412 413 414class ExogenousTests(test.TestCase): 415 416 def test_noise_increasing(self): 417 for dtype in [dtypes.float32, dtypes.float64]: 418 with variable_scope.variable_scope(dtype.name): 419 random_model = RandomStateSpaceModel( 420 state_dimension=5, state_noise_dimension=4, 421 configuration=state_space_model.StateSpaceModelConfiguration( 422 dtype=dtype, num_features=1)) 423 original_covariance = array_ops.diag(array_ops.ones(shape=[5])) 424 _, new_covariance, _ = random_model._exogenous_noise_increasing( 425 current_times=[[1]], 426 exogenous_values=[[5.]], 427 state=[ 428 array_ops.ones(shape=[1, 5]), original_covariance[None], [0] 429 ]) 430 with self.cached_session() as session: 431 variables.global_variables_initializer().run() 432 evaled_new_covariance, evaled_original_covariance = session.run( 433 [new_covariance[0], original_covariance]) 434 new_variances = numpy.diag(evaled_new_covariance) 435 original_variances = numpy.diag(evaled_original_covariance) 436 for i in range(5): 437 self.assertGreater(new_variances[i], original_variances[i]) 438 439 def test_noise_decreasing(self): 440 for dtype in [dtypes.float32, dtypes.float64]: 441 with variable_scope.variable_scope(dtype.name): 442 random_model = RandomStateSpaceModel( 443 state_dimension=5, state_noise_dimension=4, 444 configuration=state_space_model.StateSpaceModelConfiguration( 445 dtype=dtype, num_features=1)) 446 random_model.initialize_graph() 447 original_covariance = array_ops.diag( 448 array_ops.ones(shape=[5], dtype=dtype)) 449 _, new_covariance, _ = random_model._exogenous_noise_decreasing( 450 current_times=[[1]], 451 exogenous_values=constant_op.constant([[-2.]], dtype=dtype), 452 state=[ 453 -array_ops.ones(shape=[1, 5], dtype=dtype), 454 original_covariance[None], [0] 455 ]) 456 with self.cached_session() as session: 457 variables.global_variables_initializer().run() 458 evaled_new_covariance, evaled_original_covariance = session.run( 459 [new_covariance[0], original_covariance]) 460 new_variances = numpy.diag(evaled_new_covariance) 461 original_variances = numpy.diag(evaled_original_covariance) 462 for i in range(5): 463 self.assertLess(new_variances[i], original_variances[i]) 464 465 466class StubStateSpaceModel(state_space_model.StateSpaceModel): 467 468 def __init__(self, 469 transition, 470 state_noise_dimension, 471 configuration=state_space_model.StateSpaceModelConfiguration()): 472 self.transition = transition 473 self.noise_transform = numpy.random.normal( 474 size=(transition.shape[0], state_noise_dimension)).astype(numpy.float32) 475 # Test feature + batch broadcasting 476 self.observation_model = numpy.random.normal( 477 size=(transition.shape[0])).astype(numpy.float32) 478 super(StubStateSpaceModel, self).__init__( 479 configuration=configuration) 480 481 def get_state_transition(self): 482 return self.transition 483 484 def get_noise_transform(self): 485 return self.noise_transform 486 487 def get_observation_model(self, times): 488 return self.observation_model 489 490 491GeneratedModel = collections.namedtuple( 492 "GeneratedModel", ["model", "data", "true_parameters"]) 493 494 495class PosteriorTests(test.TestCase): 496 497 def _get_cycle_transition(self, period): 498 cycle_transition = numpy.zeros([period - 1, period - 1], 499 dtype=numpy.float32) 500 cycle_transition[0, :] = -1 501 cycle_transition[1:, :-1] = numpy.identity(period - 2) 502 return cycle_transition 503 504 _adder_transition = numpy.array([[1, 1], 505 [0, 1]], dtype=numpy.float32) 506 507 def _get_single_model(self): 508 numpy.random.seed(8) 509 stub_model = StubStateSpaceModel( 510 transition=self._get_cycle_transition(5), state_noise_dimension=0) 511 series_length = 1000 512 stub_model.initialize_graph() 513 true_params = stub_model.random_model_parameters() 514 data = stub_model.generate( 515 number_of_series=1, series_length=series_length, 516 model_parameters=true_params) 517 return GeneratedModel( 518 model=stub_model, data=data, true_parameters=true_params) 519 520 def test_exact_posterior_recovery_no_transition_noise(self): 521 with self.cached_session() as session: 522 stub_model, data, true_params = self._get_single_model() 523 input_fn = input_pipeline.WholeDatasetInputFn( 524 input_pipeline.NumpyReader(data)) 525 features, _ = input_fn() 526 model_outputs = stub_model.get_batch_loss( 527 features=features, 528 mode=None, 529 state=math_utils.replicate_state( 530 start_state=stub_model.get_start_state(), 531 batch_size=array_ops.shape( 532 features[feature_keys.TrainEvalFeatures.TIMES])[0])) 533 variables.global_variables_initializer().run() 534 coordinator = coordinator_lib.Coordinator() 535 queue_runner_impl.start_queue_runners(session, coord=coordinator) 536 posterior_mean, posterior_var, posterior_times = session.run( 537 # Feed the true model parameters so that this test doesn't depend on 538 # the generated parameters being close to the variable initializations 539 # (an alternative would be training steps to fit the noise values, 540 # which would be slow). 541 model_outputs.end_state, feed_dict=true_params) 542 coordinator.request_stop() 543 coordinator.join() 544 545 self.assertAllClose(numpy.zeros([1, 4, 4]), posterior_var, 546 atol=1e-2) 547 self.assertAllClose( 548 numpy.dot( 549 numpy.linalg.matrix_power( 550 stub_model.transition, 551 data[feature_keys.TrainEvalFeatures.TIMES].shape[1]), 552 true_params[stub_model.prior_state_mean]), 553 posterior_mean[0], 554 rtol=1e-1) 555 self.assertAllClose( 556 math_utils.batch_end_time( 557 features[feature_keys.TrainEvalFeatures.TIMES]).eval(), 558 posterior_times) 559 560 def test_chained_exact_posterior_recovery_no_transition_noise(self): 561 with self.cached_session() as session: 562 stub_model, data, true_params = self._get_single_model() 563 chunk_size = 10 564 input_fn = test_utils.AllWindowInputFn( 565 input_pipeline.NumpyReader(data), window_size=chunk_size) 566 features, _ = input_fn() 567 state_manager = state_management.ChainingStateManager( 568 state_saving_interval=1) 569 state_manager.initialize_graph(stub_model) 570 model_outputs = state_manager.define_loss( 571 model=stub_model, 572 features=features, 573 mode=estimator_lib.ModeKeys.TRAIN) 574 variables.global_variables_initializer().run() 575 coordinator = coordinator_lib.Coordinator() 576 queue_runner_impl.start_queue_runners(session, coord=coordinator) 577 for _ in range( 578 data[feature_keys.TrainEvalFeatures.TIMES].shape[1] // chunk_size): 579 model_outputs.loss.eval() 580 posterior_mean, posterior_var, posterior_times = session.run( 581 model_outputs.end_state, feed_dict=true_params) 582 coordinator.request_stop() 583 coordinator.join() 584 self.assertAllClose(numpy.zeros([1, 4, 4]), posterior_var, 585 atol=1e-2) 586 self.assertAllClose( 587 numpy.dot( 588 numpy.linalg.matrix_power( 589 stub_model.transition, 590 data[feature_keys.TrainEvalFeatures.TIMES].shape[1]), 591 true_params[stub_model.prior_state_mean]), 592 posterior_mean[0], 593 rtol=1e-1) 594 self.assertAllClose(data[feature_keys.TrainEvalFeatures.TIMES][:, -1], 595 posterior_times) 596 597 598class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel): 599 """A mostly trivial model which predicts values = times + 1.""" 600 601 def __init__(self, static_unrolling_window_size_threshold=None): 602 super(TimeDependentStateSpaceModel, self).__init__( 603 configuration=state_space_model.StateSpaceModelConfiguration( 604 use_observation_noise=False, 605 transition_covariance_initial_log_scale_bias=5., 606 static_unrolling_window_size_threshold= 607 static_unrolling_window_size_threshold)) 608 609 def get_state_transition(self): 610 return array_ops.ones(shape=[1, 1]) 611 612 def get_noise_transform(self): 613 return array_ops.ones(shape=[1, 1]) 614 615 def get_observation_model(self, times): 616 return array_ops.reshape( 617 tensor=math_ops.cast(times + 1, dtypes.float32), shape=[-1, 1, 1]) 618 619 def make_priors(self): 620 return (ops.convert_to_tensor([1.]), ops.convert_to_tensor([[0.]])) 621 622 623class UnknownShapeModel(TimeDependentStateSpaceModel): 624 625 def get_observation_model(self, times): 626 parent_model = super(UnknownShapeModel, self).get_observation_model(times) 627 return array_ops.placeholder_with_default( 628 input=parent_model, shape=tensor_shape.unknown_shape()) 629 630 631class TimeDependentTests(test.TestCase): 632 633 def _time_dependency_test_template(self, model_type): 634 """Test that a time-dependent observation model influences predictions.""" 635 model = model_type() 636 estimator = estimators.StateSpaceRegressor( 637 model=model, optimizer=gradient_descent.GradientDescentOptimizer(0.1)) 638 values = numpy.reshape([1., 2., 3., 4.], 639 newshape=[1, 4, 1]) 640 input_fn = input_pipeline.WholeDatasetInputFn( 641 input_pipeline.NumpyReader({ 642 feature_keys.TrainEvalFeatures.TIMES: [[0, 1, 2, 3]], 643 feature_keys.TrainEvalFeatures.VALUES: values 644 })) 645 estimator.train(input_fn=input_fn, max_steps=1) 646 predicted_values = estimator.evaluate(input_fn=input_fn, steps=1)["mean"] 647 # Throw out the first value so we don't test the prior 648 self.assertAllEqual(values[1:], predicted_values[1:]) 649 650 def test_undefined_shape_time_dependency(self): 651 self._time_dependency_test_template(UnknownShapeModel) 652 653 def test_loop_unrolling(self): 654 """Tests running/restoring from a checkpoint with static unrolling.""" 655 model = TimeDependentStateSpaceModel( 656 # Unroll during training, but not evaluation 657 static_unrolling_window_size_threshold=2) 658 estimator = estimators.StateSpaceRegressor(model=model) 659 times = numpy.arange(100) 660 values = numpy.arange(100) 661 dataset = { 662 feature_keys.TrainEvalFeatures.TIMES: times, 663 feature_keys.TrainEvalFeatures.VALUES: values 664 } 665 train_input_fn = input_pipeline.RandomWindowInputFn( 666 input_pipeline.NumpyReader(dataset), batch_size=16, window_size=2) 667 eval_input_fn = input_pipeline.WholeDatasetInputFn( 668 input_pipeline.NumpyReader(dataset)) 669 estimator.train(input_fn=train_input_fn, max_steps=1) 670 estimator.evaluate(input_fn=eval_input_fn, steps=1) 671 672 673class LevelOnlyModel(state_space_model.StateSpaceModel): 674 675 def get_state_transition(self): 676 return linalg_ops.eye(1, dtype=self.dtype) 677 678 def get_noise_transform(self): 679 return linalg_ops.eye(1, dtype=self.dtype) 680 681 def get_observation_model(self, times): 682 return [1] 683 684 685class MultivariateLevelModel( 686 state_space_model.StateSpaceCorrelatedFeaturesEnsemble): 687 688 def __init__(self, configuration): 689 univariate_component_configuration = configuration._replace( 690 num_features=1) 691 components = [] 692 for feature in range(configuration.num_features): 693 with variable_scope.variable_scope("feature{}".format(feature)): 694 components.append( 695 LevelOnlyModel(configuration=univariate_component_configuration)) 696 super(MultivariateLevelModel, self).__init__( 697 ensemble_members=components, configuration=configuration) 698 699 700class MultivariateTests(test.TestCase): 701 702 def test_multivariate(self): 703 dtype = dtypes.float32 704 num_features = 3 705 covariance = numpy.eye(num_features) 706 # A single off-diagonal has a non-zero value in the true transition 707 # noise covariance. 708 covariance[-1, 0] = 1. 709 covariance[0, -1] = 1. 710 dataset_size = 100 711 values = numpy.cumsum( 712 numpy.random.multivariate_normal( 713 mean=numpy.zeros(num_features), 714 cov=covariance, 715 size=dataset_size), 716 axis=0) 717 times = numpy.arange(dataset_size) 718 model = MultivariateLevelModel( 719 configuration=state_space_model.StateSpaceModelConfiguration( 720 num_features=num_features, 721 dtype=dtype, 722 use_observation_noise=False, 723 transition_covariance_initial_log_scale_bias=5.)) 724 estimator = estimators.StateSpaceRegressor( 725 model=model, optimizer=gradient_descent.GradientDescentOptimizer(0.1)) 726 data = { 727 feature_keys.TrainEvalFeatures.TIMES: times, 728 feature_keys.TrainEvalFeatures.VALUES: values 729 } 730 train_input_fn = input_pipeline.RandomWindowInputFn( 731 input_pipeline.NumpyReader(data), batch_size=16, window_size=16) 732 estimator.train(input_fn=train_input_fn, steps=1) 733 for component in model._ensemble_members: 734 # Check that input statistics propagated to component models 735 self.assertTrue(component._input_statistics) 736 737 def test_ensemble_observation_noise(self): 738 model = MultivariateLevelModel( 739 configuration=state_space_model.StateSpaceModelConfiguration()) 740 model.initialize_graph() 741 outputs = model.define_loss( 742 features={ 743 feature_keys.TrainEvalFeatures.TIMES: 744 constant_op.constant([[1, 2]]), 745 feature_keys.TrainEvalFeatures.VALUES: 746 constant_op.constant([[[1.], [2.]]]) 747 }, 748 mode=estimator_lib.ModeKeys.TRAIN) 749 initializer = variables.global_variables_initializer() 750 with self.cached_session() as sess: 751 sess.run([initializer]) 752 outputs.loss.eval() 753 754if __name__ == "__main__": 755 test.main() 756