1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Estimator for State Saving RNNs (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.contrib import layers 27from tensorflow.contrib import rnn as rnn_cell 28from tensorflow.contrib.layers.python.layers import feature_column_ops 29from tensorflow.contrib.layers.python.layers import optimizers 30from tensorflow.contrib.learn.python.learn.estimators import constants 31from tensorflow.contrib.learn.python.learn.estimators import estimator 32from tensorflow.contrib.learn.python.learn.estimators import model_fn 33from tensorflow.contrib.learn.python.learn.estimators import rnn_common 34from tensorflow.contrib.training.python.training import sequence_queueing_state_saver as sqss 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import sparse_tensor 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import rnn 41from tensorflow.python.training import momentum as momentum_opt 42from tensorflow.python.util import nest 43 44 45def construct_state_saving_rnn(cell, 46 inputs, 47 num_label_columns, 48 state_saver, 49 state_name, 50 scope='rnn'): 51 """Build a state saving RNN and apply a fully connected layer. 52 53 Args: 54 cell: An instance of `RNNCell`. 55 inputs: A length `T` list of inputs, each a `Tensor` of shape 56 `[batch_size, input_size, ...]`. 57 num_label_columns: The desired output dimension. 58 state_saver: A state saver object with methods `state` and `save_state`. 59 state_name: Python string or tuple of strings. The name to use with the 60 state_saver. If the cell returns tuples of states (i.e., 61 `cell.state_size` is a tuple) then `state_name` should be a tuple of 62 strings having the same length as `cell.state_size`. Otherwise it should 63 be a single string. 64 scope: `VariableScope` for the created subgraph; defaults to "rnn". 65 66 Returns: 67 activations: The output of the RNN, projected to `num_label_columns` 68 dimensions, a `Tensor` of shape `[batch_size, T, num_label_columns]`. 69 final_state: The final state output by the RNN 70 """ 71 with ops.name_scope(scope): 72 rnn_outputs, final_state = rnn.static_state_saving_rnn( 73 cell=cell, 74 inputs=inputs, 75 state_saver=state_saver, 76 state_name=state_name, 77 scope=scope) 78 # Convert rnn_outputs from a list of time-major order Tensors to a single 79 # Tensor of batch-major order. 80 rnn_outputs = array_ops.stack(rnn_outputs, axis=1) 81 activations = layers.fully_connected( 82 inputs=rnn_outputs, 83 num_outputs=num_label_columns, 84 activation_fn=None, 85 trainable=True) 86 # Use `identity` to rename `final_state`. 87 final_state = array_ops.identity( 88 final_state, name=rnn_common.RNNKeys.FINAL_STATE_KEY) 89 return activations, final_state 90 91 92def _multi_value_loss( 93 activations, labels, sequence_length, target_column, features): 94 """Maps `activations` from the RNN to loss for multi value models. 95 96 Args: 97 activations: Output from an RNN. Should have dtype `float32` and shape 98 `[batch_size, padded_length, ?]`. 99 labels: A `Tensor` with length `[batch_size, padded_length]`. 100 sequence_length: A `Tensor` with shape `[batch_size]` and dtype `int32` 101 containing the length of each sequence in the batch. If `None`, sequences 102 are assumed to be unpadded. 103 target_column: An initialized `TargetColumn`, calculate predictions. 104 features: A `dict` containing the input and (optionally) sequence length 105 information and initial state. 106 Returns: 107 A scalar `Tensor` containing the loss. 108 """ 109 with ops.name_scope('MultiValueLoss'): 110 activations_masked, labels_masked = rnn_common.mask_activations_and_labels( 111 activations, labels, sequence_length) 112 return target_column.loss(activations_masked, labels_masked, features) 113 114 115def _get_name_or_parent_names(column): 116 """Gets the name of a column or its parent columns' names. 117 118 Args: 119 column: A sequence feature column derived from `FeatureColumn`. 120 121 Returns: 122 A list of the name of `column` or the names of its parent columns, 123 if any exist. 124 """ 125 # pylint: disable=protected-access 126 parent_columns = feature_column_ops._get_parent_columns(column) 127 if parent_columns: 128 return [x.name for x in parent_columns] 129 return [column.name] 130 131 132def _prepare_features_for_sqss(features, labels, mode, 133 sequence_feature_columns, 134 context_feature_columns): 135 """Prepares features for batching by the SQSS. 136 137 In preparation for batching by the SQSS, this function: 138 - Extracts the input key from the features dict. 139 - Separates sequence and context features dicts from the features dict. 140 - Adds the labels tensor to the sequence features dict. 141 142 Args: 143 features: A dict of Python string to an iterable of `Tensor` or 144 `SparseTensor` of rank 2, the `features` argument of a TF.Learn model_fn. 145 labels: An iterable of `Tensor`. 146 mode: Defines whether this is training, evaluation or prediction. 147 See `ModeKeys`. 148 sequence_feature_columns: An iterable containing all the feature columns 149 describing sequence features. All items in the set should be instances 150 of classes derived from `FeatureColumn`. 151 context_feature_columns: An iterable containing all the feature columns 152 describing context features, i.e., features that apply across all time 153 steps. All items in the set should be instances of classes derived from 154 `FeatureColumn`. 155 156 Returns: 157 sequence_features: A dict mapping feature names to sequence features. 158 context_features: A dict mapping feature names to context features. 159 160 Raises: 161 ValueError: If `features` does not contain a value for every key in 162 `sequence_feature_columns` or `context_feature_columns`. 163 """ 164 165 # Extract sequence features. 166 feature_column_ops._check_supported_sequence_columns(sequence_feature_columns) # pylint: disable=protected-access 167 sequence_features = {} 168 for column in sequence_feature_columns: 169 for name in _get_name_or_parent_names(column): 170 feature = features.get(name, None) 171 if feature is None: 172 raise ValueError('No key in features for sequence feature: ' + name) 173 sequence_features[name] = feature 174 175 # Extract context features. 176 context_features = {} 177 if context_feature_columns is not None: 178 for column in context_feature_columns: 179 name = column.name 180 feature = features.get(name, None) 181 if feature is None: 182 raise ValueError('No key in features for context feature: ' + name) 183 context_features[name] = feature 184 185 # Add labels to the resulting sequence features dict. 186 if mode != model_fn.ModeKeys.INFER: 187 sequence_features[rnn_common.RNNKeys.LABELS_KEY] = labels 188 189 return sequence_features, context_features 190 191 192def _get_state_names(cell): 193 """Gets the state names for an `RNNCell`. 194 195 Args: 196 cell: A `RNNCell` to be used in the RNN. 197 198 Returns: 199 State names in the form of a string, a list of strings, or a list of 200 string pairs, depending on the type of `cell.state_size`. 201 202 Raises: 203 TypeError: If cell.state_size is of type TensorShape. 204 """ 205 state_size = cell.state_size 206 if isinstance(state_size, tensor_shape.TensorShape): 207 raise TypeError('cell.state_size of type TensorShape is not supported.') 208 if isinstance(state_size, int): 209 return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, 0) 210 if isinstance(state_size, rnn_cell.LSTMStateTuple): 211 return [ 212 '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, 0), 213 '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, 0), 214 ] 215 if isinstance(state_size[0], rnn_cell.LSTMStateTuple): 216 return [[ 217 '{}_{}_c'.format(rnn_common.RNNKeys.STATE_PREFIX, i), 218 '{}_{}_h'.format(rnn_common.RNNKeys.STATE_PREFIX, i), 219 ] for i in range(len(state_size))] 220 return [ 221 '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) 222 for i in range(len(state_size))] 223 224 225def _get_initial_states(cell): 226 """Gets the initial state of the `RNNCell` used in the RNN. 227 228 Args: 229 cell: A `RNNCell` to be used in the RNN. 230 231 Returns: 232 A Python dict mapping state names to the `RNNCell`'s initial state for 233 consumption by the SQSS. 234 """ 235 names = nest.flatten(_get_state_names(cell)) 236 values = nest.flatten(cell.zero_state(1, dtype=dtypes.float32)) 237 return {n: array_ops.squeeze(v, axis=0) for [n, v] in zip(names, values)} 238 239 240def _read_batch(cell, 241 features, 242 labels, 243 mode, 244 num_unroll, 245 batch_size, 246 sequence_feature_columns, 247 context_feature_columns=None, 248 num_threads=3, 249 queue_capacity=1000, 250 seed=None): 251 """Reads a batch from a state saving sequence queue. 252 253 Args: 254 cell: An initialized `RNNCell` to be used in the RNN. 255 features: A dict of Python string to an iterable of `Tensor`, the 256 `features` argument of a TF.Learn model_fn. 257 labels: An iterable of `Tensor`, the `labels` argument of a 258 TF.Learn model_fn. 259 mode: Defines whether this is training, evaluation or prediction. 260 See `ModeKeys`. 261 num_unroll: Python integer, how many time steps to unroll at a time. 262 The input sequences of length `k` are then split into `k / num_unroll` 263 many segments. 264 batch_size: Python integer, the size of the minibatch produced by the SQSS. 265 sequence_feature_columns: An iterable containing all the feature columns 266 describing sequence features. All items in the set should be instances 267 of classes derived from `FeatureColumn`. 268 context_feature_columns: An iterable containing all the feature columns 269 describing context features, i.e., features that apply across all time 270 steps. All items in the set should be instances of classes derived from 271 `FeatureColumn`. 272 num_threads: The Python integer number of threads enqueuing input examples 273 into a queue. Defaults to 3. 274 queue_capacity: The max capacity of the queue in number of examples. 275 Needs to be at least `batch_size`. Defaults to 1000. When iterating 276 over the same input example multiple times reusing their keys the 277 `queue_capacity` must be smaller than the number of examples. 278 seed: Fixes the random seed used for generating input keys by the SQSS. 279 280 Returns: 281 batch: A `NextQueuedSequenceBatch` containing batch_size `SequenceExample` 282 values and their saved internal states. 283 """ 284 states = _get_initial_states(cell) 285 286 sequences, context = _prepare_features_for_sqss( 287 features, labels, mode, sequence_feature_columns, 288 context_feature_columns) 289 290 return sqss.batch_sequences_with_states( 291 input_key='key', 292 input_sequences=sequences, 293 input_context=context, 294 input_length=None, # infer sequence lengths 295 initial_states=states, 296 num_unroll=num_unroll, 297 batch_size=batch_size, 298 pad=True, # pad to a multiple of num_unroll 299 make_keys_unique=True, 300 make_keys_unique_seed=seed, 301 num_threads=num_threads, 302 capacity=queue_capacity) 303 304 305def _get_state_name(i): 306 """Constructs the name string for state component `i`.""" 307 return '{}_{}'.format(rnn_common.RNNKeys.STATE_PREFIX, i) 308 309 310def state_tuple_to_dict(state): 311 """Returns a dict containing flattened `state`. 312 313 Args: 314 state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must 315 have the same rank and agree on all dimensions except the last. 316 317 Returns: 318 A dict containing the `Tensor`s that make up `state`. The keys of the dict 319 are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor` 320 in a depth-first traversal of `state`. 321 """ 322 with ops.name_scope('state_tuple_to_dict'): 323 flat_state = nest.flatten(state) 324 state_dict = {} 325 for i, state_component in enumerate(flat_state): 326 state_name = _get_state_name(i) 327 state_value = (None if state_component is None else array_ops.identity( 328 state_component, name=state_name)) 329 state_dict[state_name] = state_value 330 return state_dict 331 332 333def _prepare_inputs_for_rnn(sequence_features, context_features, 334 sequence_feature_columns, num_unroll): 335 """Prepares features batched by the SQSS for input to a state-saving RNN. 336 337 Args: 338 sequence_features: A dict of sequence feature name to `Tensor` or 339 `SparseTensor`, with `Tensor`s of shape `[batch_size, num_unroll, ...]` 340 or `SparseTensors` of dense shape `[batch_size, num_unroll, d]`. 341 context_features: A dict of context feature name to `Tensor`, with 342 tensors of shape `[batch_size, 1, ...]` and type float32. 343 sequence_feature_columns: An iterable containing all the feature columns 344 describing sequence features. All items in the set should be instances 345 of classes derived from `FeatureColumn`. 346 num_unroll: Python integer, how many time steps to unroll at a time. 347 The input sequences of length `k` are then split into `k / num_unroll` 348 many segments. 349 350 Returns: 351 features_by_time: A list of length `num_unroll` with `Tensor` entries of 352 shape `[batch_size, sum(sequence_features dimensions) + 353 sum(context_features dimensions)]` of type float32. 354 Context features are copied into each time step. 355 """ 356 357 def _tile(feature): 358 return array_ops.squeeze( 359 array_ops.tile(array_ops.expand_dims(feature, 1), [1, num_unroll, 1]), 360 axis=2) 361 for feature in sequence_features.values(): 362 if isinstance(feature, sparse_tensor.SparseTensor): 363 # Explicitly set dense_shape's shape to 3 ([batch_size, num_unroll, d]) 364 # since it can't be statically inferred. 365 feature.dense_shape.set_shape([3]) 366 sequence_features = layers.sequence_input_from_feature_columns( 367 columns_to_tensors=sequence_features, 368 feature_columns=sequence_feature_columns, 369 weight_collections=None, 370 scope=None) 371 # Explicitly set shape along dimension 1 to num_unroll for the unstack op. 372 sequence_features.set_shape([None, num_unroll, None]) 373 374 if not context_features: 375 return array_ops.unstack(sequence_features, axis=1) 376 # TODO(jtbates): Call layers.input_from_feature_columns for context features. 377 context_features = [ 378 _tile(context_features[k]) for k in sorted(context_features) 379 ] 380 return array_ops.unstack( 381 array_ops.concat( 382 [sequence_features, array_ops.stack(context_features, 2)], axis=2), 383 axis=1) 384 385 386def _get_rnn_model_fn(cell_type, 387 target_column, 388 problem_type, 389 optimizer, 390 num_unroll, 391 num_units, 392 num_threads, 393 queue_capacity, 394 batch_size, 395 sequence_feature_columns, 396 context_feature_columns=None, 397 predict_probabilities=False, 398 learning_rate=None, 399 gradient_clipping_norm=None, 400 dropout_keep_probabilities=None, 401 name='StateSavingRNNModel', 402 seed=None): 403 """Creates a state saving RNN model function for an `Estimator`. 404 405 Args: 406 cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. 407 target_column: An initialized `TargetColumn`, used to calculate prediction 408 and loss. 409 problem_type: `ProblemType.CLASSIFICATION` or 410 `ProblemType.LINEAR_REGRESSION`. 411 optimizer: A subclass of `Optimizer`, an instance of an `Optimizer` or a 412 string. 413 num_unroll: Python integer, how many time steps to unroll at a time. 414 The input sequences of length `k` are then split into `k / num_unroll` 415 many segments. 416 num_units: The number of units in the `RNNCell`. 417 num_threads: The Python integer number of threads enqueuing input examples 418 into a queue. 419 queue_capacity: The max capacity of the queue in number of examples. 420 Needs to be at least `batch_size`. When iterating over the same input 421 example multiple times reusing their keys the `queue_capacity` must be 422 smaller than the number of examples. 423 batch_size: Python integer, the size of the minibatch produced by the SQSS. 424 sequence_feature_columns: An iterable containing all the feature columns 425 describing sequence features. All items in the set should be instances 426 of classes derived from `FeatureColumn`. 427 context_feature_columns: An iterable containing all the feature columns 428 describing context features, i.e., features that apply across all time 429 steps. All items in the set should be instances of classes derived from 430 `FeatureColumn`. 431 predict_probabilities: A boolean indicating whether to predict probabilities 432 for all classes. 433 Must only be used with `ProblemType.CLASSIFICATION`. 434 learning_rate: Learning rate used for optimization. This argument has no 435 effect if `optimizer` is an instance of an `Optimizer`. 436 gradient_clipping_norm: A float. Gradients will be clipped to this value. 437 dropout_keep_probabilities: a list of dropout keep probabilities or `None`. 438 If given a list, it must have length `len(num_units) + 1`. 439 name: A string that will be used to create a scope for the RNN. 440 seed: Fixes the random seed used for generating input keys by the SQSS. 441 442 Returns: 443 A model function to be passed to an `Estimator`. 444 445 Raises: 446 ValueError: `problem_type` is not one of 447 `ProblemType.LINEAR_REGRESSION` 448 or `ProblemType.CLASSIFICATION`. 449 ValueError: `predict_probabilities` is `True` for `problem_type` other 450 than `ProblemType.CLASSIFICATION`. 451 ValueError: `num_unroll` is not positive. 452 """ 453 if problem_type not in (constants.ProblemType.CLASSIFICATION, 454 constants.ProblemType.LINEAR_REGRESSION): 455 raise ValueError( 456 'problem_type must be ProblemType.LINEAR_REGRESSION or ' 457 'ProblemType.CLASSIFICATION; got {}'. 458 format(problem_type)) 459 if (problem_type != constants.ProblemType.CLASSIFICATION and 460 predict_probabilities): 461 raise ValueError( 462 'predict_probabilities can only be set to True for problem_type' 463 ' ProblemType.CLASSIFICATION; got {}.'.format(problem_type)) 464 if num_unroll <= 0: 465 raise ValueError('num_unroll must be positive; got {}.'.format(num_unroll)) 466 467 def _rnn_model_fn(features, labels, mode): 468 """The model to be passed to an `Estimator`.""" 469 with ops.name_scope(name): 470 dropout = (dropout_keep_probabilities 471 if mode == model_fn.ModeKeys.TRAIN 472 else None) 473 cell = rnn_common.construct_rnn_cell(num_units, cell_type, dropout) 474 475 batch = _read_batch( 476 cell=cell, 477 features=features, 478 labels=labels, 479 mode=mode, 480 num_unroll=num_unroll, 481 batch_size=batch_size, 482 sequence_feature_columns=sequence_feature_columns, 483 context_feature_columns=context_feature_columns, 484 num_threads=num_threads, 485 queue_capacity=queue_capacity, 486 seed=seed) 487 sequence_features = batch.sequences 488 context_features = batch.context 489 if mode != model_fn.ModeKeys.INFER: 490 labels = sequence_features.pop(rnn_common.RNNKeys.LABELS_KEY) 491 inputs = _prepare_inputs_for_rnn(sequence_features, context_features, 492 sequence_feature_columns, num_unroll) 493 state_name = _get_state_names(cell) 494 rnn_activations, final_state = construct_state_saving_rnn( 495 cell=cell, 496 inputs=inputs, 497 num_label_columns=target_column.num_label_columns, 498 state_saver=batch, 499 state_name=state_name) 500 501 loss = None # Created below for modes TRAIN and EVAL. 502 prediction_dict = rnn_common.multi_value_predictions( 503 rnn_activations, target_column, problem_type, predict_probabilities) 504 if mode != model_fn.ModeKeys.INFER: 505 loss = _multi_value_loss(rnn_activations, labels, batch.length, 506 target_column, features) 507 508 eval_metric_ops = None 509 if mode != model_fn.ModeKeys.INFER: 510 eval_metric_ops = rnn_common.get_eval_metric_ops( 511 problem_type, rnn_common.PredictionType.MULTIPLE_VALUE, 512 batch.length, prediction_dict, labels) 513 514 state_dict = state_tuple_to_dict(final_state) 515 prediction_dict.update(state_dict) 516 517 train_op = None 518 if mode == model_fn.ModeKeys.TRAIN: 519 train_op = optimizers.optimize_loss( 520 loss=loss, 521 global_step=None, # Get it internally. 522 learning_rate=learning_rate, 523 optimizer=optimizer, 524 clip_gradients=gradient_clipping_norm, 525 summaries=optimizers.OPTIMIZER_SUMMARIES) 526 527 return model_fn.ModelFnOps(mode=mode, 528 predictions=prediction_dict, 529 loss=loss, 530 train_op=train_op, 531 eval_metric_ops=eval_metric_ops) 532 return _rnn_model_fn 533 534 535class StateSavingRnnEstimator(estimator.Estimator): 536 """RNN with static unrolling and state saving (deprecated). 537 538 THIS CLASS IS DEPRECATED. See 539 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 540 for general migration instructions. 541 """ 542 543 def __init__(self, 544 problem_type, 545 num_unroll, 546 batch_size, 547 sequence_feature_columns, 548 context_feature_columns=None, 549 num_classes=None, 550 num_units=None, 551 cell_type='basic_rnn', 552 optimizer_type='SGD', 553 learning_rate=0.1, 554 predict_probabilities=False, 555 momentum=None, 556 gradient_clipping_norm=5.0, 557 dropout_keep_probabilities=None, 558 model_dir=None, 559 config=None, 560 feature_engineering_fn=None, 561 num_threads=3, 562 queue_capacity=1000, 563 seed=None): 564 """Initializes a StateSavingRnnEstimator. 565 566 Args: 567 problem_type: `ProblemType.CLASSIFICATION` or 568 `ProblemType.LINEAR_REGRESSION`. 569 num_unroll: Python integer, how many time steps to unroll at a time. 570 The input sequences of length `k` are then split into `k / num_unroll` 571 many segments. 572 batch_size: Python integer, the size of the minibatch. 573 sequence_feature_columns: An iterable containing all the feature columns 574 describing sequence features. All items in the set should be instances 575 of classes derived from `FeatureColumn`. 576 context_feature_columns: An iterable containing all the feature columns 577 describing context features, i.e., features that apply across all time 578 steps. All items in the set should be instances of classes derived from 579 `FeatureColumn`. 580 num_classes: The number of classes for categorization. Used only and 581 required if `problem_type` is `ProblemType.CLASSIFICATION`. 582 num_units: A list of integers indicating the number of units in the 583 `RNNCell`s in each layer. Either `num_units` is specified or `cell_type` 584 is an instance of `RNNCell`. 585 cell_type: A subclass of `RNNCell` or one of 'basic_rnn,' 'lstm' or 'gru'. 586 optimizer_type: The type of optimizer to use. Either a subclass of 587 `Optimizer`, an instance of an `Optimizer` or a string. Strings must be 588 one of 'Adagrad', 'Adam', 'Ftrl', Momentum', 'RMSProp', or 'SGD'. 589 learning_rate: Learning rate. This argument has no effect if `optimizer` 590 is an instance of an `Optimizer`. 591 predict_probabilities: A boolean indicating whether to predict 592 probabilities for all classes. Used only if `problem_type` is 593 `ProblemType.CLASSIFICATION`. 594 momentum: Momentum value. Only used if `optimizer_type` is 'Momentum'. 595 gradient_clipping_norm: Parameter used for gradient clipping. If `None`, 596 then no clipping is performed. 597 dropout_keep_probabilities: a list of dropout keep probabilities or 598 `None`. If given a list, it must have length `len(num_units) + 1`. 599 model_dir: The directory in which to save and restore the model graph, 600 parameters, etc. 601 config: A `RunConfig` instance. 602 feature_engineering_fn: Takes features and labels which are the output of 603 `input_fn` and returns features and labels which will be fed into 604 `model_fn`. Please check `model_fn` for a definition of features and 605 labels. 606 num_threads: The Python integer number of threads enqueuing input examples 607 into a queue. Defaults to 3. 608 queue_capacity: The max capacity of the queue in number of examples. 609 Needs to be at least `batch_size`. Defaults to 1000. When iterating 610 over the same input example multiple times reusing their keys the 611 `queue_capacity` must be smaller than the number of examples. 612 seed: Fixes the random seed used for generating input keys by the SQSS. 613 614 Raises: 615 ValueError: Both or neither of the following are true: (a) `num_units` is 616 specified and (b) `cell_type` is an instance of `RNNCell`. 617 ValueError: `problem_type` is not one of 618 `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. 619 ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but 620 `num_classes` is not specified. 621 """ 622 name = 'MultiValueStateSavingRNN' 623 if problem_type == constants.ProblemType.LINEAR_REGRESSION: 624 name += 'Regressor' 625 target_column = layers.regression_target() 626 elif problem_type == constants.ProblemType.CLASSIFICATION: 627 if not num_classes: 628 raise ValueError('For CLASSIFICATION problem_type, num_classes must be ' 629 'specified.') 630 target_column = layers.multi_class_target(n_classes=num_classes) 631 name += 'Classifier' 632 else: 633 raise ValueError( 634 'problem_type must be either ProblemType.LINEAR_REGRESSION ' 635 'or ProblemType.CLASSIFICATION; got {}'.format( 636 problem_type)) 637 638 if optimizer_type == 'Momentum': 639 optimizer_type = momentum_opt.MomentumOptimizer(learning_rate, momentum) 640 641 rnn_model_fn = _get_rnn_model_fn( 642 cell_type=cell_type, 643 target_column=target_column, 644 problem_type=problem_type, 645 optimizer=optimizer_type, 646 num_unroll=num_unroll, 647 num_units=num_units, 648 num_threads=num_threads, 649 queue_capacity=queue_capacity, 650 batch_size=batch_size, 651 sequence_feature_columns=sequence_feature_columns, 652 context_feature_columns=context_feature_columns, 653 predict_probabilities=predict_probabilities, 654 learning_rate=learning_rate, 655 gradient_clipping_norm=gradient_clipping_norm, 656 dropout_keep_probabilities=dropout_keep_probabilities, 657 name=name, 658 seed=seed) 659 660 super(StateSavingRnnEstimator, self).__init__( 661 model_fn=rnn_model_fn, 662 model_dir=model_dir, 663 config=config, 664 feature_engineering_fn=feature_engineering_fn) 665