• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Estimator for State Saving RNNs (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.contrib import layers
27from tensorflow.contrib import rnn as rnn_cell
28from tensorflow.contrib.layers.python.layers import feature_column_ops
29from tensorflow.contrib.layers.python.layers import optimizers
30from tensorflow.contrib.learn.python.learn.estimators import constants
31from tensorflow.contrib.learn.python.learn.estimators import estimator
32from tensorflow.contrib.learn.python.learn.estimators import model_fn
33from tensorflow.contrib.learn.python.learn.estimators import rnn_common
34from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import rnn
41from tensorflow.python.training import momentum as momentum_opt
42from tensorflow.python.util import nest
43
44
45def construct_state_saving_rnn(cell,
46                               inputs,
47                               num_label_columns,
48                               state_saver,
49                               state_name,
50                               scope='rnn'):
51  """Build a state saving RNN and apply a fully connected layer.
52
53  Args:
54    cell: An instance of `RNNCell`.
55    inputs: A length `T` list of inputs, each a `Tensor` of shape
56      `[batch_size, input_size, ...]`.
57    num_label_columns: The desired output dimension.
58    state_saver: A state saver object with methods `state` and `save_state`.
59    state_name: Python string or tuple of strings.  The name to use with the
60      state_saver. If the cell returns tuples of states (i.e.,
61      `cell.state_size` is a tuple) then `state_name` should be a tuple of
62      strings having the same length as `cell.state_size`.  Otherwise it should
63      be a single string.
64    scope: `VariableScope` for the created subgraph; defaults to "rnn".
65
66  Returns:
67    activations: The output of the RNN, projected to `num_label_columns`
68      dimensions, a `Tensor` of shape `[batch_size, T, num_label_columns]`.
69    final_state: The final state output by the RNN
70  """
71  with ops.name_scope(scope):
72    rnn_outputs, final_state = rnn.static_state_saving_rnn(
73        cell=cell,
74        inputs=inputs,
75        state_saver=state_saver,
76        state_name=state_name,
77        scope=scope)
78    # Convert rnn_outputs from a list of time-major order Tensors to a single
79    # Tensor of batch-major order.
80    rnn_outputs = array_ops.stack(rnn_outputs, axis=1)
81    activations = layers.fully_connected(
82        inputs=rnn_outputs,
83        num_outputs=num_label_columns,
84        activation_fn=None,
85        trainable=True)
86    # Use `identity` to rename `final_state`.
87    final_state = array_ops.identity(
88        final_state, name=rnn_common.RNNKeys.FINAL_STATE_KEY)
89    return activations, final_state
90
91
92def _multi_value_loss(
93    activations, labels, sequence_length, target_column, features):
94  """Maps `activations` from the RNN to loss for multi value models.
95
96  Args:
97    activations: Output from an RNN. Should have dtype `float32` and shape
98      `[batch_size, padded_length, ?]`.
99    labels: A `Tensor` with length `[batch_size, padded_length]`.
100    sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
101      containing the length of each sequence in the batch. If `None`, sequences
102      are assumed to be unpadded.
103    target_column: An initialized `TargetColumn`, calculate predictions.
104    features: A `dict` containing the input and (optionally) sequence length
105      information and initial state.
106  Returns:
107    A scalar `Tensor` containing the loss.
108  """
109  with ops.name_scope('MultiValueLoss'):
110    activations_masked, labels_masked = rnn_common.mask_activations_and_labels(
111        activations, labels, sequence_length)
112    return target_column.loss(activations_masked, labels_masked, features)
113
114
115def _get_name_or_parent_names(column):
116  """Gets the name of a column or its parent columns' names.
117
118  Args:
119    column: A sequence feature column derived from `FeatureColumn`.
120
121  Returns:
122    A list of the name of `column` or the names of its parent columns,
123    if any exist.
124  """
125  # pylint: disable=protected-access
126  parent_columns = feature_column_ops._get_parent_columns(column)
127  if parent_columns:
128    return [x.name for x in parent_columns]
129  return [column.name]
130
131
132def _prepare_features_for_sqss(features, labels, mode,
133                               sequence_feature_columns,
134                               context_feature_columns):
135  """Prepares features for batching by the SQSS.
136
137  In preparation for batching by the SQSS, this function:
138  - Extracts the input key from the features dict.
139  - Separates sequence and context features dicts from the features dict.
140  - Adds the labels tensor to the sequence features dict.
141
142  Args:
143    features: A dict of Python string to an iterable of `Tensor` or
144      `SparseTensor` of rank 2, the `features` argument of a TF.Learn model_fn.
145    labels: An iterable of `Tensor`.
146    mode: Defines whether this is training, evaluation or prediction.
147      See `ModeKeys`.
148    sequence_feature_columns: An iterable containing all the feature columns
149      describing sequence features. All items in the set should be instances
150      of classes derived from `FeatureColumn`.
151    context_feature_columns: An iterable containing all the feature columns
152      describing context features, i.e., features that apply across all time
153      steps. All items in the set should be instances of classes derived from
154      `FeatureColumn`.
155
156  Returns:
157    sequence_features: A dict mapping feature names to sequence features.
158    context_features: A dict mapping feature names to context features.
159
160  Raises:
161    ValueError: If `features` does not contain a value for every key in
162      `sequence_feature_columns` or `context_feature_columns`.
163  """
164
165  # Extract sequence features.
166  feature_column_ops._check_supported_sequence_columns(sequence_feature_columns)  # pylint: disable=protected-access
167  sequence_features = {}
168  for column in sequence_feature_columns:
169    for name in _get_name_or_parent_names(column):
170      feature = features.get(name, None)
171      if feature is None:
172        raise ValueError('No key in features for sequence feature: ' + name)
173      sequence_features[name] = feature
174
175  # Extract context features.
176  context_features = {}
177  if context_feature_columns is not None:
178    for column in context_feature_columns:
179      name = column.name
180      feature = features.get(name, None)
181      if feature is None:
182        raise ValueError('No key in features for context feature: ' + name)
183      context_features[name] = feature
184
185  # Add labels to the resulting sequence features dict.
186  if mode != model_fn.ModeKeys.INFER:
187    sequence_features[rnn_common.RNNKeys.LABELS_KEY] = labels
188
189  return sequence_features, context_features
190
191
192def _get_state_names(cell):
193  """Gets the state names for an `RNNCell`.
194
195  Args:
196    cell: A `RNNCell` to be used in the RNN.
197
198  Returns:
199    State names in the form of a string, a list of strings, or a list of
200    string pairs, depending on the type of `cell.state_size`.
201
202  Raises:
203    TypeError: If cell.state_size is of type TensorShape.
204  """
205  state_size = cell.state_size
206  if isinstance(state_size, tensor_shape.TensorShape):
207    raise TypeError('cell.state_size of type TensorShape is not supported.')
208  if isinstance(state_size, int):
209    return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, 0)
210  if isinstance(state_size, rnn_cell.LSTMStateTuple):
211    return [
212        '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, 0),
213        '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, 0),
214    ]
215  if isinstance(state_size[0], rnn_cell.LSTMStateTuple):
216    return [[
217        '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, i),
218        '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, i),
219    ] for i in range(len(state_size))]
220  return [
221      '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
222      for i in range(len(state_size))]
223
224
225def _get_initial_states(cell):
226  """Gets the initial state of the `RNNCell` used in the RNN.
227
228  Args:
229    cell: A `RNNCell` to be used in the RNN.
230
231  Returns:
232    A Python dict mapping state names to the `RNNCell`'s initial state for
233    consumption by the SQSS.
234  """
235  names = nest.flatten(_get_state_names(cell))
236  values = nest.flatten(cell.zero_state(1, dtype=dtypes.float32))
237  return {n: array_ops.squeeze(v, axis=0) for [n, v] in zip(names, values)}
238
239
240def _read_batch(cell,
241                features,
242                labels,
243                mode,
244                num_unroll,
245                batch_size,
246                sequence_feature_columns,
247                context_feature_columns=None,
248                num_threads=3,
249                queue_capacity=1000,
250                seed=None):
251  """Reads a batch from a state saving sequence queue.
252
253  Args:
254    cell: An initialized `RNNCell` to be used in the RNN.
255    features: A dict of Python string to an iterable of `Tensor`, the
256      `features` argument of a TF.Learn model_fn.
257    labels: An iterable of `Tensor`, the `labels` argument of a
258      TF.Learn model_fn.
259    mode: Defines whether this is training, evaluation or prediction.
260      See `ModeKeys`.
261    num_unroll: Python integer, how many time steps to unroll at a time.
262      The input sequences of length `k` are then split into `k / num_unroll`
263      many segments.
264    batch_size: Python integer, the size of the minibatch produced by the SQSS.
265    sequence_feature_columns: An iterable containing all the feature columns
266      describing sequence features. All items in the set should be instances
267      of classes derived from `FeatureColumn`.
268    context_feature_columns: An iterable containing all the feature columns
269      describing context features, i.e., features that apply across all time
270      steps. All items in the set should be instances of classes derived from
271      `FeatureColumn`.
272    num_threads: The Python integer number of threads enqueuing input examples
273      into a queue. Defaults to 3.
274    queue_capacity: The max capacity of the queue in number of examples.
275      Needs to be at least `batch_size`. Defaults to 1000. When iterating
276      over the same input example multiple times reusing their keys the
277      `queue_capacity` must be smaller than the number of examples.
278    seed: Fixes the random seed used for generating input keys by the SQSS.
279
280  Returns:
281    batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample`
282      values and their saved internal states.
283  """
284  states = _get_initial_states(cell)
285
286  sequences, context = _prepare_features_for_sqss(
287      features, labels, mode, sequence_feature_columns,
288      context_feature_columns)
289
290  return sqss.batch_sequences_with_states(
291      input_key='key',
292      input_sequences=sequences,
293      input_context=context,
294      input_length=None,  # infer sequence lengths
295      initial_states=states,
296      num_unroll=num_unroll,
297      batch_size=batch_size,
298      pad=True,  # pad to a multiple of num_unroll
299      make_keys_unique=True,
300      make_keys_unique_seed=seed,
301      num_threads=num_threads,
302      capacity=queue_capacity)
303
304
305def _get_state_name(i):
306  """Constructs the name string for state component `i`."""
307  return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
308
309
310def state_tuple_to_dict(state):
311  """Returns a dict containing flattened `state`.
312
313  Args:
314    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
315    have the same rank and agree on all dimensions except the last.
316
317  Returns:
318    A dict containing the `Tensor`s that make up `state`. The keys of the dict
319    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
320    in a depth-first traversal of `state`.
321  """
322  with ops.name_scope('state_tuple_to_dict'):
323    flat_state = nest.flatten(state)
324    state_dict = {}
325    for i, state_component in enumerate(flat_state):
326      state_name = _get_state_name(i)
327      state_value = (None if state_component is None else array_ops.identity(
328          state_component, name=state_name))
329      state_dict[state_name] = state_value
330  return state_dict
331
332
333def _prepare_inputs_for_rnn(sequence_features, context_features,
334                            sequence_feature_columns, num_unroll):
335  """Prepares features batched by the SQSS for input to a state-saving RNN.
336
337  Args:
338    sequence_features: A dict of sequence feature name to `Tensor` or
339      `SparseTensor`, with `Tensor`s of shape `[batch_size, num_unroll, ...]`
340      or `SparseTensors` of dense shape `[batch_size, num_unroll, d]`.
341    context_features: A dict of context feature name to `Tensor`, with
342      tensors of shape `[batch_size, 1, ...]` and type float32.
343    sequence_feature_columns: An iterable containing all the feature columns
344      describing sequence features. All items in the set should be instances
345      of classes derived from `FeatureColumn`.
346    num_unroll: Python integer, how many time steps to unroll at a time.
347      The input sequences of length `k` are then split into `k / num_unroll`
348      many segments.
349
350  Returns:
351    features_by_time: A list of length `num_unroll` with `Tensor` entries of
352      shape `[batch_size, sum(sequence_features dimensions) +
353      sum(context_features dimensions)]` of type float32.
354      Context features are copied into each time step.
355  """
356
357  def _tile(feature):
358    return array_ops.squeeze(
359        array_ops.tile(array_ops.expand_dims(feature, 1), [1, num_unroll, 1]),
360        axis=2)
361  for feature in sequence_features.values():
362    if isinstance(feature, sparse_tensor.SparseTensor):
363      # Explicitly set dense_shape's shape to 3 ([batch_size, num_unroll, d])
364      # since it can't be statically inferred.
365      feature.dense_shape.set_shape([3])
366  sequence_features = layers.sequence_input_from_feature_columns(
367      columns_to_tensors=sequence_features,
368      feature_columns=sequence_feature_columns,
369      weight_collections=None,
370      scope=None)
371  # Explicitly set shape along dimension 1 to num_unroll for the unstack op.
372  sequence_features.set_shape([None, num_unroll, None])
373
374  if not context_features:
375    return array_ops.unstack(sequence_features, axis=1)
376  # TODO(jtbates): Call layers.input_from_feature_columns for context features.
377  context_features = [
378      _tile(context_features[k]) for k in sorted(context_features)
379  ]
380  return array_ops.unstack(
381      array_ops.concat(
382          [sequence_features, array_ops.stack(context_features, 2)], axis=2),
383      axis=1)
384
385
386def _get_rnn_model_fn(cell_type,
387                      target_column,
388                      problem_type,
389                      optimizer,
390                      num_unroll,
391                      num_units,
392                      num_threads,
393                      queue_capacity,
394                      batch_size,
395                      sequence_feature_columns,
396                      context_feature_columns=None,
397                      predict_probabilities=False,
398                      learning_rate=None,
399                      gradient_clipping_norm=None,
400                      dropout_keep_probabilities=None,
401                      name='StateSavingRNNModel',
402                      seed=None):
403  """Creates a state saving RNN model function for an `Estimator`.
404
405  Args:
406    cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
407    target_column: An initialized `TargetColumn`, used to calculate prediction
408      and loss.
409    problem_type: `ProblemType.CLASSIFICATION` or
410    `ProblemType.LINEAR_REGRESSION`.
411    optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a
412      string.
413    num_unroll: Python integer, how many time steps to unroll at a time.
414      The input sequences of length `k` are then split into `k / num_unroll`
415      many segments.
416    num_units: The number of units in the `RNNCell`.
417    num_threads: The Python integer number of threads enqueuing input examples
418      into a queue.
419    queue_capacity: The max capacity of the queue in number of examples.
420      Needs to be at least `batch_size`. When iterating over the same input
421      example multiple times reusing their keys the `queue_capacity` must be
422      smaller than the number of examples.
423    batch_size: Python integer, the size of the minibatch produced by the SQSS.
424    sequence_feature_columns: An iterable containing all the feature columns
425      describing sequence features. All items in the set should be instances
426      of classes derived from `FeatureColumn`.
427    context_feature_columns: An iterable containing all the feature columns
428      describing context features, i.e., features that apply across all time
429      steps. All items in the set should be instances of classes derived from
430      `FeatureColumn`.
431    predict_probabilities: A boolean indicating whether to predict probabilities
432      for all classes.
433      Must only be used with `ProblemType.CLASSIFICATION`.
434    learning_rate: Learning rate used for optimization. This argument has no
435      effect if `optimizer` is an instance of an `Optimizer`.
436    gradient_clipping_norm: A float. Gradients will be clipped to this value.
437    dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
438      If given a list, it must have length `len(num_units) + 1`.
439    name: A string that will be used to create a scope for the RNN.
440    seed: Fixes the random seed used for generating input keys by the SQSS.
441
442  Returns:
443    A model function to be passed to an `Estimator`.
444
445  Raises:
446    ValueError: `problem_type` is not one of
447      `ProblemType.LINEAR_REGRESSION`
448      or `ProblemType.CLASSIFICATION`.
449    ValueError: `predict_probabilities` is `True` for `problem_type` other
450      than `ProblemType.CLASSIFICATION`.
451    ValueError: `num_unroll` is not positive.
452  """
453  if problem_type not in (constants.ProblemType.CLASSIFICATION,
454                          constants.ProblemType.LINEAR_REGRESSION):
455    raise ValueError(
456        'problem_type must be ProblemType.LINEAR_REGRESSION or '
457        'ProblemType.CLASSIFICATION; got {}'.
458        format(problem_type))
459  if (problem_type != constants.ProblemType.CLASSIFICATION and
460      predict_probabilities):
461    raise ValueError(
462        'predict_probabilities can only be set to True for problem_type'
463        ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type))
464  if num_unroll <= 0:
465    raise ValueError('num_unroll must be positive; got {}.'.format(num_unroll))
466
467  def _rnn_model_fn(features, labels, mode):
468    """The model to be passed to an `Estimator`."""
469    with ops.name_scope(name):
470      dropout = (dropout_keep_probabilities
471                 if mode == model_fn.ModeKeys.TRAIN
472                 else None)
473      cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
474
475      batch = _read_batch(
476          cell=cell,
477          features=features,
478          labels=labels,
479          mode=mode,
480          num_unroll=num_unroll,
481          batch_size=batch_size,
482          sequence_feature_columns=sequence_feature_columns,
483          context_feature_columns=context_feature_columns,
484          num_threads=num_threads,
485          queue_capacity=queue_capacity,
486          seed=seed)
487      sequence_features = batch.sequences
488      context_features = batch.context
489      if mode != model_fn.ModeKeys.INFER:
490        labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY)
491      inputs = _prepare_inputs_for_rnn(sequence_features, context_features,
492                                       sequence_feature_columns, num_unroll)
493      state_name = _get_state_names(cell)
494      rnn_activations, final_state = construct_state_saving_rnn(
495          cell=cell,
496          inputs=inputs,
497          num_label_columns=target_column.num_label_columns,
498          state_saver=batch,
499          state_name=state_name)
500
501      loss = None  # Created below for modes TRAIN and EVAL.
502      prediction_dict = rnn_common.multi_value_predictions(
503          rnn_activations, target_column, problem_type, predict_probabilities)
504      if mode != model_fn.ModeKeys.INFER:
505        loss = _multi_value_loss(rnn_activations, labels, batch.length,
506                                 target_column, features)
507
508      eval_metric_ops = None
509      if mode != model_fn.ModeKeys.INFER:
510        eval_metric_ops = rnn_common.get_eval_metric_ops(
511            problem_type, rnn_common.PredictionType.MULTIPLE_VALUE,
512            batch.length, prediction_dict, labels)
513
514      state_dict = state_tuple_to_dict(final_state)
515      prediction_dict.update(state_dict)
516
517      train_op = None
518      if mode == model_fn.ModeKeys.TRAIN:
519        train_op = optimizers.optimize_loss(
520            loss=loss,
521            global_step=None,  # Get it internally.
522            learning_rate=learning_rate,
523            optimizer=optimizer,
524            clip_gradients=gradient_clipping_norm,
525            summaries=optimizers.OPTIMIZER_SUMMARIES)
526
527    return model_fn.ModelFnOps(mode=mode,
528                               predictions=prediction_dict,
529                               loss=loss,
530                               train_op=train_op,
531                               eval_metric_ops=eval_metric_ops)
532  return _rnn_model_fn
533
534
535class StateSavingRnnEstimator(estimator.Estimator):
536  """RNN with static unrolling and state saving (deprecated).
537
538  THIS CLASS IS DEPRECATED. See
539  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
540  for general migration instructions.
541  """
542
543  def __init__(self,
544               problem_type,
545               num_unroll,
546               batch_size,
547               sequence_feature_columns,
548               context_feature_columns=None,
549               num_classes=None,
550               num_units=None,
551               cell_type='basic_rnn',
552               optimizer_type='SGD',
553               learning_rate=0.1,
554               predict_probabilities=False,
555               momentum=None,
556               gradient_clipping_norm=5.0,
557               dropout_keep_probabilities=None,
558               model_dir=None,
559               config=None,
560               feature_engineering_fn=None,
561               num_threads=3,
562               queue_capacity=1000,
563               seed=None):
564    """Initializes a StateSavingRnnEstimator.
565
566    Args:
567      problem_type: `ProblemType.CLASSIFICATION` or
568        `ProblemType.LINEAR_REGRESSION`.
569      num_unroll: Python integer, how many time steps to unroll at a time.
570        The input sequences of length `k` are then split into `k / num_unroll`
571        many segments.
572      batch_size: Python integer, the size of the minibatch.
573      sequence_feature_columns: An iterable containing all the feature columns
574        describing sequence features. All items in the set should be instances
575        of classes derived from `FeatureColumn`.
576      context_feature_columns: An iterable containing all the feature columns
577        describing context features, i.e., features that apply across all time
578        steps. All items in the set should be instances of classes derived from
579        `FeatureColumn`.
580      num_classes: The number of classes for categorization. Used only and
581        required if `problem_type` is `ProblemType.CLASSIFICATION`.
582      num_units: A list of integers indicating the number of units in the
583        `RNNCell`s in each layer. Either `num_units` is specified or `cell_type`
584        is an instance of `RNNCell`.
585      cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
586      optimizer_type: The type of optimizer to use. Either a subclass of
587        `Optimizer`, an instance of an `Optimizer` or a string. Strings must be
588        one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'.
589      learning_rate: Learning rate. This argument has no effect if `optimizer`
590        is an instance of an `Optimizer`.
591      predict_probabilities: A boolean indicating whether to predict
592        probabilities for all classes. Used only if `problem_type` is
593        `ProblemType.CLASSIFICATION`.
594      momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
595      gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
596        then no clipping is performed.
597      dropout_keep_probabilities: a list of dropout keep probabilities or
598        `None`. If given a list, it must have length `len(num_units) + 1`.
599      model_dir: The directory in which to save and restore the model graph,
600        parameters, etc.
601      config: A `RunConfig` instance.
602      feature_engineering_fn: Takes features and labels which are the output of
603        `input_fn` and returns features and labels which will be fed into
604        `model_fn`. Please check `model_fn` for a definition of features and
605        labels.
606      num_threads: The Python integer number of threads enqueuing input examples
607        into a queue. Defaults to 3.
608      queue_capacity: The max capacity of the queue in number of examples.
609        Needs to be at least `batch_size`. Defaults to 1000. When iterating
610        over the same input example multiple times reusing their keys the
611        `queue_capacity` must be smaller than the number of examples.
612      seed: Fixes the random seed used for generating input keys by the SQSS.
613
614    Raises:
615      ValueError: Both or neither of the following are true: (a) `num_units` is
616        specified and (b) `cell_type` is an instance of `RNNCell`.
617      ValueError: `problem_type` is not one of
618        `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
619      ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
620        `num_classes` is not specified.
621    """
622    name = 'MultiValueStateSavingRNN'
623    if problem_type == constants.ProblemType.LINEAR_REGRESSION:
624      name += 'Regressor'
625      target_column = layers.regression_target()
626    elif problem_type == constants.ProblemType.CLASSIFICATION:
627      if not num_classes:
628        raise ValueError('For CLASSIFICATION problem_type, num_classes must be '
629                         'specified.')
630      target_column = layers.multi_class_target(n_classes=num_classes)
631      name += 'Classifier'
632    else:
633      raise ValueError(
634          'problem_type must be either ProblemType.LINEAR_REGRESSION '
635          'or ProblemType.CLASSIFICATION; got {}'.format(
636              problem_type))
637
638    if optimizer_type == 'Momentum':
639      optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum)
640
641    rnn_model_fn = _get_rnn_model_fn(
642        cell_type=cell_type,
643        target_column=target_column,
644        problem_type=problem_type,
645        optimizer=optimizer_type,
646        num_unroll=num_unroll,
647        num_units=num_units,
648        num_threads=num_threads,
649        queue_capacity=queue_capacity,
650        batch_size=batch_size,
651        sequence_feature_columns=sequence_feature_columns,
652        context_feature_columns=context_feature_columns,
653        predict_probabilities=predict_probabilities,
654        learning_rate=learning_rate,
655        gradient_clipping_norm=gradient_clipping_norm,
656        dropout_keep_probabilities=dropout_keep_probabilities,
657        name=name,
658        seed=seed)
659
660    super(StateSavingRnnEstimator, self).__init__(
661        model_fn=rnn_model_fn,
662        model_dir=model_dir,
663        config=config,
664        feature_engineering_fn=feature_engineering_fn)
665