1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Estimator for Dynamic RNNs (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.contrib import layers 27from tensorflow.contrib.layers.python.layers import optimizers 28from tensorflow.contrib.learn.python.learn.estimators import constants 29from tensorflow.contrib.learn.python.learn.estimators import estimator 30from tensorflow.contrib.learn.python.learn.estimators import model_fn 31from tensorflow.contrib.learn.python.learn.estimators import prediction_key 32from tensorflow.contrib.learn.python.learn.estimators import rnn_common 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import rnn 39from tensorflow.python.training import momentum as momentum_opt 40from tensorflow.python.util import nest 41 42 43# TODO(jtbates): Remove PredictionType when all non-experimental targets which 44# depend on it point to rnn_common.PredictionType. 45class PredictionType(object): 46 SINGLE_VALUE = 1 47 MULTIPLE_VALUE = 2 48 49 50def _get_state_name(i): 51 """Constructs the name string for state component `i`.""" 52 return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) 53 54 55def state_tuple_to_dict(state): 56 """Returns a dict containing flattened `state`. 57 58 Args: 59 state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must 60 have the same rank and agree on all dimensions except the last. 61 62 Returns: 63 A dict containing the `Tensor`s that make up `state`. The keys of the dict 64 are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor` 65 in a depth-first traversal of `state`. 66 """ 67 with ops.name_scope('state_tuple_to_dict'): 68 flat_state = nest.flatten(state) 69 state_dict = {} 70 for i, state_component in enumerate(flat_state): 71 state_name = _get_state_name(i) 72 state_value = (None if state_component is None 73 else array_ops.identity(state_component, name=state_name)) 74 state_dict[state_name] = state_value 75 return state_dict 76 77 78def dict_to_state_tuple(input_dict, cell): 79 """Reconstructs nested `state` from a dict containing state `Tensor`s. 80 81 Args: 82 input_dict: A dict of `Tensor`s. 83 cell: An instance of `RNNCell`. 84 Returns: 85 If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n` 86 where `n` is the number of nested entries in `cell.state_size`, this 87 function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size` 88 is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested 89 tuple. 90 Raises: 91 ValueError: State is partially specified. The `input_dict` must contain 92 values for all state components or none at all. 93 """ 94 flat_state_sizes = nest.flatten(cell.state_size) 95 state_tensors = [] 96 with ops.name_scope('dict_to_state_tuple'): 97 for i, state_size in enumerate(flat_state_sizes): 98 state_name = _get_state_name(i) 99 state_tensor = input_dict.get(state_name) 100 if state_tensor is not None: 101 rank_check = check_ops.assert_rank( 102 state_tensor, 2, name='check_state_{}_rank'.format(i)) 103 shape_check = check_ops.assert_equal( 104 array_ops.shape(state_tensor)[1], 105 state_size, 106 name='check_state_{}_shape'.format(i)) 107 with ops.control_dependencies([rank_check, shape_check]): 108 state_tensor = array_ops.identity(state_tensor, name=state_name) 109 state_tensors.append(state_tensor) 110 if not state_tensors: 111 return None 112 elif len(state_tensors) == len(flat_state_sizes): 113 dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool) 114 return nest.pack_sequence_as(dummy_state, state_tensors) 115 else: 116 raise ValueError( 117 'RNN state was partially specified.' 118 'Expected zero or {} state Tensors; got {}'. 119 format(len(flat_state_sizes), len(state_tensors))) 120 121 122def _concatenate_context_input(sequence_input, context_input): 123 """Replicates `context_input` across all timesteps of `sequence_input`. 124 125 Expands dimension 1 of `context_input` then tiles it `sequence_length` times. 126 This value is appended to `sequence_input` on dimension 2 and the result is 127 returned. 128 129 Args: 130 sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, 131 padded_length, d0]`. 132 context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. 133 134 Returns: 135 A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, 136 d0 + d1]`. 137 138 Raises: 139 ValueError: If `sequence_input` does not have rank 3 or `context_input` does 140 not have rank 2. 141 """ 142 seq_rank_check = check_ops.assert_rank( 143 sequence_input, 144 3, 145 message='sequence_input must have rank 3', 146 data=[array_ops.shape(sequence_input)]) 147 seq_type_check = check_ops.assert_type( 148 sequence_input, 149 dtypes.float32, 150 message='sequence_input must have dtype float32; got {}.'.format( 151 sequence_input.dtype)) 152 ctx_rank_check = check_ops.assert_rank( 153 context_input, 154 2, 155 message='context_input must have rank 2', 156 data=[array_ops.shape(context_input)]) 157 ctx_type_check = check_ops.assert_type( 158 context_input, 159 dtypes.float32, 160 message='context_input must have dtype float32; got {}.'.format( 161 context_input.dtype)) 162 with ops.control_dependencies( 163 [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): 164 padded_length = array_ops.shape(sequence_input)[1] 165 tiled_context_input = array_ops.tile( 166 array_ops.expand_dims(context_input, 1), 167 array_ops.concat([[1], [padded_length], [1]], 0)) 168 return array_ops.concat([sequence_input, tiled_context_input], 2) 169 170 171def build_sequence_input(features, 172 sequence_feature_columns, 173 context_feature_columns, 174 weight_collections=None, 175 scope=None): 176 """Combine sequence and context features into input for an RNN. 177 178 Args: 179 features: A `dict` containing the input and (optionally) sequence length 180 information and initial state. 181 sequence_feature_columns: An iterable containing all the feature columns 182 describing sequence features. All items in the set should be instances 183 of classes derived from `FeatureColumn`. 184 context_feature_columns: An iterable containing all the feature columns 185 describing context features i.e. features that apply across all time 186 steps. All items in the set should be instances of classes derived from 187 `FeatureColumn`. 188 weight_collections: List of graph collections to which weights are added. 189 scope: Optional scope, passed through to parsing ops. 190 Returns: 191 A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`. 192 This will be used as input to an RNN. 193 """ 194 features = features.copy() 195 features.update(layers.transform_features( 196 features, 197 list(sequence_feature_columns) + list(context_feature_columns or []))) 198 sequence_input = layers.sequence_input_from_feature_columns( 199 columns_to_tensors=features, 200 feature_columns=sequence_feature_columns, 201 weight_collections=weight_collections, 202 scope=scope) 203 if context_feature_columns is not None: 204 context_input = layers.input_from_feature_columns( 205 columns_to_tensors=features, 206 feature_columns=context_feature_columns, 207 weight_collections=weight_collections, 208 scope=scope) 209 sequence_input = _concatenate_context_input(sequence_input, context_input) 210 return sequence_input 211 212 213def construct_rnn(initial_state, 214 sequence_input, 215 cell, 216 num_label_columns, 217 dtype=dtypes.float32, 218 parallel_iterations=32, 219 swap_memory=True): 220 """Build an RNN and apply a fully connected layer to get the desired output. 221 222 Args: 223 initial_state: The initial state to pass the RNN. If `None`, the 224 default starting state for `self._cell` is used. 225 sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]` 226 that will be passed as input to the RNN. 227 cell: An initialized `RNNCell`. 228 num_label_columns: The desired output dimension. 229 dtype: dtype of `cell`. 230 parallel_iterations: Number of iterations to run in parallel. Values >> 1 231 use more memory but take less time, while smaller values use less memory 232 but computations take longer. 233 swap_memory: Transparently swap the tensors produced in forward inference 234 but needed for back prop from GPU to CPU. This allows training RNNs 235 which would typically not fit on a single GPU, with very minimal (or no) 236 performance penalty. 237 Returns: 238 activations: The output of the RNN, projected to `num_label_columns` 239 dimensions. 240 final_state: A `Tensor` or nested tuple of `Tensor`s representing the final 241 state output by the RNN. 242 """ 243 with ops.name_scope('RNN'): 244 rnn_outputs, final_state = rnn.dynamic_rnn( 245 cell=cell, 246 inputs=sequence_input, 247 initial_state=initial_state, 248 dtype=dtype, 249 parallel_iterations=parallel_iterations, 250 swap_memory=swap_memory, 251 time_major=False) 252 activations = layers.fully_connected( 253 inputs=rnn_outputs, 254 num_outputs=num_label_columns, 255 activation_fn=None, 256 trainable=True) 257 return activations, final_state 258 259 260def _single_value_predictions(activations, 261 sequence_length, 262 target_column, 263 problem_type, 264 predict_probabilities): 265 """Maps `activations` from the RNN to predictions for single value models. 266 267 If `predict_probabilities` is `False`, this function returns a `dict` 268 containing single entry with key `PREDICTIONS_KEY`. If `predict_probabilities` 269 is `True`, it will contain a second entry with key `PROBABILITIES_KEY`. The 270 value of this entry is a `Tensor` of probabilities with shape 271 `[batch_size, num_classes]`. 272 273 Args: 274 activations: Output from an RNN. Should have dtype `float32` and shape 275 `[batch_size, padded_length, ?]`. 276 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 277 containing the length of each sequence in the batch. If `None`, sequences 278 are assumed to be unpadded. 279 target_column: An initialized `TargetColumn`, calculate predictions. 280 problem_type: Either `ProblemType.CLASSIFICATION` or 281 `ProblemType.LINEAR_REGRESSION`. 282 predict_probabilities: A Python boolean, indicating whether probabilities 283 should be returned. Should only be set to `True` for 284 classification/logistic regression problems. 285 Returns: 286 A `dict` mapping strings to `Tensors`. 287 """ 288 with ops.name_scope('SingleValuePrediction'): 289 last_activations = rnn_common.select_last_activations( 290 activations, sequence_length) 291 predictions_name = (prediction_key.PredictionKey.CLASSES 292 if problem_type == constants.ProblemType.CLASSIFICATION 293 else prediction_key.PredictionKey.SCORES) 294 if predict_probabilities: 295 probabilities = target_column.logits_to_predictions( 296 last_activations, proba=True) 297 prediction_dict = { 298 prediction_key.PredictionKey.PROBABILITIES: probabilities, 299 predictions_name: math_ops.argmax(probabilities, 1)} 300 else: 301 predictions = target_column.logits_to_predictions( 302 last_activations, proba=False) 303 prediction_dict = {predictions_name: predictions} 304 return prediction_dict 305 306 307def _multi_value_loss( 308 activations, labels, sequence_length, target_column, features): 309 """Maps `activations` from the RNN to loss for multi value models. 310 311 Args: 312 activations: Output from an RNN. Should have dtype `float32` and shape 313 `[batch_size, padded_length, ?]`. 314 labels: A `Tensor` with length `[batch_size, padded_length]`. 315 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 316 containing the length of each sequence in the batch. If `None`, sequences 317 are assumed to be unpadded. 318 target_column: An initialized `TargetColumn`, calculate predictions. 319 features: A `dict` containing the input and (optionally) sequence length 320 information and initial state. 321 Returns: 322 A scalar `Tensor` containing the loss. 323 """ 324 with ops.name_scope('MultiValueLoss'): 325 activations_masked, labels_masked = rnn_common.mask_activations_and_labels( 326 activations, labels, sequence_length) 327 return target_column.loss(activations_masked, labels_masked, features) 328 329 330def _single_value_loss( 331 activations, labels, sequence_length, target_column, features): 332 """Maps `activations` from the RNN to loss for multi value models. 333 334 Args: 335 activations: Output from an RNN. Should have dtype `float32` and shape 336 `[batch_size, padded_length, ?]`. 337 labels: A `Tensor` with length `[batch_size]`. 338 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 339 containing the length of each sequence in the batch. If `None`, sequences 340 are assumed to be unpadded. 341 target_column: An initialized `TargetColumn`, calculate predictions. 342 features: A `dict` containing the input and (optionally) sequence length 343 information and initial state. 344 Returns: 345 A scalar `Tensor` containing the loss. 346 """ 347 348 with ops.name_scope('SingleValueLoss'): 349 last_activations = rnn_common.select_last_activations( 350 activations, sequence_length) 351 return target_column.loss(last_activations, labels, features) 352 353 354def _get_output_alternatives(prediction_type, 355 problem_type, 356 prediction_dict): 357 """Constructs output alternatives dict for `ModelFnOps`. 358 359 Args: 360 prediction_type: either `MULTIPLE_VALUE` or `SINGLE_VALUE`. 361 problem_type: either `CLASSIFICATION` or `LINEAR_REGRESSION`. 362 prediction_dict: a dictionary mapping strings to `Tensor`s containing 363 predictions. 364 365 Returns: 366 `None` or a dictionary mapping a string to an output alternative. 367 368 Raises: 369 ValueError: `prediction_type` is not one of `SINGLE_VALUE` or 370 `MULTIPLE_VALUE`. 371 """ 372 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 373 return None 374 if prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 375 prediction_dict_no_state = { 376 k: v 377 for k, v in prediction_dict.items() 378 if rnn_common.RNNKeys.STATE_PREFIX not in k 379 } 380 return {'dynamic_rnn_output': (problem_type, prediction_dict_no_state)} 381 raise ValueError('Unrecognized prediction_type: {}'.format(prediction_type)) 382 383 384def _get_dynamic_rnn_model_fn( 385 cell_type, 386 num_units, 387 target_column, 388 problem_type, 389 prediction_type, 390 optimizer, 391 sequence_feature_columns, 392 context_feature_columns=None, 393 predict_probabilities=False, 394 learning_rate=None, 395 gradient_clipping_norm=None, 396 dropout_keep_probabilities=None, 397 sequence_length_key=rnn_common.RNNKeys.SEQUENCE_LENGTH_KEY, 398 dtype=dtypes.float32, 399 parallel_iterations=None, 400 swap_memory=True, 401 name='DynamicRNNModel'): 402 """Creates an RNN model function for an `Estimator`. 403 404 The model function returns an instance of `ModelFnOps`. When 405 `problem_type == ProblemType.CLASSIFICATION` and 406 `predict_probabilities == True`, the returned `ModelFnOps` includes an output 407 alternative containing the classes and their associated probabilities. When 408 `predict_probabilities == False`, only the classes are included. When 409 `problem_type == ProblemType.LINEAR_REGRESSION`, the output alternative 410 contains only the predicted values. 411 412 Args: 413 cell_type: A string, a subclass of `RNNCell` or an instance of an `RNNCell`. 414 num_units: A single `int` or a list of `int`s. The size of the `RNNCell`s. 415 target_column: An initialized `TargetColumn`, used to calculate prediction 416 and loss. 417 problem_type: `ProblemType.CLASSIFICATION` or 418 `ProblemType.LINEAR_REGRESSION`. 419 prediction_type: `PredictionType.SINGLE_VALUE` or 420 `PredictionType.MULTIPLE_VALUE`. 421 optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a 422 string. 423 sequence_feature_columns: An iterable containing all the feature columns 424 describing sequence features. All items in the set should be instances 425 of classes derived from `FeatureColumn`. 426 context_feature_columns: An iterable containing all the feature columns 427 describing context features, i.e., features that apply across all time 428 steps. All items in the set should be instances of classes derived from 429 `FeatureColumn`. 430 predict_probabilities: A boolean indicating whether to predict probabilities 431 for all classes. Must only be used with 432 `ProblemType.CLASSIFICATION`. 433 learning_rate: Learning rate used for optimization. This argument has no 434 effect if `optimizer` is an instance of an `Optimizer`. 435 gradient_clipping_norm: A float. Gradients will be clipped to this value. 436 dropout_keep_probabilities: a list of dropout keep probabilities or `None`. 437 If a list is given, it must have length `len(num_units) + 1`. 438 sequence_length_key: The key that will be used to look up sequence length in 439 the `features` dict. 440 dtype: The dtype of the state and output of the given `cell`. 441 parallel_iterations: Number of iterations to run in parallel. Values >> 1 442 use more memory but take less time, while smaller values use less memory 443 but computations take longer. 444 swap_memory: Transparently swap the tensors produced in forward inference 445 but needed for back prop from GPU to CPU. This allows training RNNs 446 which would typically not fit on a single GPU, with very minimal (or no) 447 performance penalty. 448 name: A string that will be used to create a scope for the RNN. 449 450 Returns: 451 A model function to be passed to an `Estimator`. 452 453 Raises: 454 ValueError: `problem_type` is not one of 455 `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. 456 ValueError: `prediction_type` is not one of `PredictionType.SINGLE_VALUE` 457 or `PredictionType.MULTIPLE_VALUE`. 458 ValueError: `predict_probabilities` is `True` for `problem_type` other 459 than `ProblemType.CLASSIFICATION`. 460 ValueError: `len(dropout_keep_probabilities)` is not `len(num_units) + 1`. 461 """ 462 if problem_type not in (constants.ProblemType.CLASSIFICATION, 463 constants.ProblemType.LINEAR_REGRESSION): 464 raise ValueError( 465 'problem_type must be ProblemType.LINEAR_REGRESSION or ' 466 'ProblemType.CLASSIFICATION; got {}'. 467 format(problem_type)) 468 if prediction_type not in (rnn_common.PredictionType.SINGLE_VALUE, 469 rnn_common.PredictionType.MULTIPLE_VALUE): 470 raise ValueError( 471 'prediction_type must be PredictionType.MULTIPLE_VALUEs or ' 472 'PredictionType.SINGLE_VALUE; got {}'. 473 format(prediction_type)) 474 if (problem_type != constants.ProblemType.CLASSIFICATION 475 and predict_probabilities): 476 raise ValueError( 477 'predict_probabilities can only be set to True for problem_type' 478 ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type)) 479 def _dynamic_rnn_model_fn(features, labels, mode): 480 """The model to be passed to an `Estimator`.""" 481 with ops.name_scope(name): 482 sequence_length = features.get(sequence_length_key) 483 sequence_input = build_sequence_input(features, 484 sequence_feature_columns, 485 context_feature_columns) 486 dropout = (dropout_keep_probabilities 487 if mode == model_fn.ModeKeys.TRAIN 488 else None) 489 # This class promises to use the cell type selected by that function. 490 cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout) 491 initial_state = dict_to_state_tuple(features, cell) 492 rnn_activations, final_state = construct_rnn( 493 initial_state, 494 sequence_input, 495 cell, 496 target_column.num_label_columns, 497 dtype=dtype, 498 parallel_iterations=parallel_iterations, 499 swap_memory=swap_memory) 500 501 loss = None # Created below for modes TRAIN and EVAL. 502 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 503 prediction_dict = rnn_common.multi_value_predictions( 504 rnn_activations, target_column, problem_type, predict_probabilities) 505 if mode != model_fn.ModeKeys.INFER: 506 loss = _multi_value_loss( 507 rnn_activations, labels, sequence_length, target_column, features) 508 elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 509 prediction_dict = _single_value_predictions( 510 rnn_activations, sequence_length, target_column, 511 problem_type, predict_probabilities) 512 if mode != model_fn.ModeKeys.INFER: 513 loss = _single_value_loss( 514 rnn_activations, labels, sequence_length, target_column, features) 515 state_dict = state_tuple_to_dict(final_state) 516 prediction_dict.update(state_dict) 517 518 eval_metric_ops = None 519 if mode != model_fn.ModeKeys.INFER: 520 eval_metric_ops = rnn_common.get_eval_metric_ops( 521 problem_type, prediction_type, sequence_length, prediction_dict, 522 labels) 523 524 train_op = None 525 if mode == model_fn.ModeKeys.TRAIN: 526 train_op = optimizers.optimize_loss( 527 loss=loss, 528 global_step=None, # Get it internally. 529 learning_rate=learning_rate, 530 optimizer=optimizer, 531 clip_gradients=gradient_clipping_norm, 532 summaries=optimizers.OPTIMIZER_SUMMARIES) 533 534 output_alternatives = _get_output_alternatives(prediction_type, 535 problem_type, 536 prediction_dict) 537 538 return model_fn.ModelFnOps(mode=mode, 539 predictions=prediction_dict, 540 loss=loss, 541 train_op=train_op, 542 eval_metric_ops=eval_metric_ops, 543 output_alternatives=output_alternatives) 544 return _dynamic_rnn_model_fn 545 546 547class DynamicRnnEstimator(estimator.Estimator): 548 """Dynamically unrolled RNN (deprecated). 549 550 THIS CLASS IS DEPRECATED. See 551 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 552 for general migration instructions. 553 """ 554 555 def __init__(self, 556 problem_type, 557 prediction_type, 558 sequence_feature_columns, 559 context_feature_columns=None, 560 num_classes=None, 561 num_units=None, 562 cell_type='basic_rnn', 563 optimizer='SGD', 564 learning_rate=0.1, 565 predict_probabilities=False, 566 momentum=None, 567 gradient_clipping_norm=5.0, 568 dropout_keep_probabilities=None, 569 model_dir=None, 570 feature_engineering_fn=None, 571 config=None): 572 """Initializes a `DynamicRnnEstimator`. 573 574 The input function passed to this `Estimator` optionally contains keys 575 `RNNKeys.SEQUENCE_LENGTH_KEY`. The value corresponding to 576 `RNNKeys.SEQUENCE_LENGTH_KEY` must be vector of size `batch_size` where 577 entry `n` corresponds to the length of the `n`th sequence in the batch. The 578 sequence length feature is required for batches of varying sizes. It will be 579 used to calculate loss and evaluation metrics. If 580 `RNNKeys.SEQUENCE_LENGTH_KEY` is not included, all sequences are assumed to 581 have length equal to the size of dimension 1 of the input to the RNN. 582 583 In order to specify an initial state, the input function must include keys 584 `STATE_PREFIX_i` for all `0 <= i < n` where `n` is the number of nested 585 elements in `cell.state_size`. The input function must contain values for 586 all state components or none of them. If none are included, then the default 587 (zero) state is used as an initial state. See the documentation for 588 `dict_to_state_tuple` and `state_tuple_to_dict` for further details. 589 The input function can call rnn_common.construct_rnn_cell() to obtain the 590 same cell type that this class will select from arguments to __init__. 591 592 The `predict()` method of the `Estimator` returns a dictionary with keys 593 `STATE_PREFIX_i` for `0 <= i < n` where `n` is the number of nested elements 594 in `cell.state_size`, along with `PredictionKey.CLASSES` for problem type 595 `CLASSIFICATION` or `PredictionKey.SCORES` for problem type 596 `LINEAR_REGRESSION`. The value keyed by 597 `PredictionKey.CLASSES` or `PredictionKey.SCORES` has shape 598 `[batch_size, padded_length]` in the multi-value case and shape 599 `[batch_size]` in the single-value case. Here, `padded_length` is the 600 largest value in the `RNNKeys.SEQUENCE_LENGTH` `Tensor` passed as input. 601 Entry `[i, j]` is the prediction associated with sequence `i` and time step 602 `j`. If the problem type is `CLASSIFICATION` and `predict_probabilities` is 603 `True`, it will also include key`PredictionKey.PROBABILITIES`. 604 605 Args: 606 problem_type: whether the `Estimator` is intended for a regression or 607 classification problem. Value must be one of 608 `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`. 609 prediction_type: whether the `Estimator` should return a value for each 610 step in the sequence, or just a single value for the final time step. 611 Must be one of `PredictionType.SINGLE_VALUE` or 612 `PredictionType.MULTIPLE_VALUE`. 613 sequence_feature_columns: An iterable containing all the feature columns 614 describing sequence features. All items in the iterable should be 615 instances of classes derived from `FeatureColumn`. 616 context_feature_columns: An iterable containing all the feature columns 617 describing context features, i.e., features that apply across all time 618 steps. All items in the set should be instances of classes derived from 619 `FeatureColumn`. 620 num_classes: the number of classes for a classification problem. Only 621 used when `problem_type=ProblemType.CLASSIFICATION`. 622 num_units: A list of integers indicating the number of units in the 623 `RNNCell`s in each layer. 624 cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. 625 optimizer: The type of optimizer to use. Either a subclass of 626 `Optimizer`, an instance of an `Optimizer`, a callback that returns an 627 optimizer, or a string. Strings must be one of 'Adagrad', 'Adam', 628 'Ftrl', 'Momentum', 'RMSProp' or 'SGD. See `layers.optimize_loss` for 629 more details. 630 learning_rate: Learning rate. This argument has no effect if `optimizer` 631 is an instance of an `Optimizer`. 632 predict_probabilities: A boolean indicating whether to predict 633 probabilities for all classes. Used only if `problem_type` is 634 `ProblemType.CLASSIFICATION` 635 momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. 636 gradient_clipping_norm: Parameter used for gradient clipping. If `None`, 637 then no clipping is performed. 638 dropout_keep_probabilities: a list of dropout probabilities or `None`. 639 If a list is given, it must have length `len(num_units) + 1`. If 640 `None`, then no dropout is applied. 641 model_dir: The directory in which to save and restore the model graph, 642 parameters, etc. 643 feature_engineering_fn: Takes features and labels which are the output of 644 `input_fn` and returns features and labels which will be fed into 645 `model_fn`. Please check `model_fn` for a definition of features and 646 labels. 647 config: A `RunConfig` instance. 648 649 Raises: 650 ValueError: `problem_type` is not one of 651 `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. 652 ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but 653 `num_classes` is not specified. 654 ValueError: `prediction_type` is not one of 655 `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. 656 """ 657 if prediction_type == rnn_common.PredictionType.MULTIPLE_VALUE: 658 name = 'MultiValueDynamicRNN' 659 elif prediction_type == rnn_common.PredictionType.SINGLE_VALUE: 660 name = 'SingleValueDynamicRNN' 661 else: 662 raise ValueError( 663 'prediction_type must be one of PredictionType.MULTIPLE_VALUE or ' 664 'PredictionType.SINGLE_VALUE; got {}'.format(prediction_type)) 665 666 if problem_type == constants.ProblemType.LINEAR_REGRESSION: 667 name += 'Regressor' 668 target_column = layers.regression_target() 669 elif problem_type == constants.ProblemType.CLASSIFICATION: 670 if not num_classes: 671 raise ValueError('For CLASSIFICATION problem_type, num_classes must be ' 672 'specified.') 673 target_column = layers.multi_class_target(n_classes=num_classes) 674 name += 'Classifier' 675 else: 676 raise ValueError( 677 'problem_type must be either ProblemType.LINEAR_REGRESSION ' 678 'or ProblemType.CLASSIFICATION; got {}'.format( 679 problem_type)) 680 681 if optimizer == 'Momentum': 682 optimizer = momentum_opt.MomentumOptimizer(learning_rate, momentum) 683 dynamic_rnn_model_fn = _get_dynamic_rnn_model_fn( 684 cell_type=cell_type, 685 num_units=num_units, 686 target_column=target_column, 687 problem_type=problem_type, 688 prediction_type=prediction_type, 689 optimizer=optimizer, 690 sequence_feature_columns=sequence_feature_columns, 691 context_feature_columns=context_feature_columns, 692 predict_probabilities=predict_probabilities, 693 learning_rate=learning_rate, 694 gradient_clipping_norm=gradient_clipping_norm, 695 dropout_keep_probabilities=dropout_keep_probabilities, 696 name=name) 697 698 super(DynamicRnnEstimator, self).__init__( 699 model_fn=dynamic_rnn_model_fn, 700 model_dir=model_dir, 701 config=config, 702 feature_engineering_fn=feature_engineering_fn) 703