• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Estimator for Dynamic RNNs (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.contrib import layers
27from tensorflow.contrib.layers.python.layers import optimizers
28from tensorflow.contrib.learn.python.learn.estimators import constants
29from tensorflow.contrib.learn.python.learn.estimators import estimator
30from tensorflow.contrib.learn.python.learn.estimators import model_fn
31from tensorflow.contrib.learn.python.learn.estimators import prediction_key
32from tensorflow.contrib.learn.python.learn.estimators import rnn_common
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import rnn
39from tensorflow.python.training import momentum as momentum_opt
40from tensorflow.python.util import nest
41
42
43# TODO(jtbates): Remove PredictionType when all non-experimental targets which
44# depend on it point to rnn_common.PredictionType.
45class PredictionType(object):
46  SINGLE_VALUE = 1
47  MULTIPLE_VALUE = 2
48
49
50def _get_state_name(i):
51  """Constructs the name string for state component `i`."""
52  return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i)
53
54
55def state_tuple_to_dict(state):
56  """Returns a dict containing flattened `state`.
57
58  Args:
59    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
60    have the same rank and agree on all dimensions except the last.
61
62  Returns:
63    A dict containing the `Tensor`s that make up `state`. The keys of the dict
64    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
65    in a depth-first traversal of `state`.
66  """
67  with ops.name_scope('state_tuple_to_dict'):
68    flat_state = nest.flatten(state)
69    state_dict = {}
70    for i, state_component in enumerate(flat_state):
71      state_name = _get_state_name(i)
72      state_value = (None if state_component is None
73                     else array_ops.identity(state_component, name=state_name))
74      state_dict[state_name] = state_value
75  return state_dict
76
77
78def dict_to_state_tuple(input_dict, cell):
79  """Reconstructs nested `state` from a dict containing state `Tensor`s.
80
81  Args:
82    input_dict: A dict of `Tensor`s.
83    cell: An instance of `RNNCell`.
84  Returns:
85    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
86    where `n` is the number of nested entries in `cell.state_size`, this
87    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
88    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
89    tuple.
90  Raises:
91    ValueError: State is partially specified. The `input_dict` must contain
92      values for all state components or none at all.
93  """
94  flat_state_sizes = nest.flatten(cell.state_size)
95  state_tensors = []
96  with ops.name_scope('dict_to_state_tuple'):
97    for i, state_size in enumerate(flat_state_sizes):
98      state_name = _get_state_name(i)
99      state_tensor = input_dict.get(state_name)
100      if state_tensor is not None:
101        rank_check = check_ops.assert_rank(
102            state_tensor, 2, name='check_state_{}_rank'.format(i))
103        shape_check = check_ops.assert_equal(
104            array_ops.shape(state_tensor)[1],
105            state_size,
106            name='check_state_{}_shape'.format(i))
107        with ops.control_dependencies([rank_check, shape_check]):
108          state_tensor = array_ops.identity(state_tensor, name=state_name)
109        state_tensors.append(state_tensor)
110    if not state_tensors:
111      return None
112    elif len(state_tensors) == len(flat_state_sizes):
113      dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
114      return nest.pack_sequence_as(dummy_state, state_tensors)
115    else:
116      raise ValueError(
117          'RNN state was partially specified.'
118          'Expected zero or {} state Tensors; got {}'.
119          format(len(flat_state_sizes), len(state_tensors)))
120
121
122def _concatenate_context_input(sequence_input, context_input):
123  """Replicates `context_input` across all timesteps of `sequence_input`.
124
125  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
126  This value is appended to `sequence_input` on dimension 2 and the result is
127  returned.
128
129  Args:
130    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
131      padded_length, d0]`.
132    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
133
134  Returns:
135    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
136    d0 + d1]`.
137
138  Raises:
139    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
140      not have rank 2.
141  """
142  seq_rank_check = check_ops.assert_rank(
143      sequence_input,
144      3,
145      message='sequence_input must have rank 3',
146      data=[array_ops.shape(sequence_input)])
147  seq_type_check = check_ops.assert_type(
148      sequence_input,
149      dtypes.float32,
150      message='sequence_input must have dtype float32; got {}.'.format(
151          sequence_input.dtype))
152  ctx_rank_check = check_ops.assert_rank(
153      context_input,
154      2,
155      message='context_input must have rank 2',
156      data=[array_ops.shape(context_input)])
157  ctx_type_check = check_ops.assert_type(
158      context_input,
159      dtypes.float32,
160      message='context_input must have dtype float32; got {}.'.format(
161          context_input.dtype))
162  with ops.control_dependencies(
163      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
164    padded_length = array_ops.shape(sequence_input)[1]
165    tiled_context_input = array_ops.tile(
166        array_ops.expand_dims(context_input, 1),
167        array_ops.concat([[1], [padded_length], [1]], 0))
168  return array_ops.concat([sequence_input, tiled_context_input], 2)
169
170
171def build_sequence_input(features,
172                         sequence_feature_columns,
173                         context_feature_columns,
174                         weight_collections=None,
175                         scope=None):
176  """Combine sequence and context features into input for an RNN.
177
178  Args:
179    features: A `dict` containing the input and (optionally) sequence length
180      information and initial state.
181    sequence_feature_columns: An iterable containing all the feature columns
182      describing sequence features. All items in the set should be instances
183      of classes derived from `FeatureColumn`.
184    context_feature_columns: An iterable containing all the feature columns
185      describing context features i.e. features that apply across all time
186      steps. All items in the set should be instances of classes derived from
187      `FeatureColumn`.
188    weight_collections: List of graph collections to which weights are added.
189    scope: Optional scope, passed through to parsing ops.
190  Returns:
191    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`.
192    This will be used as input to an RNN.
193  """
194  features = features.copy()
195  features.update(layers.transform_features(
196      features,
197      list(sequence_feature_columns) + list(context_feature_columns or [])))
198  sequence_input = layers.sequence_input_from_feature_columns(
199      columns_to_tensors=features,
200      feature_columns=sequence_feature_columns,
201      weight_collections=weight_collections,
202      scope=scope)
203  if context_feature_columns is not None:
204    context_input = layers.input_from_feature_columns(
205        columns_to_tensors=features,
206        feature_columns=context_feature_columns,
207        weight_collections=weight_collections,
208        scope=scope)
209    sequence_input = _concatenate_context_input(sequence_input, context_input)
210  return sequence_input
211
212
213def construct_rnn(initial_state,
214                  sequence_input,
215                  cell,
216                  num_label_columns,
217                  dtype=dtypes.float32,
218                  parallel_iterations=32,
219                  swap_memory=True):
220  """Build an RNN and apply a fully connected layer to get the desired output.
221
222  Args:
223    initial_state: The initial state to pass the RNN. If `None`, the
224      default starting state for `self._cell` is used.
225    sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]`
226      that will be passed as input to the RNN.
227    cell: An initialized `RNNCell`.
228    num_label_columns: The desired output dimension.
229    dtype: dtype of `cell`.
230    parallel_iterations: Number of iterations to run in parallel. Values >> 1
231      use more memory but take less time, while smaller values use less memory
232      but computations take longer.
233    swap_memory: Transparently swap the tensors produced in forward inference
234      but needed for back prop from GPU to CPU.  This allows training RNNs
235      which would typically not fit on a single GPU, with very minimal (or no)
236      performance penalty.
237  Returns:
238    activations: The output of the RNN, projected to `num_label_columns`
239      dimensions.
240    final_state: A `Tensor` or nested tuple of `Tensor`s representing the final
241      state output by the RNN.
242  """
243  with ops.name_scope('RNN'):
244    rnn_outputs, final_state = rnn.dynamic_rnn(
245        cell=cell,
246        inputs=sequence_input,
247        initial_state=initial_state,
248        dtype=dtype,
249        parallel_iterations=parallel_iterations,
250        swap_memory=swap_memory,
251        time_major=False)
252    activations = layers.fully_connected(
253        inputs=rnn_outputs,
254        num_outputs=num_label_columns,
255        activation_fn=None,
256        trainable=True)
257    return activations, final_state
258
259
260def _single_value_predictions(activations,
261                              sequence_length,
262                              target_column,
263                              problem_type,
264                              predict_probabilities):
265  """Maps `activations` from the RNN to predictions for single value models.
266
267  If `predict_probabilities` is `False`, this function returns a `dict`
268  containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities`
269  is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The
270  value of this entry is a `Tensor` of probabilities with shape
271  `[batch_size, num_classes]`.
272
273  Args:
274    activations: Output from an RNN. Should have dtype `float32` and shape
275      `[batch_size, padded_length, ?]`.
276    sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
277      containing the length of each sequence in the batch. If `None`, sequences
278      are assumed to be unpadded.
279    target_column: An initialized `TargetColumn`, calculate predictions.
280    problem_type: Either `ProblemType.CLASSIFICATION` or
281      `ProblemType.LINEAR_REGRESSION`.
282    predict_probabilities: A Python boolean, indicating whether probabilities
283      should be returned. Should only be set to `True` for
284      classification/logistic regression problems.
285  Returns:
286    A `dict` mapping strings to `Tensors`.
287  """
288  with ops.name_scope('SingleValuePrediction'):
289    last_activations = rnn_common.select_last_activations(
290        activations, sequence_length)
291    predictions_name = (prediction_key.PredictionKey.CLASSES
292                        if problem_type == constants.ProblemType.CLASSIFICATION
293                        else prediction_key.PredictionKey.SCORES)
294    if predict_probabilities:
295      probabilities = target_column.logits_to_predictions(
296          last_activations, proba=True)
297      prediction_dict = {
298          prediction_key.PredictionKey.PROBABILITIES: probabilities,
299          predictions_name: math_ops.argmax(probabilities, 1)}
300    else:
301      predictions = target_column.logits_to_predictions(
302          last_activations, proba=False)
303      prediction_dict = {predictions_name: predictions}
304    return prediction_dict
305
306
307def _multi_value_loss(
308    activations, labels, sequence_length, target_column, features):
309  """Maps `activations` from the RNN to loss for multi value models.
310
311  Args:
312    activations: Output from an RNN. Should have dtype `float32` and shape
313      `[batch_size, padded_length, ?]`.
314    labels: A `Tensor` with length `[batch_size, padded_length]`.
315    sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
316      containing the length of each sequence in the batch. If `None`, sequences
317      are assumed to be unpadded.
318    target_column: An initialized `TargetColumn`, calculate predictions.
319    features: A `dict` containing the input and (optionally) sequence length
320      information and initial state.
321  Returns:
322    A scalar `Tensor` containing the loss.
323  """
324  with ops.name_scope('MultiValueLoss'):
325    activations_masked, labels_masked = rnn_common.mask_activations_and_labels(
326        activations, labels, sequence_length)
327    return target_column.loss(activations_masked, labels_masked, features)
328
329
330def _single_value_loss(
331    activations, labels, sequence_length, target_column, features):
332  """Maps `activations` from the RNN to loss for multi value models.
333
334  Args:
335    activations: Output from an RNN. Should have dtype `float32` and shape
336      `[batch_size, padded_length, ?]`.
337    labels: A `Tensor` with length `[batch_size]`.
338    sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
339      containing the length of each sequence in the batch. If `None`, sequences
340      are assumed to be unpadded.
341    target_column: An initialized `TargetColumn`, calculate predictions.
342    features: A `dict` containing the input and (optionally) sequence length
343      information and initial state.
344  Returns:
345    A scalar `Tensor` containing the loss.
346  """
347
348  with ops.name_scope('SingleValueLoss'):
349    last_activations = rnn_common.select_last_activations(
350        activations, sequence_length)
351    return target_column.loss(last_activations, labels, features)
352
353
354def _get_output_alternatives(prediction_type,
355                             problem_type,
356                             prediction_dict):
357  """Constructs output alternatives dict for `ModelFnOps`.
358
359  Args:
360    prediction_type: either `MULTIPLE_VALUE` or `SINGLE_VALUE`.
361    problem_type: either `CLASSIFICATION` or `LINEAR_REGRESSION`.
362    prediction_dict: a dictionary mapping strings to `Tensor`s containing
363      predictions.
364
365  Returns:
366    `None` or a dictionary mapping a string to an output alternative.
367
368  Raises:
369    ValueError: `prediction_type` is not one of `SINGLE_VALUE` or
370    `MULTIPLE_VALUE`.
371  """
372  if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
373    return None
374  if prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
375    prediction_dict_no_state = {
376        k: v
377        for k, v in prediction_dict.items()
378        if rnn_common.RNNKeys.STATE_PREFIX not in k
379    }
380    return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)}
381  raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type))
382
383
384def _get_dynamic_rnn_model_fn(
385    cell_type,
386    num_units,
387    target_column,
388    problem_type,
389    prediction_type,
390    optimizer,
391    sequence_feature_columns,
392    context_feature_columns=None,
393    predict_probabilities=False,
394    learning_rate=None,
395    gradient_clipping_norm=None,
396    dropout_keep_probabilities=None,
397    sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY,
398    dtype=dtypes.float32,
399    parallel_iterations=None,
400    swap_memory=True,
401    name='DynamicRNNModel'):
402  """Creates an RNN model function for an `Estimator`.
403
404  The model function returns an instance of `ModelFnOps`. When
405  `problem_type == ProblemType.CLASSIFICATION` and
406  `predict_probabilities == True`, the returned `ModelFnOps` includes an output
407  alternative containing the classes and their associated probabilities. When
408  `predict_probabilities == False`, only the classes are included. When
409  `problem_type == ProblemType.LINEAR_REGRESSION`, the output alternative
410  contains only the predicted values.
411
412  Args:
413    cell_type: A string, a subclass of `RNNCell` or an instance of an `RNNCell`.
414    num_units: A single `int` or a list of `int`s. The size of the `RNNCell`s.
415    target_column: An initialized `TargetColumn`, used to calculate prediction
416      and loss.
417    problem_type: `ProblemType.CLASSIFICATION` or
418      `ProblemType.LINEAR_REGRESSION`.
419    prediction_type: `PredictionType.SINGLE_VALUE` or
420      `PredictionType.MULTIPLE_VALUE`.
421    optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a
422      string.
423    sequence_feature_columns: An iterable containing all the feature columns
424      describing sequence features. All items in the set should be instances
425      of classes derived from `FeatureColumn`.
426    context_feature_columns: An iterable containing all the feature columns
427      describing context features, i.e., features that apply across all time
428      steps. All items in the set should be instances of classes derived from
429      `FeatureColumn`.
430    predict_probabilities: A boolean indicating whether to predict probabilities
431      for all classes. Must only be used with
432      `ProblemType.CLASSIFICATION`.
433    learning_rate: Learning rate used for optimization. This argument has no
434      effect if `optimizer` is an instance of an `Optimizer`.
435    gradient_clipping_norm: A float. Gradients will be clipped to this value.
436    dropout_keep_probabilities: a list of dropout keep probabilities or `None`.
437      If a list is given, it must have length `len(num_units) + 1`.
438    sequence_length_key: The key that will be used to look up sequence length in
439      the `features` dict.
440    dtype: The dtype of the state and output of the given `cell`.
441    parallel_iterations: Number of iterations to run in parallel. Values >> 1
442      use more memory but take less time, while smaller values use less memory
443      but computations take longer.
444    swap_memory: Transparently swap the tensors produced in forward inference
445      but needed for back prop from GPU to CPU.  This allows training RNNs
446      which would typically not fit on a single GPU, with very minimal (or no)
447      performance penalty.
448    name: A string that will be used to create a scope for the RNN.
449
450  Returns:
451    A model function to be passed to an `Estimator`.
452
453  Raises:
454    ValueError: `problem_type` is not one of
455      `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
456    ValueError: `prediction_type` is not one of `PredictionType.SINGLE_VALUE`
457      or `PredictionType.MULTIPLE_VALUE`.
458    ValueError: `predict_probabilities` is `True` for `problem_type` other
459      than `ProblemType.CLASSIFICATION`.
460    ValueError: `len(dropout_keep_probabilities)` is not `len(num_units) + 1`.
461  """
462  if problem_type not in (constants.ProblemType.CLASSIFICATION,
463                          constants.ProblemType.LINEAR_REGRESSION):
464    raise ValueError(
465        'problem_type must be ProblemType.LINEAR_REGRESSION or '
466        'ProblemType.CLASSIFICATION; got {}'.
467        format(problem_type))
468  if prediction_type not in (rnn_common.PredictionType.SINGLE_VALUE,
469                             rnn_common.PredictionType.MULTIPLE_VALUE):
470    raise ValueError(
471        'prediction_type must be PredictionType.MULTIPLE_VALUEs or '
472        'PredictionType.SINGLE_VALUE; got {}'.
473        format(prediction_type))
474  if (problem_type != constants.ProblemType.CLASSIFICATION
475      and predict_probabilities):
476    raise ValueError(
477        'predict_probabilities can only be set to True for problem_type'
478        ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type))
479  def _dynamic_rnn_model_fn(features, labels, mode):
480    """The model to be passed to an `Estimator`."""
481    with ops.name_scope(name):
482      sequence_length = features.get(sequence_length_key)
483      sequence_input = build_sequence_input(features,
484                                            sequence_feature_columns,
485                                            context_feature_columns)
486      dropout = (dropout_keep_probabilities
487                 if mode == model_fn.ModeKeys.TRAIN
488                 else None)
489      # This class promises to use the cell type selected by that function.
490      cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout)
491      initial_state = dict_to_state_tuple(features, cell)
492      rnn_activations, final_state = construct_rnn(
493          initial_state,
494          sequence_input,
495          cell,
496          target_column.num_label_columns,
497          dtype=dtype,
498          parallel_iterations=parallel_iterations,
499          swap_memory=swap_memory)
500
501      loss = None  # Created below for modes TRAIN and EVAL.
502      if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
503        prediction_dict = rnn_common.multi_value_predictions(
504            rnn_activations, target_column, problem_type, predict_probabilities)
505        if mode != model_fn.ModeKeys.INFER:
506          loss = _multi_value_loss(
507              rnn_activations, labels, sequence_length, target_column, features)
508      elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
509        prediction_dict = _single_value_predictions(
510            rnn_activations, sequence_length, target_column,
511            problem_type, predict_probabilities)
512        if mode != model_fn.ModeKeys.INFER:
513          loss = _single_value_loss(
514              rnn_activations, labels, sequence_length, target_column, features)
515      state_dict = state_tuple_to_dict(final_state)
516      prediction_dict.update(state_dict)
517
518      eval_metric_ops = None
519      if mode != model_fn.ModeKeys.INFER:
520        eval_metric_ops = rnn_common.get_eval_metric_ops(
521            problem_type, prediction_type, sequence_length, prediction_dict,
522            labels)
523
524      train_op = None
525      if mode == model_fn.ModeKeys.TRAIN:
526        train_op = optimizers.optimize_loss(
527            loss=loss,
528            global_step=None,  # Get it internally.
529            learning_rate=learning_rate,
530            optimizer=optimizer,
531            clip_gradients=gradient_clipping_norm,
532            summaries=optimizers.OPTIMIZER_SUMMARIES)
533
534    output_alternatives = _get_output_alternatives(prediction_type,
535                                                   problem_type,
536                                                   prediction_dict)
537
538    return model_fn.ModelFnOps(mode=mode,
539                               predictions=prediction_dict,
540                               loss=loss,
541                               train_op=train_op,
542                               eval_metric_ops=eval_metric_ops,
543                               output_alternatives=output_alternatives)
544  return _dynamic_rnn_model_fn
545
546
547class DynamicRnnEstimator(estimator.Estimator):
548  """Dynamically unrolled RNN (deprecated).
549
550  THIS CLASS IS DEPRECATED. See
551  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
552  for general migration instructions.
553  """
554
555  def __init__(self,
556               problem_type,
557               prediction_type,
558               sequence_feature_columns,
559               context_feature_columns=None,
560               num_classes=None,
561               num_units=None,
562               cell_type='basic_rnn',
563               optimizer='SGD',
564               learning_rate=0.1,
565               predict_probabilities=False,
566               momentum=None,
567               gradient_clipping_norm=5.0,
568               dropout_keep_probabilities=None,
569               model_dir=None,
570               feature_engineering_fn=None,
571               config=None):
572    """Initializes a `DynamicRnnEstimator`.
573
574    The input function passed to this `Estimator` optionally contains keys
575    `RNNKeys.SEQUENCE_LENGTH_KEY`. The value corresponding to
576    `RNNKeys.SEQUENCE_LENGTH_KEY` must be vector of size `batch_size` where
577    entry `n` corresponds to the length of the `n`th sequence in the batch. The
578    sequence length feature is required for batches of varying sizes. It will be
579    used to calculate loss and evaluation metrics. If
580    `RNNKeys.SEQUENCE_LENGTH_KEY` is not included, all sequences are assumed to
581    have length equal to the size of dimension 1 of the input to the RNN.
582
583    In order to specify an initial state, the input function must include keys
584    `STATE_PREFIX_i` for all `0 <= i < n` where `n` is the number of nested
585    elements in `cell.state_size`. The input function must contain values for
586    all state components or none of them. If none are included, then the default
587    (zero) state is used as an initial state. See the documentation for
588    `dict_to_state_tuple` and `state_tuple_to_dict` for further details.
589    The input function can call rnn_common.construct_rnn_cell() to obtain the
590    same cell type that this class will select from arguments to __init__.
591
592    The `predict()` method of the `Estimator` returns a dictionary with keys
593    `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements
594    in `cell.state_size`, along with `PredictionKey.CLASSES` for problem type
595    `CLASSIFICATION` or `PredictionKey.SCORES` for problem type
596    `LINEAR_REGRESSION`.  The value keyed by
597    `PredictionKey.CLASSES` or `PredictionKey.SCORES` has shape
598    `[batch_size, padded_length]` in the multi-value case and shape
599    `[batch_size]` in the single-value case.  Here, `padded_length` is the
600    largest value in the `RNNKeys.SEQUENCE_LENGTH` `Tensor` passed as input.
601    Entry `[i, j]` is the prediction associated with sequence `i` and time step
602    `j`. If the problem type is `CLASSIFICATION` and `predict_probabilities` is
603    `True`, it will also include key`PredictionKey.PROBABILITIES`.
604
605    Args:
606      problem_type: whether the `Estimator` is intended for a regression or
607        classification problem. Value must be one of
608        `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`.
609      prediction_type: whether the `Estimator` should return a value for each
610        step in the sequence, or just a single value for the final time step.
611        Must be one of `PredictionType.SINGLE_VALUE` or
612        `PredictionType.MULTIPLE_VALUE`.
613      sequence_feature_columns: An iterable containing all the feature columns
614        describing sequence features. All items in the iterable should be
615        instances of classes derived from `FeatureColumn`.
616      context_feature_columns: An iterable containing all the feature columns
617        describing context features, i.e., features that apply across all time
618        steps. All items in the set should be instances of classes derived from
619        `FeatureColumn`.
620      num_classes: the number of classes for a classification problem. Only
621        used when `problem_type=ProblemType.CLASSIFICATION`.
622      num_units: A list of integers indicating the number of units in the
623        `RNNCell`s in each layer.
624      cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'.
625      optimizer: The type of optimizer to use. Either a subclass of
626        `Optimizer`, an instance of an `Optimizer`, a callback that returns an
627        optimizer, or a string. Strings must be one of 'Adagrad', 'Adam',
628        'Ftrl', 'Momentum', 'RMSProp' or 'SGD. See `layers.optimize_loss` for
629        more details.
630      learning_rate: Learning rate. This argument has no effect if `optimizer`
631        is an instance of an `Optimizer`.
632      predict_probabilities: A boolean indicating whether to predict
633        probabilities for all classes. Used only if `problem_type` is
634        `ProblemType.CLASSIFICATION`
635      momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'.
636      gradient_clipping_norm: Parameter used for gradient clipping. If `None`,
637        then no clipping is performed.
638      dropout_keep_probabilities: a list of dropout probabilities or `None`.
639        If a list is given, it must have length `len(num_units) + 1`. If
640        `None`, then no dropout is applied.
641      model_dir: The directory in which to save and restore the model graph,
642        parameters, etc.
643      feature_engineering_fn: Takes features and labels which are the output of
644        `input_fn` and returns features and labels which will be fed into
645        `model_fn`. Please check `model_fn` for a definition of features and
646        labels.
647      config: A `RunConfig` instance.
648
649    Raises:
650      ValueError: `problem_type` is not one of
651        `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`.
652      ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but
653        `num_classes` is not specified.
654      ValueError: `prediction_type` is not one of
655        `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`.
656    """
657    if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE:
658      name = 'MultiValueDynamicRNN'
659    elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE:
660      name = 'SingleValueDynamicRNN'
661    else:
662      raise ValueError(
663          'prediction_type must be one of PredictionType.MULTIPLE_VALUE or '
664          'PredictionType.SINGLE_VALUE; got {}'.format(prediction_type))
665
666    if problem_type == constants.ProblemType.LINEAR_REGRESSION:
667      name += 'Regressor'
668      target_column = layers.regression_target()
669    elif problem_type == constants.ProblemType.CLASSIFICATION:
670      if not num_classes:
671        raise ValueError('For CLASSIFICATION problem_type, num_classes must be '
672                         'specified.')
673      target_column = layers.multi_class_target(n_classes=num_classes)
674      name += 'Classifier'
675    else:
676      raise ValueError(
677          'problem_type must be either ProblemType.LINEAR_REGRESSION '
678          'or ProblemType.CLASSIFICATION; got {}'.format(
679              problem_type))
680
681    if optimizer == 'Momentum':
682      optimizer = momentum_opt.MomentumOptimizer(learning_rate, momentum)
683    dynamic_rnn_model_fn = _get_dynamic_rnn_model_fn(
684        cell_type=cell_type,
685        num_units=num_units,
686        target_column=target_column,
687        problem_type=problem_type,
688        prediction_type=prediction_type,
689        optimizer=optimizer,
690        sequence_feature_columns=sequence_feature_columns,
691        context_feature_columns=context_feature_columns,
692        predict_probabilities=predict_probabilities,
693        learning_rate=learning_rate,
694        gradient_clipping_norm=gradient_clipping_norm,
695        dropout_keep_probabilities=dropout_keep_probabilities,
696        name=name)
697
698    super(DynamicRnnEstimator, self).__init__(
699        model_fn=dynamic_rnn_model_fn,
700        model_dir=model_dir,
701        config=config,
702        feature_engineering_fn=feature_engineering_fn)
703