1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Timeseries head.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import re 21 22from tensorflow.contrib.timeseries.python.timeseries import feature_keys 23from tensorflow.python.estimator import estimator_lib 24from tensorflow.python.estimator.canned import head as head_lib 25from tensorflow.python.estimator.canned import metric_keys 26from tensorflow.python.estimator.export import export_lib 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import metrics_impl 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.summary import summary 37from tensorflow.python.training import training_util 38from tensorflow.python.util import nest 39 40 41class _NoStatePredictOutput(export_lib.PredictOutput): 42 43 def as_signature_def(self, receiver_tensors): 44 no_state_receiver_tensors = { 45 key: value for key, value in receiver_tensors.items() 46 if not key.startswith(feature_keys.State.STATE_PREFIX)} 47 return super(_NoStatePredictOutput, self).as_signature_def( 48 receiver_tensors=no_state_receiver_tensors) 49 50 51class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access 52 """Determines input and output signatures for a time series model.""" 53 54 def __init__(self, 55 model, 56 state_manager, 57 optimizer, 58 input_statistics_generator=None, 59 name=None): 60 """Creates a `_Head` for time series regression. 61 62 Args: 63 model: A model for time series regression. 64 state_manager: A state manager. 65 optimizer: An optimizer. 66 input_statistics_generator: A input statistics generator. 67 name: An optional name for the model. 68 """ 69 self.model = model 70 self.state_manager = state_manager 71 self.optimizer = optimizer 72 self.input_statistics_generator = input_statistics_generator 73 self._name = name 74 75 @property 76 def name(self): 77 return self._name 78 79 # TODO(terrytangyuan): consolidate `model_outputs` and `_Head.LossSpec` 80 # once `_Head.create_loss` becomes extendable 81 def create_loss(self, features, mode, logits=None, labels=None): 82 """See `_Head`.""" 83 model_outputs = self.state_manager.define_loss( 84 self.model, features, mode) 85 summary.scalar( 86 head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), 87 model_outputs.loss) 88 return model_outputs 89 90 @property 91 def logits_dimension(self): 92 """See `_Head`.""" 93 return 1 94 95 def _train_ops(self, features): 96 """Add training ops to the graph.""" 97 mode = estimator_lib.ModeKeys.TRAIN 98 with variable_scope.variable_scope( 99 "model", 100 # Use ResourceVariables to avoid race conditions. 101 use_resource=True): 102 model_outputs = self.create_loss(features, mode) 103 104 train_op = self.optimizer.minimize( 105 model_outputs.loss, 106 global_step=training_util.get_global_step()) 107 return estimator_lib.EstimatorSpec( 108 loss=model_outputs.loss, 109 mode=mode, 110 train_op=train_op) 111 112 def _evaluate_ops(self, features): 113 """Add ops for evaluation (aka filtering) to the graph.""" 114 mode = estimator_lib.ModeKeys.EVAL 115 with variable_scope.variable_scope("model", use_resource=True): 116 model_outputs = self.create_loss(features, mode) 117 metrics = {} 118 # Just output in-sample predictions for the last chunk seen 119 for prediction_key, prediction_value in model_outputs.predictions.items(): 120 metrics[prediction_key] = _identity_metric_single(prediction_key, 121 prediction_value) 122 metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single( 123 feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) 124 metrics[feature_keys.FilteringResults.STATE_TUPLE] = ( 125 _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, 126 model_outputs.end_state)) 127 metrics[metric_keys.MetricKeys.LOSS_MEAN] = metrics_impl.mean( 128 model_outputs.loss, name="average_loss") 129 return estimator_lib.EstimatorSpec( 130 loss=model_outputs.loss, 131 mode=mode, 132 eval_metric_ops=metrics, 133 # needed for custom metrics. 134 predictions=model_outputs.predictions) 135 136 def _predict_ops(self, features): 137 """Add ops for prediction to the graph.""" 138 with variable_scope.variable_scope("model", use_resource=True): 139 prediction = self.model.predict(features=features) 140 prediction[feature_keys.PredictionResults.TIMES] = features[ 141 feature_keys.PredictionFeatures.TIMES] 142 return estimator_lib.EstimatorSpec( 143 predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) 144 145 def _serving_ops(self, features): 146 """Add ops for serving to the graph.""" 147 with variable_scope.variable_scope("model", use_resource=True): 148 prediction_outputs = self.model.predict(features=features) 149 with variable_scope.variable_scope("model", reuse=True): 150 filtering_outputs = self.create_loss( 151 features, estimator_lib.ModeKeys.EVAL) 152 with variable_scope.variable_scope("model", reuse=True): 153 no_state_features = { 154 k: v for k, v in features.items() 155 if not k.startswith(feature_keys.State.STATE_PREFIX)} 156 # Ignore any state management when cold-starting. The model's default 157 # start state is replicated across the batch. 158 cold_filtering_outputs = self.model.define_loss( 159 features=no_state_features, mode=estimator_lib.ModeKeys.EVAL) 160 return estimator_lib.EstimatorSpec( 161 mode=estimator_lib.ModeKeys.PREDICT, 162 export_outputs={ 163 feature_keys.SavedModelLabels.PREDICT: 164 export_lib.PredictOutput(prediction_outputs), 165 feature_keys.SavedModelLabels.FILTER: 166 export_lib.PredictOutput( 167 state_to_dictionary(filtering_outputs.end_state)), 168 feature_keys.SavedModelLabels.COLD_START_FILTER: 169 _NoStatePredictOutput( 170 state_to_dictionary(cold_filtering_outputs.end_state)) 171 }, 172 # Likely unused, but it is necessary to return `predictions` to satisfy 173 # the Estimator's error checking. 174 predictions={}) 175 176 def _convert_feature_to_tensor(self, name, value): 177 """Casts features to the correct dtype based on their name.""" 178 if name in [ 179 feature_keys.TrainEvalFeatures.TIMES, 180 feature_keys.PredictionFeatures.TIMES 181 ]: 182 return math_ops.cast(value, dtypes.int64) 183 if name == feature_keys.TrainEvalFeatures.VALUES: 184 return math_ops.cast(value, self.model.dtype) 185 if name == feature_keys.PredictionFeatures.STATE_TUPLE: 186 return value # Correct dtypes are model-dependent 187 return sparse_tensor.convert_to_tensor_or_sparse_tensor(value) 188 189 def _gather_state(self, features): 190 """Returns `features` with state packed, indicates if packing was done.""" 191 prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX + 192 r"_(\d+)$") 193 numbered_state = [] 194 for key, tensor in features.items(): 195 search_result = prefixed_state_re.search(key) 196 if search_result: 197 numbered_state.append((int(search_result.group(1)), key, tensor)) 198 if not numbered_state: 199 return features, False 200 features = features.copy() 201 for _, key, _ in numbered_state: 202 del features[key] 203 numbered_state.sort(key=lambda number, *_: number) 204 features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as( 205 structure=self.model.get_start_state(), 206 flat_sequence=[tensor for _, _, tensor in numbered_state]) 207 return features, True 208 209 def _check_predict_features(self, features): 210 """Raises errors if features are not suitable for prediction.""" 211 if feature_keys.PredictionFeatures.TIMES not in features: 212 raise ValueError("Expected a '{}' feature for prediction.".format( 213 feature_keys.PredictionFeatures.TIMES)) 214 if feature_keys.PredictionFeatures.STATE_TUPLE not in features: 215 raise ValueError("Expected a '{}' feature for prediction.".format( 216 feature_keys.PredictionFeatures.STATE_TUPLE)) 217 times_feature = features[feature_keys.PredictionFeatures.TIMES] 218 if not times_feature.get_shape().is_compatible_with([None, None]): 219 raise ValueError( 220 ("Expected shape (batch dimension, window size) for feature '{}' " 221 "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, 222 times_feature.get_shape())) 223 _check_feature_shapes_compatible_with( 224 features=features, 225 compatible_with_name=feature_keys.PredictionFeatures.TIMES, 226 compatible_with_value=times_feature, 227 ignore=set([ 228 # Model-dependent shapes 229 feature_keys.PredictionFeatures.STATE_TUPLE 230 ])) 231 232 def create_estimator_spec(self, features, mode, labels=None): 233 """Performs basic error checking and returns an EstimatorSpec.""" 234 with ops.name_scope(self._name, "head"): 235 if labels is not None and labels != {}: # for better error messages. 236 raise ValueError( 237 "The model received a `labels`, which is not supported. " 238 "Pass '{}' and '{}' as features.".format( 239 feature_keys.TrainEvalFeatures.TIMES, 240 feature_keys.TrainEvalFeatures.VALUES)) 241 del labels 242 features = { 243 name: self._convert_feature_to_tensor(name=name, value=value) 244 for name, value in features.items() 245 } 246 if self.input_statistics_generator is not None: 247 input_statistics = self.input_statistics_generator.initialize_graph( 248 features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) 249 else: 250 input_statistics = None 251 self.model.initialize_graph(input_statistics=input_statistics) 252 253 # _gather_state requires the model to have its graph initialized (so it 254 # has access to the structure of the model's state) 255 features, passed_flat_state = self._gather_state(features) 256 if (mode == estimator_lib.ModeKeys.TRAIN or 257 mode == estimator_lib.ModeKeys.EVAL): 258 _check_train_eval_features(features, self.model) 259 elif mode == estimator_lib.ModeKeys.PREDICT: 260 self._check_predict_features(features) 261 else: 262 raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) 263 264 self.state_manager.initialize_graph( 265 model=self.model, input_statistics=input_statistics) 266 267 if mode == estimator_lib.ModeKeys.TRAIN: 268 return self._train_ops(features) 269 elif mode == estimator_lib.ModeKeys.EVAL: 270 return self._evaluate_ops(features) 271 elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state: 272 return self._predict_ops(features) 273 elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state: 274 # The mode is PREDICT, but we're actually in export_savedmodel for 275 # serving. We want to return two graphs: one for filtering (state + data 276 # -> state) and one for predicting (state -> prediction). 277 return self._serving_ops(features) 278 279 280class OneShotPredictionHead(TimeSeriesRegressionHead): 281 """A time series head which exports a single stateless serving signature. 282 283 The serving default signature exported by this head expects `times`, `values`, 284 and any exogenous features, but no state. `values` has shape `[batch_size, 285 filter_length, num_features]` and `times` has shape `[batch_size, 286 total_length]`, where `total_length > filter_length`. Any exogenous features 287 must have their shapes prefixed by the shape of the `times` feature. 288 289 When serving, first performs filtering on the series up to `filter_length` 290 starting from the default start state for the model, then computes predictions 291 on the remainder of the series, returning them. 292 293 Model state is neither accepted nor returned, so filtering must be performed 294 each time predictions are requested when using this head. 295 """ 296 297 def _check_predict_features(self, features): 298 """Raises errors if features are not suitable for one-shot prediction.""" 299 if feature_keys.PredictionFeatures.TIMES not in features: 300 raise ValueError("Expected a '{}' feature for prediction.".format( 301 feature_keys.PredictionFeatures.TIMES)) 302 if feature_keys.TrainEvalFeatures.VALUES not in features: 303 raise ValueError("Expected a '{}' feature for prediction.".format( 304 feature_keys.TrainEvalFeatures.VALUES)) 305 if feature_keys.PredictionFeatures.STATE_TUPLE not in features: 306 raise ValueError("Expected a '{}' feature for prediction.".format( 307 feature_keys.PredictionFeatures.STATE_TUPLE)) 308 times_feature = features[feature_keys.PredictionFeatures.TIMES] 309 if not times_feature.get_shape().is_compatible_with([None, None]): 310 raise ValueError( 311 ("Expected shape (batch dimension, window size) for feature '{}' " 312 "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, 313 times_feature.get_shape())) 314 _check_feature_shapes_compatible_with( 315 features=features, 316 compatible_with_name=feature_keys.PredictionFeatures.TIMES, 317 compatible_with_value=times_feature, 318 ignore=set([ 319 # Model-dependent shapes 320 feature_keys.PredictionFeatures.STATE_TUPLE, 321 # One shot prediction head relies on values being shorter than 322 # times. Even though we're predicting eventually, we need values for 323 # the filtering phase. 324 feature_keys.TrainEvalFeatures.VALUES, 325 ])) 326 327 def _evaluate_ops(self, features): 328 """Add ops for evaluation (aka filtering) to the graph.""" 329 spec = super(OneShotPredictionHead, self)._evaluate_ops(features) 330 # No state is fed to OneShotPredictionHead, so we don't return it; it being 331 # a tuple can cause issues for downstream infrastructure. 332 del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE] 333 return spec 334 335 def _serving_ops(self, features): 336 """Add ops for serving to the graph.""" 337 with variable_scope.variable_scope("model", use_resource=True): 338 filtering_features = {} 339 prediction_features = {} 340 values_length = array_ops.shape( 341 features[feature_keys.FilteringFeatures.VALUES])[1] 342 for key, value in features.items(): 343 if key == feature_keys.State.STATE_TUPLE: 344 # Ignore state input. The model's default start state is replicated 345 # across the batch. 346 continue 347 if key == feature_keys.FilteringFeatures.VALUES: 348 filtering_features[key] = value 349 else: 350 filtering_features[key] = value[:, :values_length] 351 prediction_features[key] = value[:, values_length:] 352 cold_filtering_outputs = self.model.define_loss( 353 features=filtering_features, mode=estimator_lib.ModeKeys.EVAL) 354 prediction_features[feature_keys.State.STATE_TUPLE] = ( 355 cold_filtering_outputs.end_state) 356 with variable_scope.variable_scope("model", reuse=True): 357 prediction_outputs = self.model.predict( 358 features=prediction_features) 359 return estimator_lib.EstimatorSpec( 360 mode=estimator_lib.ModeKeys.PREDICT, 361 export_outputs={ 362 feature_keys.SavedModelLabels.PREDICT: 363 _NoStatePredictOutput(prediction_outputs), 364 }, 365 # Likely unused, but it is necessary to return `predictions` to satisfy 366 # the Estimator's error checking. 367 predictions={}) 368 369 370def _check_feature_shapes_compatible_with(features, 371 compatible_with_name, 372 compatible_with_value, 373 ignore=None): 374 """Checks all features are compatible with the given time-like feature.""" 375 if ignore is None: 376 ignore = set() 377 for name, value in features.items(): 378 if name in ignore: 379 continue 380 feature_shape = value.get_shape() 381 if feature_shape.ndims is None: 382 continue 383 if feature_shape.ndims < 2: 384 raise ValueError( 385 ("Features must have shape (batch dimension, window size, ...) " 386 "(got rank {} for feature '{}')").format(feature_shape.ndims, name)) 387 if not feature_shape[:2].is_compatible_with( 388 compatible_with_value.get_shape()): 389 raise ValueError( 390 ("Features must have shape (batch dimension, window size, ...) " 391 "where batch dimension and window size match the " 392 "'{times_feature}' feature (got shape {feature_shape} for " 393 "feature '{feature_name}' but shape {times_shape} for feature " 394 "'{times_feature}')").format( 395 times_feature=compatible_with_name, 396 feature_shape=feature_shape, 397 feature_name=name, 398 times_shape=compatible_with_value.get_shape())) 399 400 401def _check_train_eval_features(features, model): 402 """Raise errors if features are not suitable for training/evaluation.""" 403 if feature_keys.TrainEvalFeatures.TIMES not in features: 404 raise ValueError("Expected a '{}' feature for training/evaluation.".format( 405 feature_keys.TrainEvalFeatures.TIMES)) 406 if feature_keys.TrainEvalFeatures.VALUES not in features: 407 raise ValueError("Expected a '{}' feature for training/evaluation.".format( 408 feature_keys.TrainEvalFeatures.VALUES)) 409 times_feature = features[feature_keys.TrainEvalFeatures.TIMES] 410 if not times_feature.get_shape().is_compatible_with([None, None]): 411 raise ValueError( 412 ("Expected shape (batch dimension, window size) for feature '{}' " 413 "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, 414 times_feature.get_shape())) 415 values_feature = features[feature_keys.TrainEvalFeatures.VALUES] 416 if not values_feature.get_shape().is_compatible_with( 417 [None, None, model.num_features]): 418 raise ValueError( 419 ("Expected shape (batch dimension, window size, {num_features}) " 420 "for feature '{feature_name}', since the model was configured " 421 "with num_features={num_features} (got shape {got_shape})").format( 422 num_features=model.num_features, 423 feature_name=feature_keys.TrainEvalFeatures.VALUES, 424 got_shape=times_feature.get_shape())) 425 _check_feature_shapes_compatible_with( 426 features=features, 427 compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, 428 compatible_with_value=times_feature, 429 ignore=set([ 430 feature_keys.State.STATE_TUPLE # Model-dependent shapes 431 ])) 432 433 434def _identity_metric_single(name, input_tensor): 435 """A metric which takes on its last updated value. 436 437 This keeps evaluation metrics in sync with one another, since update ops are 438 run separately from their result Tensors. Simply returning (input_tensor, 439 no_op) as a metric with a value but no update means that a metric will come 440 from a different batch of data than metrics which cache values in a Variable 441 (e.g. the default loss metric). 442 443 Args: 444 name: A name for the metric. 445 input_tensor: Any Tensor. 446 Returns: 447 A tuple of (value, update_op). 448 """ 449 metric_variable = variable_scope.variable( 450 name="{}_identity_metric".format(name), 451 initial_value=array_ops.zeros([], dtype=input_tensor.dtype), 452 collections=[ops.GraphKeys.LOCAL_VARIABLES], 453 validate_shape=False) 454 update_op = state_ops.assign( 455 metric_variable, input_tensor, validate_shape=False) 456 # This shape will be correct once the first update runs (but may be 457 # incomplete, so is not helpful for initializing the variable). 458 metric_variable.set_shape(input_tensor.get_shape()) 459 return (metric_variable.value(), update_op) 460 461 462def _identity_metric_nested(name, input_tensors): 463 """Create identity metrics for a nested tuple of Tensors.""" 464 update_ops = [] 465 value_tensors = [] 466 for tensor_number, tensor in enumerate(nest.flatten(input_tensors)): 467 value_tensor, update_op = _identity_metric_single( 468 name="{}_{}".format(name, tensor_number), input_tensor=tensor) 469 update_ops.append(update_op) 470 value_tensors.append(value_tensor) 471 return (nest.pack_sequence_as(input_tensors, value_tensors), 472 control_flow_ops.group(*update_ops)) 473 474 475def state_to_dictionary(state_tuple): 476 """Flatten model state into a dictionary with string keys.""" 477 flattened = {} 478 for state_number, state_value in enumerate(nest.flatten(state_tuple)): 479 prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX, 480 state_number) 481 flattened[prefixed_state_name] = state_value 482 return flattened 483