1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for testing time series models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.timeseries.python.timeseries import estimators 22from tensorflow.contrib.timeseries.python.timeseries import input_pipeline 23from tensorflow.contrib.timeseries.python.timeseries import state_management 24from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures 25 26from tensorflow.python.client import session 27from tensorflow.python.estimator import estimator_lib 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import random_seed 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.training import adam 35from tensorflow.python.training import basic_session_run_hooks 36from tensorflow.python.training import coordinator as coordinator_lib 37from tensorflow.python.training import queue_runner_impl 38from tensorflow.python.util import nest 39 40 41class AllWindowInputFn(input_pipeline.TimeSeriesInputFn): 42 """Returns all contiguous windows of data from a full dataset. 43 44 In contrast to WholeDatasetInputFn, which does basic shape checking but 45 maintains the flat sequencing of data, this `TimeSeriesInputFn` creates 46 batches of windows. However, unlike `RandomWindowInputFn` these windows are 47 deterministic, starting at every possible offset (i.e. batches of size 48 series_length - window_size + 1 are produced). 49 """ 50 51 def __init__(self, time_series_reader, window_size): 52 """Initialize the input_pipeline. 53 54 Args: 55 time_series_reader: A `input_pipeline.TimeSeriesReader` object. 56 window_size: The size of contiguous windows of data to produce. 57 """ 58 self._window_size = window_size 59 self._reader = time_series_reader 60 super(AllWindowInputFn, self).__init__() 61 62 def create_batch(self): 63 features = self._reader.read_full() 64 times = features[TrainEvalFeatures.TIMES] 65 num_windows = array_ops.shape(times)[0] - self._window_size + 1 66 indices = array_ops.reshape(math_ops.range(num_windows), [num_windows, 1]) 67 # indices contains the starting point for each window. We now extend these 68 # indices to include the elements inside the windows as well by doing a 69 # broadcast addition. 70 increments = array_ops.reshape(math_ops.range(self._window_size), [1, -1]) 71 all_indices = array_ops.reshape(indices + increments, [-1]) 72 # Select the appropriate elements in the batch and reshape the output to 3D. 73 features = { 74 key: array_ops.reshape( 75 array_ops.gather(value, all_indices), 76 array_ops.concat( 77 [[num_windows, self._window_size], array_ops.shape(value)[1:]], 78 axis=0)) 79 for key, value in features.items() 80 } 81 return (features, None) 82 83 84class _SavingTensorHook(basic_session_run_hooks.LoggingTensorHook): 85 """A hook to save Tensors during training.""" 86 87 def __init__(self, tensors, every_n_iter=None, every_n_secs=None): 88 self.tensor_values = {} 89 super(_SavingTensorHook, self).__init__( 90 tensors=tensors, every_n_iter=every_n_iter, 91 every_n_secs=every_n_secs) 92 93 def after_run(self, run_context, run_values): 94 del run_context 95 if self._should_trigger: 96 for tag in self._current_tensors.keys(): 97 self.tensor_values[tag] = run_values.results[tag] 98 self._timer.update_last_triggered_step(self._iter_count) 99 self._iter_count += 1 100 101 102def _train_on_generated_data( 103 generate_fn, generative_model, train_iterations, seed, 104 learning_rate=0.1, ignore_params_fn=lambda _: (), 105 derived_param_test_fn=lambda _: (), 106 train_input_fn_type=input_pipeline.WholeDatasetInputFn, 107 train_state_manager=state_management.PassthroughStateManager()): 108 """The training portion of parameter recovery tests.""" 109 random_seed.set_random_seed(seed) 110 generate_graph = ops.Graph() 111 with generate_graph.as_default(): 112 with session.Session(graph=generate_graph): 113 generative_model.initialize_graph() 114 time_series_reader, true_parameters = generate_fn(generative_model) 115 true_parameters = { 116 tensor.name: value for tensor, value in true_parameters.items()} 117 eval_input_fn = input_pipeline.WholeDatasetInputFn(time_series_reader) 118 eval_state_manager = state_management.PassthroughStateManager() 119 true_parameter_eval_graph = ops.Graph() 120 with true_parameter_eval_graph.as_default(): 121 generative_model.initialize_graph() 122 ignore_params = ignore_params_fn(generative_model) 123 feature_dict, _ = eval_input_fn() 124 eval_state_manager.initialize_graph(generative_model) 125 feature_dict[TrainEvalFeatures.VALUES] = math_ops.cast( 126 feature_dict[TrainEvalFeatures.VALUES], generative_model.dtype) 127 model_outputs = eval_state_manager.define_loss( 128 model=generative_model, 129 features=feature_dict, 130 mode=estimator_lib.ModeKeys.EVAL) 131 with session.Session(graph=true_parameter_eval_graph) as sess: 132 variables.global_variables_initializer().run() 133 coordinator = coordinator_lib.Coordinator() 134 queue_runner_impl.start_queue_runners(sess, coord=coordinator) 135 true_param_loss = model_outputs.loss.eval(feed_dict=true_parameters) 136 true_transformed_params = { 137 param: param.eval(feed_dict=true_parameters) 138 for param in derived_param_test_fn(generative_model)} 139 coordinator.request_stop() 140 coordinator.join() 141 142 saving_hook = _SavingTensorHook( 143 tensors=true_parameters.keys(), 144 every_n_iter=train_iterations - 1) 145 146 class _RunConfig(estimator_lib.RunConfig): 147 148 @property 149 def tf_random_seed(self): 150 return seed 151 152 estimator = estimators.TimeSeriesRegressor( 153 model=generative_model, 154 config=_RunConfig(), 155 state_manager=train_state_manager, 156 optimizer=adam.AdamOptimizer(learning_rate)) 157 train_input_fn = train_input_fn_type(time_series_reader=time_series_reader) 158 trained_loss = (estimator.train( 159 input_fn=train_input_fn, 160 max_steps=train_iterations, 161 hooks=[saving_hook]).evaluate( 162 input_fn=eval_input_fn, steps=1))["loss"] 163 logging.info("Final trained loss: %f", trained_loss) 164 logging.info("True parameter loss: %f", true_param_loss) 165 return (ignore_params, true_parameters, true_transformed_params, 166 trained_loss, true_param_loss, saving_hook, 167 true_parameter_eval_graph) 168 169 170def test_parameter_recovery( 171 generate_fn, generative_model, train_iterations, test_case, seed, 172 learning_rate=0.1, rtol=0.2, atol=0.1, train_loss_tolerance_coeff=0.99, 173 ignore_params_fn=lambda _: (), 174 derived_param_test_fn=lambda _: (), 175 train_input_fn_type=input_pipeline.WholeDatasetInputFn, 176 train_state_manager=state_management.PassthroughStateManager()): 177 """Test that a generative model fits generated data. 178 179 Args: 180 generate_fn: A function taking a model and returning a `TimeSeriesReader` 181 object and dictionary mapping parameters to their 182 values. model.initialize_graph() will have been called on the model 183 before it is passed to this function. 184 generative_model: A timeseries.model.TimeSeriesModel instance to test. 185 train_iterations: Number of training steps. 186 test_case: A tf.test.TestCase to run assertions on. 187 seed: Same as for TimeSeriesModel.unconditional_generate(). 188 learning_rate: Step size for optimization. 189 rtol: Relative tolerance for tests. 190 atol: Absolute tolerance for tests. 191 train_loss_tolerance_coeff: Trained loss times this value must be less 192 than the loss evaluated using the generated parameters. 193 ignore_params_fn: Function mapping from a Model to a list of parameters 194 which are not tested for accurate recovery. 195 derived_param_test_fn: Function returning a list of derived parameters 196 (Tensors) which are checked for accurate recovery (comparing the value 197 evaluated with trained parameters to the value under the true 198 parameters). 199 200 As an example, for VARMA, in addition to checking AR and MA parameters, 201 this function can be used to also check lagged covariance. See 202 varma_ssm.py for details. 203 train_input_fn_type: The `TimeSeriesInputFn` type to use when training 204 (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use 205 `WholeDatasetInputFn`. 206 train_state_manager: The state manager to use when training (likely 207 `PassthroughStateManager` or `ChainingStateManager`). If None, use 208 `PassthroughStateManager`. 209 """ 210 (ignore_params, true_parameters, true_transformed_params, 211 trained_loss, true_param_loss, saving_hook, true_parameter_eval_graph 212 ) = _train_on_generated_data( 213 generate_fn=generate_fn, generative_model=generative_model, 214 train_iterations=train_iterations, seed=seed, learning_rate=learning_rate, 215 ignore_params_fn=ignore_params_fn, 216 derived_param_test_fn=derived_param_test_fn, 217 train_input_fn_type=train_input_fn_type, 218 train_state_manager=train_state_manager) 219 trained_parameter_substitutions = {} 220 for param in true_parameters.keys(): 221 evaled_value = saving_hook.tensor_values[param] 222 trained_parameter_substitutions[param] = evaled_value 223 true_value = true_parameters[param] 224 logging.info("True %s: %s, learned: %s", 225 param, true_value, evaled_value) 226 with session.Session(graph=true_parameter_eval_graph): 227 for transformed_param, true_value in true_transformed_params.items(): 228 trained_value = transformed_param.eval( 229 feed_dict=trained_parameter_substitutions) 230 logging.info("True %s [transformed parameter]: %s, learned: %s", 231 transformed_param, true_value, trained_value) 232 test_case.assertAllClose(true_value, trained_value, 233 rtol=rtol, atol=atol) 234 235 if ignore_params is None: 236 ignore_params = [] 237 else: 238 ignore_params = nest.flatten(ignore_params) 239 ignore_params = [tensor.name for tensor in ignore_params] 240 if trained_loss > 0: 241 test_case.assertLess(trained_loss * train_loss_tolerance_coeff, 242 true_param_loss) 243 else: 244 test_case.assertLess(trained_loss / train_loss_tolerance_coeff, 245 true_param_loss) 246 for param in true_parameters.keys(): 247 if param in ignore_params: 248 continue 249 evaled_value = saving_hook.tensor_values[param] 250 true_value = true_parameters[param] 251 test_case.assertAllClose(true_value, evaled_value, 252 rtol=rtol, atol=atol) 253 254 255def parameter_recovery_dry_run( 256 generate_fn, generative_model, seed, 257 learning_rate=0.1, 258 train_input_fn_type=input_pipeline.WholeDatasetInputFn, 259 train_state_manager=state_management.PassthroughStateManager()): 260 """Test that a generative model can train on generated data. 261 262 Args: 263 generate_fn: A function taking a model and returning a 264 `input_pipeline.TimeSeriesReader` object and a dictionary mapping 265 parameters to their values. model.initialize_graph() will have been 266 called on the model before it is passed to this function. 267 generative_model: A timeseries.model.TimeSeriesModel instance to test. 268 seed: Same as for TimeSeriesModel.unconditional_generate(). 269 learning_rate: Step size for optimization. 270 train_input_fn_type: The type of `TimeSeriesInputFn` to use when training 271 (likely `WholeDatasetInputFn` or `RandomWindowInputFn`). If None, use 272 `WholeDatasetInputFn`. 273 train_state_manager: The state manager to use when training (likely 274 `PassthroughStateManager` or `ChainingStateManager`). If None, use 275 `PassthroughStateManager`. 276 """ 277 _train_on_generated_data( 278 generate_fn=generate_fn, generative_model=generative_model, 279 seed=seed, learning_rate=learning_rate, 280 train_input_fn_type=train_input_fn_type, 281 train_state_manager=train_state_manager, 282 train_iterations=2) 283