1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Commonly used special feature names for time series models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.saved_model import signature_constants 22 23 24class State(object): 25 """Key formats for accepting/returning state.""" 26 # The model-dependent state to start from, as a single tuple. 27 STATE_TUPLE = "start_tuple" 28 # Same meaning as STATE_TUPLE, but prefixes keys representing flattened model 29 # state rather than mapping to a nested tuple containing model state, 30 # primarily for use with export_savedmodel. 31 STATE_PREFIX = "model_state" 32 33 34class Times(object): 35 """Key formats for accepting/returning times.""" 36 # An increasing vector of integers. 37 TIMES = "times" 38 39 40class Values(object): 41 """Key formats for accepting/returning values.""" 42 # Floating point, with one or more values corresponding to each time in TIMES. 43 VALUES = "values" 44 45 46class TrainEvalFeatures(Times, Values): 47 """Feature names used during training and evaluation.""" 48 pass 49 50 51class PredictionFeatures(Times, State): 52 """Feature names used during prediction.""" 53 pass 54 55 56class FilteringFeatures(Times, Values, State): 57 """Special feature names for filtering.""" 58 pass 59 60 61class PredictionResults(Times): 62 """Keys returned when predicting (not comprehensive).""" 63 pass 64 65 66class FilteringResults(Times, State): 67 """Keys returned from evaluation/filtering.""" 68 pass 69 70 71class SavedModelLabels(object): 72 """Names of signatures exported with export_savedmodel.""" 73 PREDICT = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 74 FILTER = "filter" 75 COLD_START_FILTER = "cold_start_filter" 76