• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common operations for RNN Estimators (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.contrib import metrics
27from tensorflow.contrib import rnn as contrib_rnn
28from tensorflow.contrib.learn.python.learn.estimators import constants
29from tensorflow.contrib.learn.python.learn.estimators import prediction_key
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import math_ops
33
34
35# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been
36# removed and replaced with values from `prediction_key.PredictionKey`. The key
37# `RNNKeys.PREDICTIONS_KEY` has been replaced by
38# `prediction_key.PredictionKey.SCORES` for regression and
39# `prediction_key.PredictionKey.CLASSES` for classification. The key
40# `RNNKeys.PROBABILITIES_KEY` has been replaced by
41# `prediction_key.PredictionKey.PROBABILITIES`.
42class RNNKeys(object):
43  FINAL_STATE_KEY = 'final_state'
44  LABELS_KEY = '__labels__'
45  SEQUENCE_LENGTH_KEY = 'sequence_length'
46  STATE_PREFIX = 'rnn_cell_state'
47
48
49class PredictionType(object):
50  """Enum-like values for the type of prediction that the model makes.
51  """
52  SINGLE_VALUE = 1
53  MULTIPLE_VALUE = 2
54
55
56_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell,
57               'lstm': contrib_rnn.LSTMCell,
58               'gru': contrib_rnn.GRUCell,}
59
60
61def _get_single_cell(cell_type, num_units):
62  """Constructs and return a single `RNNCell`.
63
64  Args:
65    cell_type: Either a string identifying the `RNNCell` type or a subclass of
66      `RNNCell`.
67    num_units: The number of units in the `RNNCell`.
68  Returns:
69    An initialized `RNNCell`.
70  Raises:
71    ValueError: `cell_type` is an invalid `RNNCell` name.
72    TypeError: `cell_type` is not a string or a subclass of `RNNCell`.
73  """
74  cell_type = _CELL_TYPES.get(cell_type, cell_type)
75  if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell):
76    raise ValueError('The supported cell types are {}; got {}'.format(
77        list(_CELL_TYPES.keys()), cell_type))
78  return cell_type(num_units=num_units)
79
80
81def construct_rnn_cell(num_units, cell_type='basic_rnn',
82                       dropout_keep_probabilities=None):
83  """Constructs cells, applies dropout and assembles a `MultiRNNCell`.
84
85  The cell type chosen by DynamicRNNEstimator.__init__() is the same as
86  returned by this function when called with the same arguments.
87
88  Args:
89    num_units: A single `int` or a list/tuple of `int`s. The size of the
90      `RNNCell`s.
91    cell_type: A string identifying the `RNNCell` type or a subclass of
92      `RNNCell`.
93    dropout_keep_probabilities: a list of dropout probabilities or `None`. If a
94      list is given, it must have length `len(cell_type) + 1`.
95
96  Returns:
97    An initialized `RNNCell`.
98  """
99  if not isinstance(num_units, (list, tuple)):
100    num_units = (num_units,)
101
102  cells = [_get_single_cell(cell_type, n) for n in num_units]
103  if dropout_keep_probabilities:
104    cells = apply_dropout(cells, dropout_keep_probabilities)
105  if len(cells) == 1:
106    return cells[0]
107  return contrib_rnn.MultiRNNCell(cells)
108
109
110def apply_dropout(cells, dropout_keep_probabilities, random_seed=None):
111  """Applies dropout to the outputs and inputs of `cell`.
112
113  Args:
114    cells: A list of `RNNCell`s.
115    dropout_keep_probabilities: a list whose elements are either floats in
116    `[0.0, 1.0]` or `None`. It must have length one greater than `cells`.
117    random_seed: Seed for random dropout.
118
119  Returns:
120    A list of `RNNCell`s, the result of applying the supplied dropouts.
121
122  Raises:
123    ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`.
124  """
125  if len(dropout_keep_probabilities) != len(cells) + 1:
126    raise ValueError(
127        'The number of dropout probabilities must be one greater than the '
128        'number of cells. Got {} cells and {} dropout probabilities.'.format(
129            len(cells), len(dropout_keep_probabilities)))
130  wrapped_cells = [
131      contrib_rnn.DropoutWrapper(cell, prob, 1.0, seed=random_seed)
132      for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2])
133  ]
134  wrapped_cells.append(
135      contrib_rnn.DropoutWrapper(cells[-1], dropout_keep_probabilities[-2],
136                                 dropout_keep_probabilities[-1]))
137  return wrapped_cells
138
139
140def get_eval_metric_ops(problem_type, prediction_type, sequence_length,
141                        prediction_dict, labels):
142  """Returns eval metric ops for given `problem_type` and `prediction_type`.
143
144  Args:
145    problem_type: `ProblemType.CLASSIFICATION` or
146      `ProblemType.LINEAR_REGRESSION`.
147    prediction_type: `PredictionType.SINGLE_VALUE` or
148      `PredictionType.MULTIPLE_VALUE`.
149    sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32`
150      containing the length of each sequence in the batch. If `None`, sequences
151      are assumed to be unpadded.
152    prediction_dict: A dict of prediction tensors.
153    labels: The label `Tensor`.
154
155  Returns:
156    A `dict` mapping strings to the result of calling the metric_fn.
157  """
158  eval_metric_ops = {}
159  if problem_type == constants.ProblemType.CLASSIFICATION:
160    # Multi value classification
161    if prediction_type == PredictionType.MULTIPLE_VALUE:
162      mask_predictions, mask_labels = mask_activations_and_labels(
163          prediction_dict[prediction_key.PredictionKey.CLASSES], labels,
164          sequence_length)
165      eval_metric_ops['accuracy'] = metrics.streaming_accuracy(
166          predictions=mask_predictions, labels=mask_labels)
167    # Single value classification
168    elif prediction_type == PredictionType.SINGLE_VALUE:
169      eval_metric_ops['accuracy'] = metrics.streaming_accuracy(
170          predictions=prediction_dict[prediction_key.PredictionKey.CLASSES],
171          labels=labels)
172  elif problem_type == constants.ProblemType.LINEAR_REGRESSION:
173    # Multi value regression
174    if prediction_type == PredictionType.MULTIPLE_VALUE:
175      pass
176    # Single value regression
177    elif prediction_type == PredictionType.SINGLE_VALUE:
178      pass
179  return eval_metric_ops
180
181
182def select_last_activations(activations, sequence_lengths):
183  """Selects the nth set of activations for each n in `sequence_length`.
184
185  Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
186  `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If
187  `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
188
189  Args:
190    activations: A `Tensor` with shape `[batch_size, padded_length, k]`.
191    sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`.
192  Returns:
193    A `Tensor` of shape `[batch_size, k]`.
194  """
195  with ops.name_scope(
196      'select_last_activations', values=[activations, sequence_lengths]):
197    activations_shape = array_ops.shape(activations)
198    batch_size = activations_shape[0]
199    padded_length = activations_shape[1]
200    num_label_columns = activations_shape[2]
201    if sequence_lengths is None:
202      sequence_lengths = padded_length
203    reshaped_activations = array_ops.reshape(activations,
204                                             [-1, num_label_columns])
205    indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1
206    last_activations = array_ops.gather(reshaped_activations, indices)
207    last_activations.set_shape(
208        [activations.get_shape()[0], activations.get_shape()[2]])
209    return last_activations
210
211
212def mask_activations_and_labels(activations, labels, sequence_lengths):
213  """Remove entries outside `sequence_lengths` and returned flattened results.
214
215  Args:
216    activations: Output of the RNN, shape `[batch_size, padded_length, k]`.
217    labels: Label values, shape `[batch_size, padded_length]`.
218    sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded
219      length of each sequence. If `None`, then each sequence is unpadded.
220
221  Returns:
222    activations_masked: `logit` values with those beyond `sequence_lengths`
223      removed for each batch. Batches are then concatenated. Shape
224      `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
225      shape `[batch_size * padded_length, k]` otherwise.
226    labels_masked: Label values after removing unneeded entries. Shape
227      `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape
228      `[batch_size * padded_length]` otherwise.
229  """
230  with ops.name_scope(
231      'mask_activations_and_labels',
232      values=[activations, labels, sequence_lengths]):
233    labels_shape = array_ops.shape(labels)
234    batch_size = labels_shape[0]
235    padded_length = labels_shape[1]
236    if sequence_lengths is None:
237      flattened_dimension = padded_length * batch_size
238      activations_masked = array_ops.reshape(activations,
239                                             [flattened_dimension, -1])
240      labels_masked = array_ops.reshape(labels, [flattened_dimension])
241    else:
242      mask = array_ops.sequence_mask(sequence_lengths, padded_length)
243      activations_masked = array_ops.boolean_mask(activations, mask)
244      labels_masked = array_ops.boolean_mask(labels, mask)
245    return activations_masked, labels_masked
246
247
248def multi_value_predictions(activations, target_column, problem_type,
249                            predict_probabilities):
250  """Maps `activations` from the RNN to predictions for multi value models.
251
252  If `predict_probabilities` is `False`, this function returns a `dict`
253  containing single entry with key `prediction_key.PredictionKey.CLASSES` for
254  `problem_type` `ProblemType.CLASSIFICATION` or
255  `prediction_key.PredictionKey.SCORE` for `problem_type`
256  `ProblemType.LINEAR_REGRESSION`.
257
258  If `predict_probabilities` is `True`, it will contain a second entry with key
259  `prediction_key.PredictionKey.PROBABILITIES`. The
260  value of this entry is a `Tensor` of probabilities with shape
261  `[batch_size, padded_length, num_classes]`.
262
263  Note that variable length inputs will yield some predictions that don't have
264  meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]`
265  has no meaningful interpretation.
266
267  Args:
268    activations: Output from an RNN. Should have dtype `float32` and shape
269      `[batch_size, padded_length, ?]`.
270    target_column: An initialized `TargetColumn`, calculate predictions.
271    problem_type: Either `ProblemType.CLASSIFICATION` or
272      `ProblemType.LINEAR_REGRESSION`.
273    predict_probabilities: A Python boolean, indicating whether probabilities
274      should be returned. Should only be set to `True` for
275      classification/logistic regression problems.
276  Returns:
277    A `dict` mapping strings to `Tensors`.
278  """
279  with ops.name_scope('MultiValuePrediction'):
280    activations_shape = array_ops.shape(activations)
281    flattened_activations = array_ops.reshape(activations,
282                                              [-1, activations_shape[2]])
283    prediction_dict = {}
284    if predict_probabilities:
285      flat_probabilities = target_column.logits_to_predictions(
286          flattened_activations, proba=True)
287      flat_predictions = math_ops.argmax(flat_probabilities, 1)
288      if target_column.num_label_columns == 1:
289        probability_shape = array_ops.concat([activations_shape[:2], [2]], 0)
290      else:
291        probability_shape = activations_shape
292      probabilities = array_ops.reshape(
293          flat_probabilities,
294          probability_shape,
295          name=prediction_key.PredictionKey.PROBABILITIES)
296      prediction_dict[
297          prediction_key.PredictionKey.PROBABILITIES] = probabilities
298    else:
299      flat_predictions = target_column.logits_to_predictions(
300          flattened_activations, proba=False)
301    predictions_name = (prediction_key.PredictionKey.CLASSES
302                        if problem_type == constants.ProblemType.CLASSIFICATION
303                        else prediction_key.PredictionKey.SCORES)
304    predictions = array_ops.reshape(
305        flat_predictions, [activations_shape[0], activations_shape[1]],
306        name=predictions_name)
307    prediction_dict[predictions_name] = predictions
308    return prediction_dict
309