1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for learn.estimators.dynamic_rnn_estimator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import tempfile 22 23import numpy as np 24 25from tensorflow.contrib import rnn 26from tensorflow.contrib.layers.python.layers import feature_column 27from tensorflow.contrib.layers.python.layers import target_column as target_column_lib 28from tensorflow.contrib.learn.python.learn.estimators import constants 29from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator 30from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 31from tensorflow.contrib.learn.python.learn.estimators import prediction_key 32from tensorflow.contrib.learn.python.learn.estimators import rnn_common 33from tensorflow.contrib.learn.python.learn.estimators import run_config 34from tensorflow.python.client import session 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import random_seed 38from tensorflow.python.framework import sparse_tensor 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import lookup_ops 41from tensorflow.python.ops import map_fn 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import random_ops 44from tensorflow.python.ops import rnn_cell 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import test 47 48 49class IdentityRNNCell(rnn.RNNCell): 50 51 def __init__(self, state_size, output_size): 52 self._state_size = state_size 53 self._output_size = output_size 54 55 @property 56 def state_size(self): 57 return self._state_size 58 59 @property 60 def output_size(self): 61 return self._output_size 62 63 def __call__(self, inputs, state): 64 return array_ops.identity(inputs), array_ops.ones( 65 [array_ops.shape(inputs)[0], self.state_size]) 66 67 68class MockTargetColumn(object): 69 70 def __init__(self, num_label_columns=None): 71 self._num_label_columns = num_label_columns 72 73 def get_eval_ops(self, features, activations, labels, metrics): 74 raise NotImplementedError( 75 'MockTargetColumn.get_eval_ops called unexpectedly.') 76 77 def logits_to_predictions(self, flattened_activations, proba=False): 78 raise NotImplementedError( 79 'MockTargetColumn.logits_to_predictions called unexpectedly.') 80 81 def loss(self, activations, labels, features): 82 raise NotImplementedError('MockTargetColumn.loss called unexpectedly.') 83 84 @property 85 def num_label_columns(self): 86 if self._num_label_columns is None: 87 raise ValueError('MockTargetColumn.num_label_columns has not been set.') 88 return self._num_label_columns 89 90 def set_num_label_columns(self, n): 91 self._num_label_columns = n 92 93 94def sequence_length_mask(values, lengths): 95 masked = values 96 for i, length in enumerate(lengths): 97 masked[i, length:, :] = np.zeros_like(masked[i, length:, :]) 98 return masked 99 100 101class DynamicRnnEstimatorTest(test.TestCase): 102 103 NUM_RNN_CELL_UNITS = 8 104 NUM_LABEL_COLUMNS = 6 105 INPUTS_COLUMN = feature_column.real_valued_column( 106 'inputs', dimension=NUM_LABEL_COLUMNS) 107 108 def setUp(self): 109 super(DynamicRnnEstimatorTest, self).setUp() 110 self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS) 111 self.mock_target_column = MockTargetColumn( 112 num_label_columns=self.NUM_LABEL_COLUMNS) 113 114 location = feature_column.sparse_column_with_keys( 115 'location', keys=['west_side', 'east_side', 'nyc']) 116 location_onehot = feature_column.one_hot_column(location) 117 self.context_feature_columns = [location_onehot] 118 119 wire_cast = feature_column.sparse_column_with_keys( 120 'wire_cast', ['marlo', 'omar', 'stringer']) 121 wire_cast_embedded = feature_column.embedding_column(wire_cast, dimension=8) 122 measurements = feature_column.real_valued_column( 123 'measurements', dimension=2) 124 self.sequence_feature_columns = [measurements, wire_cast_embedded] 125 126 def GetColumnsToTensors(self): 127 """Get columns_to_tensors matching setUp(), in the current default graph.""" 128 return { 129 'location': 130 sparse_tensor.SparseTensor( 131 indices=[[0, 0], [1, 0], [2, 0]], 132 values=['west_side', 'west_side', 'nyc'], 133 dense_shape=[3, 1]), 134 'wire_cast': 135 sparse_tensor.SparseTensor( 136 indices=[[0, 0, 0], [0, 1, 0], 137 [1, 0, 0], [1, 1, 0], [1, 1, 1], 138 [2, 0, 0]], 139 values=[b'marlo', b'stringer', 140 b'omar', b'stringer', b'marlo', 141 b'marlo'], 142 dense_shape=[3, 2, 2]), 143 'measurements': 144 random_ops.random_uniform( 145 [3, 2, 2], seed=4711) 146 } 147 148 def GetClassificationTargetsOrNone(self, mode): 149 """Get targets matching setUp() and mode, in the current default graph.""" 150 return (random_ops.random_uniform( 151 [3, 2, 1], 0, 2, dtype=dtypes.int64, seed=1412) if 152 mode != model_fn_lib.ModeKeys.INFER else None) 153 154 def testBuildSequenceInputInput(self): 155 sequence_input = dynamic_rnn_estimator.build_sequence_input( 156 self.GetColumnsToTensors(), self.sequence_feature_columns, 157 self.context_feature_columns) 158 with self.cached_session() as sess: 159 sess.run(variables.global_variables_initializer()) 160 sess.run(lookup_ops.tables_initializer()) 161 sequence_input_val = sess.run(sequence_input) 162 expected_shape = np.array([ 163 3, # expected batch size 164 2, # padded sequence length 165 3 + 8 + 2 # location keys + embedding dim + measurement dimension 166 ]) 167 self.assertAllEqual(expected_shape, sequence_input_val.shape) 168 169 def testConstructRNN(self): 170 initial_state = None 171 sequence_input = dynamic_rnn_estimator.build_sequence_input( 172 self.GetColumnsToTensors(), self.sequence_feature_columns, 173 self.context_feature_columns) 174 activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn( 175 initial_state, sequence_input, self.rnn_cell, 176 self.mock_target_column.num_label_columns) 177 178 # Obtain values of activations and final state. 179 with session.Session() as sess: 180 sess.run(variables.global_variables_initializer()) 181 sess.run(lookup_ops.tables_initializer()) 182 activations, final_state = sess.run([activations_t, final_state_t]) 183 184 expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) 185 self.assertAllEqual(expected_activations_shape, activations.shape) 186 expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS]) 187 self.assertAllEqual(expected_state_shape, final_state.shape) 188 189 def testGetOutputAlternatives(self): 190 test_cases = ( 191 (rnn_common.PredictionType.SINGLE_VALUE, 192 constants.ProblemType.CLASSIFICATION, 193 {prediction_key.PredictionKey.CLASSES: True, 194 prediction_key.PredictionKey.PROBABILITIES: True, 195 dynamic_rnn_estimator._get_state_name(0): True}, 196 {'dynamic_rnn_output': 197 (constants.ProblemType.CLASSIFICATION, 198 {prediction_key.PredictionKey.CLASSES: True, 199 prediction_key.PredictionKey.PROBABILITIES: True})}), 200 201 (rnn_common.PredictionType.SINGLE_VALUE, 202 constants.ProblemType.LINEAR_REGRESSION, 203 {prediction_key.PredictionKey.SCORES: True, 204 dynamic_rnn_estimator._get_state_name(0): True, 205 dynamic_rnn_estimator._get_state_name(1): True}, 206 {'dynamic_rnn_output': 207 (constants.ProblemType.LINEAR_REGRESSION, 208 {prediction_key.PredictionKey.SCORES: True})}), 209 210 (rnn_common.PredictionType.MULTIPLE_VALUE, 211 constants.ProblemType.CLASSIFICATION, 212 {prediction_key.PredictionKey.CLASSES: True, 213 prediction_key.PredictionKey.PROBABILITIES: True, 214 dynamic_rnn_estimator._get_state_name(0): True}, 215 None)) 216 217 for pred_type, prob_type, pred_dict, expected_alternatives in test_cases: 218 actual_alternatives = dynamic_rnn_estimator._get_output_alternatives( 219 pred_type, prob_type, pred_dict) 220 self.assertEqual(expected_alternatives, actual_alternatives) 221 222 # testGetDynamicRnnModelFn{Train,Eval,Infer}() test which fields 223 # of ModelFnOps are set depending on mode. 224 def testGetDynamicRnnModelFnTrain(self): 225 model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.TRAIN) 226 self.assertIsNotNone(model_fn_ops.predictions) 227 self.assertIsNotNone(model_fn_ops.loss) 228 self.assertIsNotNone(model_fn_ops.train_op) 229 # None may get normalized to {}; we accept neither. 230 self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) 231 232 def testGetDynamicRnnModelFnEval(self): 233 model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.EVAL) 234 self.assertIsNotNone(model_fn_ops.predictions) 235 self.assertIsNotNone(model_fn_ops.loss) 236 self.assertIsNone(model_fn_ops.train_op) 237 # None may get normalized to {}; we accept neither. 238 self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) 239 240 def testGetDynamicRnnModelFnInfer(self): 241 model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.INFER) 242 self.assertIsNotNone(model_fn_ops.predictions) 243 self.assertIsNone(model_fn_ops.loss) 244 self.assertIsNone(model_fn_ops.train_op) 245 # None may get normalized to {}; we accept both. 246 self.assertFalse(model_fn_ops.eval_metric_ops) 247 248 def _GetModelFnOpsForMode(self, mode): 249 """Helper for testGetDynamicRnnModelFn{Train,Eval,Infer}().""" 250 model_fn = dynamic_rnn_estimator._get_dynamic_rnn_model_fn( 251 cell_type='basic_rnn', 252 num_units=[10], 253 target_column=target_column_lib.multi_class_target(n_classes=2), 254 # Only CLASSIFICATION yields eval metrics to test for. 255 problem_type=constants.ProblemType.CLASSIFICATION, 256 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 257 optimizer='SGD', 258 sequence_feature_columns=self.sequence_feature_columns, 259 context_feature_columns=self.context_feature_columns, 260 learning_rate=0.1) 261 labels = self.GetClassificationTargetsOrNone(mode) 262 model_fn_ops = model_fn( 263 features=self.GetColumnsToTensors(), labels=labels, mode=mode) 264 return model_fn_ops 265 266 def testExport(self): 267 input_feature_key = 'magic_input_feature_key' 268 269 def get_input_fn(mode): 270 271 def input_fn(): 272 features = self.GetColumnsToTensors() 273 if mode == model_fn_lib.ModeKeys.INFER: 274 input_examples = array_ops.placeholder(dtypes.string) 275 features[input_feature_key] = input_examples 276 # Real code would now parse features out of input_examples, 277 # but this test can just stick to the constants above. 278 return features, self.GetClassificationTargetsOrNone(mode) 279 280 return input_fn 281 282 model_dir = tempfile.mkdtemp() 283 284 def estimator_fn(): 285 return dynamic_rnn_estimator.DynamicRnnEstimator( 286 problem_type=constants.ProblemType.CLASSIFICATION, 287 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 288 num_classes=2, 289 num_units=self.NUM_RNN_CELL_UNITS, 290 sequence_feature_columns=self.sequence_feature_columns, 291 context_feature_columns=self.context_feature_columns, 292 predict_probabilities=True, 293 model_dir=model_dir) 294 295 # Train a bit to create an exportable checkpoint. 296 estimator_fn().fit(input_fn=get_input_fn(model_fn_lib.ModeKeys.TRAIN), 297 steps=100) 298 # Now export, but from a fresh estimator instance, like you would 299 # in an export binary. That means .export() has to work without 300 # .fit() being called on the same object. 301 export_dir = tempfile.mkdtemp() 302 print('Exporting to', export_dir) 303 estimator_fn().export( 304 export_dir, 305 input_fn=get_input_fn(model_fn_lib.ModeKeys.INFER), 306 use_deprecated_input_fn=False, 307 input_feature_key=input_feature_key) 308 309 def testStateTupleDictConversion(self): 310 """Test `state_tuple_to_dict` and `dict_to_state_tuple`.""" 311 cell_sizes = [5, 3, 7] 312 # A MultiRNNCell of LSTMCells is both a common choice and an interesting 313 # test case, because it has two levels of nesting, with an inner class that 314 # is not a plain tuple. 315 cell = rnn_cell.MultiRNNCell( 316 [rnn_cell.LSTMCell(i) for i in cell_sizes]) 317 state_dict = { 318 dynamic_rnn_estimator._get_state_name(i): 319 array_ops.expand_dims(math_ops.range(cell_size), 0) 320 for i, cell_size in enumerate([5, 5, 3, 3, 7, 7]) 321 } 322 expected_state = (rnn_cell.LSTMStateTuple( 323 np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])), 324 rnn_cell.LSTMStateTuple( 325 np.reshape(np.arange(3), [1, -1]), 326 np.reshape(np.arange(3), [1, -1])), 327 rnn_cell.LSTMStateTuple( 328 np.reshape(np.arange(7), [1, -1]), 329 np.reshape(np.arange(7), [1, -1]))) 330 actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell) 331 flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state) 332 333 with self.cached_session() as sess: 334 (state_dict_val, actual_state_val, flattened_state_val) = sess.run( 335 [state_dict, actual_state, flattened_state]) 336 337 def _recursive_assert_equal(x, y): 338 self.assertEqual(type(x), type(y)) 339 if isinstance(x, (list, tuple)): 340 self.assertEqual(len(x), len(y)) 341 for i, _ in enumerate(x): 342 _recursive_assert_equal(x[i], y[i]) 343 elif isinstance(x, np.ndarray): 344 np.testing.assert_array_equal(x, y) 345 else: 346 self.fail('Unexpected type: {}'.format(type(x))) 347 348 for k in state_dict_val.keys(): 349 np.testing.assert_array_almost_equal( 350 state_dict_val[k], 351 flattened_state_val[k], 352 err_msg='Wrong value for state component {}.'.format(k)) 353 _recursive_assert_equal(expected_state, actual_state_val) 354 355 def testMultiRNNState(self): 356 """Test that state flattening/reconstruction works for `MultiRNNCell`.""" 357 batch_size = 11 358 sequence_length = 16 359 train_steps = 5 360 cell_sizes = [4, 8, 7] 361 learning_rate = 0.1 362 363 def get_shift_input_fn(batch_size, sequence_length, seed=None): 364 365 def input_fn(): 366 random_sequence = random_ops.random_uniform( 367 [batch_size, sequence_length + 1], 368 0, 369 2, 370 dtype=dtypes.int32, 371 seed=seed) 372 labels = array_ops.slice(random_sequence, [0, 0], 373 [batch_size, sequence_length]) 374 inputs = array_ops.expand_dims( 375 math_ops.cast( 376 array_ops.slice(random_sequence, [0, 1], 377 [batch_size, sequence_length]), 378 dtypes.float32), 2) 379 input_dict = { 380 dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform( 381 [batch_size, cell_size], seed=((i + 1) * seed)) 382 for i, cell_size in enumerate([4, 4, 8, 8, 7, 7]) 383 } 384 input_dict['inputs'] = inputs 385 return input_dict, labels 386 387 return input_fn 388 389 seq_columns = [feature_column.real_valued_column('inputs', dimension=1)] 390 config = run_config.RunConfig(tf_random_seed=21212) 391 cell_type = 'lstm' 392 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 393 problem_type=constants.ProblemType.CLASSIFICATION, 394 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 395 num_classes=2, 396 num_units=cell_sizes, 397 sequence_feature_columns=seq_columns, 398 cell_type=cell_type, 399 learning_rate=learning_rate, 400 config=config, 401 predict_probabilities=True) 402 403 train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) 404 eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) 405 406 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 407 408 prediction_dict = sequence_estimator.predict( 409 input_fn=eval_input_fn, as_iterable=False) 410 for i, state_size in enumerate([4, 4, 8, 8, 7, 7]): 411 state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)] 412 self.assertListEqual(list(state_piece.shape), [batch_size, state_size]) 413 414 def testMultipleRuns(self): 415 """Tests resuming training by feeding state.""" 416 cell_sizes = [4, 7] 417 batch_size = 11 418 learning_rate = 0.1 419 train_sequence_length = 21 420 train_steps = 121 421 dropout_keep_probabilities = [0.5, 0.5, 0.5] 422 prediction_steps = [3, 2, 5, 11, 6] 423 424 def get_input_fn(batch_size, sequence_length, state_dict, starting_step=0): 425 426 def input_fn(): 427 sequence = constant_op.constant( 428 [[(starting_step + i + j) % 2 for j in range(sequence_length + 1)] 429 for i in range(batch_size)], 430 dtype=dtypes.int32) 431 labels = array_ops.slice(sequence, [0, 0], 432 [batch_size, sequence_length]) 433 inputs = array_ops.expand_dims( 434 math_ops.cast( 435 array_ops.slice(sequence, [0, 1], [batch_size, sequence_length 436 ]), 437 dtypes.float32), 2) 438 input_dict = state_dict 439 input_dict['inputs'] = inputs 440 return input_dict, labels 441 442 return input_fn 443 444 seq_columns = [feature_column.real_valued_column('inputs', dimension=1)] 445 config = run_config.RunConfig(tf_random_seed=21212) 446 447 model_dir = tempfile.mkdtemp() 448 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 449 problem_type=constants.ProblemType.CLASSIFICATION, 450 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 451 num_classes=2, 452 sequence_feature_columns=seq_columns, 453 num_units=cell_sizes, 454 cell_type='lstm', 455 dropout_keep_probabilities=dropout_keep_probabilities, 456 learning_rate=learning_rate, 457 config=config, 458 model_dir=model_dir) 459 460 train_input_fn = get_input_fn( 461 batch_size, train_sequence_length, state_dict={}) 462 463 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 464 465 def incremental_predict(estimator, increments): 466 """Run `estimator.predict` for `i` steps for `i` in `increments`.""" 467 step = 0 468 incremental_state_dict = {} 469 for increment in increments: 470 input_fn = get_input_fn( 471 batch_size, 472 increment, 473 state_dict=incremental_state_dict, 474 starting_step=step) 475 prediction_dict = estimator.predict( 476 input_fn=input_fn, as_iterable=False) 477 step += increment 478 incremental_state_dict = { 479 k: v 480 for (k, v) in prediction_dict.items() 481 if k.startswith(rnn_common.RNNKeys.STATE_PREFIX) 482 } 483 return prediction_dict 484 485 pred_all_at_once = incremental_predict(sequence_estimator, 486 [sum(prediction_steps)]) 487 pred_step_by_step = incremental_predict(sequence_estimator, 488 prediction_steps) 489 490 # Check that the last `prediction_steps[-1]` steps give the same 491 # predictions. 492 np.testing.assert_array_equal( 493 pred_all_at_once[prediction_key.PredictionKey.CLASSES] 494 [:, -1 * prediction_steps[-1]:], 495 pred_step_by_step[prediction_key.PredictionKey.CLASSES], 496 err_msg='Mismatch on last {} predictions.'.format(prediction_steps[-1])) 497 # Check that final states are identical. 498 for k, v in pred_all_at_once.items(): 499 if k.startswith(rnn_common.RNNKeys.STATE_PREFIX): 500 np.testing.assert_array_equal( 501 v, pred_step_by_step[k], err_msg='Mismatch on state {}.'.format(k)) 502 503 504# TODO(jamieas): move all tests below to a benchmark test. 505class DynamicRNNEstimatorLearningTest(test.TestCase): 506 """Learning tests for dynamic RNN Estimators.""" 507 508 def testLearnSineFunction(self): 509 """Tests learning a sine function.""" 510 batch_size = 8 511 sequence_length = 64 512 train_steps = 200 513 eval_steps = 20 514 cell_size = [4] 515 learning_rate = 0.1 516 loss_threshold = 0.02 517 518 def get_sin_input_fn(batch_size, sequence_length, increment, seed=None): 519 520 def _sin_fn(x): 521 ranger = math_ops.linspace( 522 array_ops.reshape(x[0], []), (sequence_length - 1) * increment, 523 sequence_length + 1) 524 return math_ops.sin(ranger) 525 526 def input_fn(): 527 starts = random_ops.random_uniform( 528 [batch_size], maxval=(2 * np.pi), seed=seed) 529 sin_curves = map_fn.map_fn( 530 _sin_fn, (starts,), dtype=dtypes.float32) 531 inputs = array_ops.expand_dims( 532 array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]), 533 2) 534 labels = array_ops.slice(sin_curves, [0, 1], 535 [batch_size, sequence_length]) 536 return {'inputs': inputs}, labels 537 538 return input_fn 539 540 seq_columns = [ 541 feature_column.real_valued_column( 542 'inputs', dimension=cell_size[0]) 543 ] 544 config = run_config.RunConfig(tf_random_seed=1234) 545 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 546 problem_type=constants.ProblemType.LINEAR_REGRESSION, 547 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 548 num_units=cell_size, 549 sequence_feature_columns=seq_columns, 550 learning_rate=learning_rate, 551 dropout_keep_probabilities=[0.9, 0.9], 552 config=config) 553 554 train_input_fn = get_sin_input_fn( 555 batch_size, sequence_length, np.pi / 32, seed=1234) 556 eval_input_fn = get_sin_input_fn( 557 batch_size, sequence_length, np.pi / 32, seed=4321) 558 559 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 560 loss = sequence_estimator.evaluate( 561 input_fn=eval_input_fn, steps=eval_steps)['loss'] 562 self.assertLess(loss, loss_threshold, 563 'Loss should be less than {}; got {}'.format(loss_threshold, 564 loss)) 565 566 def testLearnShiftByOne(self): 567 """Tests that learning a 'shift-by-one' example. 568 569 Each label sequence consists of the input sequence 'shifted' by one place. 570 The RNN must learn to 'remember' the previous input. 571 """ 572 batch_size = 16 573 sequence_length = 32 574 train_steps = 200 575 eval_steps = 20 576 cell_size = 4 577 learning_rate = 0.3 578 accuracy_threshold = 0.9 579 580 def get_shift_input_fn(batch_size, sequence_length, seed=None): 581 582 def input_fn(): 583 random_sequence = random_ops.random_uniform( 584 [batch_size, sequence_length + 1], 585 0, 586 2, 587 dtype=dtypes.int32, 588 seed=seed) 589 labels = array_ops.slice(random_sequence, [0, 0], 590 [batch_size, sequence_length]) 591 inputs = array_ops.expand_dims( 592 math_ops.cast( 593 array_ops.slice(random_sequence, [0, 1], 594 [batch_size, sequence_length]), 595 dtypes.float32), 596 2) 597 return {'inputs': inputs}, labels 598 599 return input_fn 600 601 seq_columns = [ 602 feature_column.real_valued_column( 603 'inputs', dimension=cell_size) 604 ] 605 config = run_config.RunConfig(tf_random_seed=21212) 606 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 607 problem_type=constants.ProblemType.CLASSIFICATION, 608 prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE, 609 num_classes=2, 610 num_units=cell_size, 611 sequence_feature_columns=seq_columns, 612 learning_rate=learning_rate, 613 config=config, 614 predict_probabilities=True) 615 616 train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321) 617 eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123) 618 619 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 620 621 evaluation = sequence_estimator.evaluate( 622 input_fn=eval_input_fn, steps=eval_steps) 623 accuracy = evaluation['accuracy'] 624 self.assertGreater(accuracy, accuracy_threshold, 625 'Accuracy should be higher than {}; got {}'.format( 626 accuracy_threshold, accuracy)) 627 628 # Testing `predict` when `predict_probabilities=True`. 629 prediction_dict = sequence_estimator.predict( 630 input_fn=eval_input_fn, as_iterable=False) 631 self.assertListEqual( 632 sorted(list(prediction_dict.keys())), 633 sorted([ 634 prediction_key.PredictionKey.CLASSES, 635 prediction_key.PredictionKey.PROBABILITIES, 636 dynamic_rnn_estimator._get_state_name(0) 637 ])) 638 predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] 639 probabilities = prediction_dict[ 640 prediction_key.PredictionKey.PROBABILITIES] 641 self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) 642 self.assertListEqual( 643 list(probabilities.shape), [batch_size, sequence_length, 2]) 644 645 def testLearnMean(self): 646 """Test learning to calculate a mean.""" 647 batch_size = 16 648 sequence_length = 3 649 train_steps = 200 650 eval_steps = 20 651 cell_type = 'basic_rnn' 652 cell_size = 8 653 optimizer_type = 'Momentum' 654 learning_rate = 0.1 655 momentum = 0.9 656 loss_threshold = 0.1 657 658 def get_mean_input_fn(batch_size, sequence_length, seed=None): 659 660 def input_fn(): 661 # Create examples by choosing 'centers' and adding uniform noise. 662 centers = math_ops.matmul( 663 random_ops.random_uniform( 664 [batch_size, 1], -0.75, 0.75, dtype=dtypes.float32, seed=seed), 665 array_ops.ones([1, sequence_length])) 666 noise = random_ops.random_uniform( 667 [batch_size, sequence_length], 668 -0.25, 669 0.25, 670 dtype=dtypes.float32, 671 seed=seed) 672 sequences = centers + noise 673 674 inputs = array_ops.expand_dims(sequences, 2) 675 labels = math_ops.reduce_mean(sequences, axis=[1]) 676 return {'inputs': inputs}, labels 677 678 return input_fn 679 680 seq_columns = [ 681 feature_column.real_valued_column( 682 'inputs', dimension=cell_size) 683 ] 684 config = run_config.RunConfig(tf_random_seed=6) 685 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 686 problem_type=constants.ProblemType.LINEAR_REGRESSION, 687 prediction_type=rnn_common.PredictionType.SINGLE_VALUE, 688 num_units=cell_size, 689 sequence_feature_columns=seq_columns, 690 cell_type=cell_type, 691 optimizer=optimizer_type, 692 learning_rate=learning_rate, 693 momentum=momentum, 694 config=config) 695 696 train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121) 697 eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212) 698 699 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 700 evaluation = sequence_estimator.evaluate( 701 input_fn=eval_input_fn, steps=eval_steps) 702 loss = evaluation['loss'] 703 self.assertLess(loss, loss_threshold, 704 'Loss should be less than {}; got {}'.format(loss_threshold, 705 loss)) 706 707 def DISABLED_testLearnMajority(self): 708 """Test learning the 'majority' function.""" 709 batch_size = 16 710 sequence_length = 7 711 train_steps = 500 712 eval_steps = 20 713 cell_type = 'lstm' 714 cell_size = 4 715 optimizer_type = 'Momentum' 716 learning_rate = 2.0 717 momentum = 0.9 718 accuracy_threshold = 0.6 719 720 def get_majority_input_fn(batch_size, sequence_length, seed=None): 721 random_seed.set_random_seed(seed) 722 723 def input_fn(): 724 random_sequence = random_ops.random_uniform( 725 [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed) 726 inputs = array_ops.expand_dims( 727 math_ops.cast(random_sequence, dtypes.float32), 2) 728 labels = math_ops.cast( 729 array_ops.squeeze( 730 math_ops.reduce_sum(inputs, axis=[1]) > ( 731 sequence_length / 2.0)), 732 dtypes.int32) 733 return {'inputs': inputs}, labels 734 735 return input_fn 736 737 seq_columns = [ 738 feature_column.real_valued_column( 739 'inputs', dimension=cell_size) 740 ] 741 config = run_config.RunConfig(tf_random_seed=77) 742 sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator( 743 problem_type=constants.ProblemType.CLASSIFICATION, 744 prediction_type=rnn_common.PredictionType.SINGLE_VALUE, 745 num_classes=2, 746 num_units=cell_size, 747 sequence_feature_columns=seq_columns, 748 cell_type=cell_type, 749 optimizer=optimizer_type, 750 learning_rate=learning_rate, 751 momentum=momentum, 752 config=config, 753 predict_probabilities=True) 754 755 train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111) 756 eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222) 757 758 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 759 evaluation = sequence_estimator.evaluate( 760 input_fn=eval_input_fn, steps=eval_steps) 761 accuracy = evaluation['accuracy'] 762 self.assertGreater(accuracy, accuracy_threshold, 763 'Accuracy should be higher than {}; got {}'.format( 764 accuracy_threshold, accuracy)) 765 766 # Testing `predict` when `predict_probabilities=True`. 767 prediction_dict = sequence_estimator.predict( 768 input_fn=eval_input_fn, as_iterable=False) 769 self.assertListEqual( 770 sorted(list(prediction_dict.keys())), 771 sorted([ 772 prediction_key.PredictionKey.CLASSES, 773 prediction_key.PredictionKey.PROBABILITIES, 774 dynamic_rnn_estimator._get_state_name(0), 775 dynamic_rnn_estimator._get_state_name(1) 776 ])) 777 predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] 778 probabilities = prediction_dict[ 779 prediction_key.PredictionKey.PROBABILITIES] 780 self.assertListEqual(list(predictions.shape), [batch_size]) 781 self.assertListEqual(list(probabilities.shape), [batch_size, 2]) 782 783 784if __name__ == '__main__': 785 test.main() 786