1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common operations for RNN Estimators (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.contrib import metrics 27from tensorflow.contrib import rnn as contrib_rnn 28from tensorflow.contrib.learn.python.learn.estimators import constants 29from tensorflow.contrib.learn.python.learn.estimators import prediction_key 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import math_ops 33 34 35# NOTE(jtbates): As of February 10, 2017, some of the `RNNKeys` have been 36# removed and replaced with values from `prediction_key.PredictionKey`. The key 37# `RNNKeys.PREDICTIONS_KEY` has been replaced by 38# `prediction_key.PredictionKey.SCORES` for regression and 39# `prediction_key.PredictionKey.CLASSES` for classification. The key 40# `RNNKeys.PROBABILITIES_KEY` has been replaced by 41# `prediction_key.PredictionKey.PROBABILITIES`. 42class RNNKeys(object): 43 FINAL_STATE_KEY = 'final_state' 44 LABELS_KEY = '__labels__' 45 SEQUENCE_LENGTH_KEY = 'sequence_length' 46 STATE_PREFIX = 'rnn_cell_state' 47 48 49class PredictionType(object): 50 """Enum-like values for the type of prediction that the model makes. 51 """ 52 SINGLE_VALUE = 1 53 MULTIPLE_VALUE = 2 54 55 56_CELL_TYPES = {'basic_rnn': contrib_rnn.BasicRNNCell, 57 'lstm': contrib_rnn.LSTMCell, 58 'gru': contrib_rnn.GRUCell,} 59 60 61def _get_single_cell(cell_type, num_units): 62 """Constructs and return a single `RNNCell`. 63 64 Args: 65 cell_type: Either a string identifying the `RNNCell` type or a subclass of 66 `RNNCell`. 67 num_units: The number of units in the `RNNCell`. 68 Returns: 69 An initialized `RNNCell`. 70 Raises: 71 ValueError: `cell_type` is an invalid `RNNCell` name. 72 TypeError: `cell_type` is not a string or a subclass of `RNNCell`. 73 """ 74 cell_type = _CELL_TYPES.get(cell_type, cell_type) 75 if not cell_type or not issubclass(cell_type, contrib_rnn.RNNCell): 76 raise ValueError('The supported cell types are {}; got {}'.format( 77 list(_CELL_TYPES.keys()), cell_type)) 78 return cell_type(num_units=num_units) 79 80 81def construct_rnn_cell(num_units, cell_type='basic_rnn', 82 dropout_keep_probabilities=None): 83 """Constructs cells, applies dropout and assembles a `MultiRNNCell`. 84 85 The cell type chosen by DynamicRNNEstimator.__init__() is the same as 86 returned by this function when called with the same arguments. 87 88 Args: 89 num_units: A single `int` or a list/tuple of `int`s. The size of the 90 `RNNCell`s. 91 cell_type: A string identifying the `RNNCell` type or a subclass of 92 `RNNCell`. 93 dropout_keep_probabilities: a list of dropout probabilities or `None`. If a 94 list is given, it must have length `len(cell_type) + 1`. 95 96 Returns: 97 An initialized `RNNCell`. 98 """ 99 if not isinstance(num_units, (list, tuple)): 100 num_units = (num_units,) 101 102 cells = [_get_single_cell(cell_type, n) for n in num_units] 103 if dropout_keep_probabilities: 104 cells = apply_dropout(cells, dropout_keep_probabilities) 105 if len(cells) == 1: 106 return cells[0] 107 return contrib_rnn.MultiRNNCell(cells) 108 109 110def apply_dropout(cells, dropout_keep_probabilities, random_seed=None): 111 """Applies dropout to the outputs and inputs of `cell`. 112 113 Args: 114 cells: A list of `RNNCell`s. 115 dropout_keep_probabilities: a list whose elements are either floats in 116 `[0.0, 1.0]` or `None`. It must have length one greater than `cells`. 117 random_seed: Seed for random dropout. 118 119 Returns: 120 A list of `RNNCell`s, the result of applying the supplied dropouts. 121 122 Raises: 123 ValueError: If `len(dropout_keep_probabilities) != len(cells) + 1`. 124 """ 125 if len(dropout_keep_probabilities) != len(cells) + 1: 126 raise ValueError( 127 'The number of dropout probabilities must be one greater than the ' 128 'number of cells. Got {} cells and {} dropout probabilities.'.format( 129 len(cells), len(dropout_keep_probabilities))) 130 wrapped_cells = [ 131 contrib_rnn.DropoutWrapper(cell, prob, 1.0, seed=random_seed) 132 for cell, prob in zip(cells[:-1], dropout_keep_probabilities[:-2]) 133 ] 134 wrapped_cells.append( 135 contrib_rnn.DropoutWrapper(cells[-1], dropout_keep_probabilities[-2], 136 dropout_keep_probabilities[-1])) 137 return wrapped_cells 138 139 140def get_eval_metric_ops(problem_type, prediction_type, sequence_length, 141 prediction_dict, labels): 142 """Returns eval metric ops for given `problem_type` and `prediction_type`. 143 144 Args: 145 problem_type: `ProblemType.CLASSIFICATION` or 146 `ProblemType.LINEAR_REGRESSION`. 147 prediction_type: `PredictionType.SINGLE_VALUE` or 148 `PredictionType.MULTIPLE_VALUE`. 149 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 150 containing the length of each sequence in the batch. If `None`, sequences 151 are assumed to be unpadded. 152 prediction_dict: A dict of prediction tensors. 153 labels: The label `Tensor`. 154 155 Returns: 156 A `dict` mapping strings to the result of calling the metric_fn. 157 """ 158 eval_metric_ops = {} 159 if problem_type == constants.ProblemType.CLASSIFICATION: 160 # Multi value classification 161 if prediction_type == PredictionType.MULTIPLE_VALUE: 162 mask_predictions, mask_labels = mask_activations_and_labels( 163 prediction_dict[prediction_key.PredictionKey.CLASSES], labels, 164 sequence_length) 165 eval_metric_ops['accuracy'] = metrics.streaming_accuracy( 166 predictions=mask_predictions, labels=mask_labels) 167 # Single value classification 168 elif prediction_type == PredictionType.SINGLE_VALUE: 169 eval_metric_ops['accuracy'] = metrics.streaming_accuracy( 170 predictions=prediction_dict[prediction_key.PredictionKey.CLASSES], 171 labels=labels) 172 elif problem_type == constants.ProblemType.LINEAR_REGRESSION: 173 # Multi value regression 174 if prediction_type == PredictionType.MULTIPLE_VALUE: 175 pass 176 # Single value regression 177 elif prediction_type == PredictionType.SINGLE_VALUE: 178 pass 179 return eval_metric_ops 180 181 182def select_last_activations(activations, sequence_lengths): 183 """Selects the nth set of activations for each n in `sequence_length`. 184 185 Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not 186 `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If 187 `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. 188 189 Args: 190 activations: A `Tensor` with shape `[batch_size, padded_length, k]`. 191 sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. 192 Returns: 193 A `Tensor` of shape `[batch_size, k]`. 194 """ 195 with ops.name_scope( 196 'select_last_activations', values=[activations, sequence_lengths]): 197 activations_shape = array_ops.shape(activations) 198 batch_size = activations_shape[0] 199 padded_length = activations_shape[1] 200 num_label_columns = activations_shape[2] 201 if sequence_lengths is None: 202 sequence_lengths = padded_length 203 reshaped_activations = array_ops.reshape(activations, 204 [-1, num_label_columns]) 205 indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1 206 last_activations = array_ops.gather(reshaped_activations, indices) 207 last_activations.set_shape( 208 [activations.get_shape()[0], activations.get_shape()[2]]) 209 return last_activations 210 211 212def mask_activations_and_labels(activations, labels, sequence_lengths): 213 """Remove entries outside `sequence_lengths` and returned flattened results. 214 215 Args: 216 activations: Output of the RNN, shape `[batch_size, padded_length, k]`. 217 labels: Label values, shape `[batch_size, padded_length]`. 218 sequence_lengths: A `Tensor` of shape `[batch_size]` with the unpadded 219 length of each sequence. If `None`, then each sequence is unpadded. 220 221 Returns: 222 activations_masked: `logit` values with those beyond `sequence_lengths` 223 removed for each batch. Batches are then concatenated. Shape 224 `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and 225 shape `[batch_size * padded_length, k]` otherwise. 226 labels_masked: Label values after removing unneeded entries. Shape 227 `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape 228 `[batch_size * padded_length]` otherwise. 229 """ 230 with ops.name_scope( 231 'mask_activations_and_labels', 232 values=[activations, labels, sequence_lengths]): 233 labels_shape = array_ops.shape(labels) 234 batch_size = labels_shape[0] 235 padded_length = labels_shape[1] 236 if sequence_lengths is None: 237 flattened_dimension = padded_length * batch_size 238 activations_masked = array_ops.reshape(activations, 239 [flattened_dimension, -1]) 240 labels_masked = array_ops.reshape(labels, [flattened_dimension]) 241 else: 242 mask = array_ops.sequence_mask(sequence_lengths, padded_length) 243 activations_masked = array_ops.boolean_mask(activations, mask) 244 labels_masked = array_ops.boolean_mask(labels, mask) 245 return activations_masked, labels_masked 246 247 248def multi_value_predictions(activations, target_column, problem_type, 249 predict_probabilities): 250 """Maps `activations` from the RNN to predictions for multi value models. 251 252 If `predict_probabilities` is `False`, this function returns a `dict` 253 containing single entry with key `prediction_key.PredictionKey.CLASSES` for 254 `problem_type` `ProblemType.CLASSIFICATION` or 255 `prediction_key.PredictionKey.SCORE` for `problem_type` 256 `ProblemType.LINEAR_REGRESSION`. 257 258 If `predict_probabilities` is `True`, it will contain a second entry with key 259 `prediction_key.PredictionKey.PROBABILITIES`. The 260 value of this entry is a `Tensor` of probabilities with shape 261 `[batch_size, padded_length, num_classes]`. 262 263 Note that variable length inputs will yield some predictions that don't have 264 meaning. For example, if `sequence_length = [3, 2]`, then prediction `[1, 2]` 265 has no meaningful interpretation. 266 267 Args: 268 activations: Output from an RNN. Should have dtype `float32` and shape 269 `[batch_size, padded_length, ?]`. 270 target_column: An initialized `TargetColumn`, calculate predictions. 271 problem_type: Either `ProblemType.CLASSIFICATION` or 272 `ProblemType.LINEAR_REGRESSION`. 273 predict_probabilities: A Python boolean, indicating whether probabilities 274 should be returned. Should only be set to `True` for 275 classification/logistic regression problems. 276 Returns: 277 A `dict` mapping strings to `Tensors`. 278 """ 279 with ops.name_scope('MultiValuePrediction'): 280 activations_shape = array_ops.shape(activations) 281 flattened_activations = array_ops.reshape(activations, 282 [-1, activations_shape[2]]) 283 prediction_dict = {} 284 if predict_probabilities: 285 flat_probabilities = target_column.logits_to_predictions( 286 flattened_activations, proba=True) 287 flat_predictions = math_ops.argmax(flat_probabilities, 1) 288 if target_column.num_label_columns == 1: 289 probability_shape = array_ops.concat([activations_shape[:2], [2]], 0) 290 else: 291 probability_shape = activations_shape 292 probabilities = array_ops.reshape( 293 flat_probabilities, 294 probability_shape, 295 name=prediction_key.PredictionKey.PROBABILITIES) 296 prediction_dict[ 297 prediction_key.PredictionKey.PROBABILITIES] = probabilities 298 else: 299 flat_predictions = target_column.logits_to_predictions( 300 flattened_activations, proba=False) 301 predictions_name = (prediction_key.PredictionKey.CLASSES 302 if problem_type == constants.ProblemType.CLASSIFICATION 303 else prediction_key.PredictionKey.SCORES) 304 predictions = array_ops.reshape( 305 flat_predictions, [activations_shape[0], activations_shape[1]], 306 name=predictions_name) 307 prediction_dict[predictions_name] = predictions 308 return prediction_dict 309