1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for learn.estimators.state_saving_rnn_estimator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import tempfile 22 23import numpy as np 24 25from tensorflow.contrib import lookup 26from tensorflow.contrib.layers.python.layers import feature_column 27from tensorflow.contrib.layers.python.layers import target_column as target_column_lib 28from tensorflow.contrib.learn.python.learn.estimators import constants 29from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 30from tensorflow.contrib.learn.python.learn.estimators import prediction_key 31from tensorflow.contrib.learn.python.learn.estimators import rnn_common 32from tensorflow.contrib.learn.python.learn.estimators import run_config 33from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import init_ops 39from tensorflow.python.ops import lookup_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import random_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import test 44 45 46class PrepareInputsForRnnTest(test.TestCase): 47 48 def _test_prepare_inputs_for_rnn(self, sequence_features, context_features, 49 sequence_feature_columns, num_unroll, 50 expected): 51 features_by_time = ssre._prepare_inputs_for_rnn(sequence_features, 52 context_features, 53 sequence_feature_columns, 54 num_unroll) 55 56 with self.cached_session() as sess: 57 sess.run(variables.global_variables_initializer()) 58 sess.run(lookup_ops.tables_initializer()) 59 features_val = sess.run(features_by_time) 60 self.assertAllEqual(expected, features_val) 61 62 def testPrepareInputsForRnnBatchSize1(self): 63 num_unroll = 3 64 65 expected = [ 66 np.array([[11., 31., 5., 7.]]), np.array([[12., 32., 5., 7.]]), 67 np.array([[13., 33., 5., 7.]]) 68 ] 69 70 sequence_features = { 71 'seq_feature0': constant_op.constant([[11., 12., 13.]]), 72 'seq_feature1': constant_op.constant([[31., 32., 33.]]) 73 } 74 75 sequence_feature_columns = [ 76 feature_column.real_valued_column( 77 'seq_feature0', dimension=1), 78 feature_column.real_valued_column( 79 'seq_feature1', dimension=1), 80 ] 81 82 context_features = { 83 'ctx_feature0': constant_op.constant([[5.]]), 84 'ctx_feature1': constant_op.constant([[7.]]) 85 } 86 self._test_prepare_inputs_for_rnn(sequence_features, context_features, 87 sequence_feature_columns, num_unroll, 88 expected) 89 90 def testPrepareInputsForRnnBatchSize2(self): 91 92 num_unroll = 3 93 94 expected = [ 95 np.array([[11., 31., 5., 7.], [21., 41., 6., 8.]]), 96 np.array([[12., 32., 5., 7.], [22., 42., 6., 8.]]), 97 np.array([[13., 33., 5., 7.], [23., 43., 6., 8.]]) 98 ] 99 100 sequence_features = { 101 'seq_feature0': 102 constant_op.constant([[11., 12., 13.], [21., 22., 23.]]), 103 'seq_feature1': 104 constant_op.constant([[31., 32., 33.], [41., 42., 43.]]) 105 } 106 107 sequence_feature_columns = [ 108 feature_column.real_valued_column( 109 'seq_feature0', dimension=1), 110 feature_column.real_valued_column( 111 'seq_feature1', dimension=1), 112 ] 113 114 context_features = { 115 'ctx_feature0': constant_op.constant([[5.], [6.]]), 116 'ctx_feature1': constant_op.constant([[7.], [8.]]) 117 } 118 119 self._test_prepare_inputs_for_rnn(sequence_features, context_features, 120 sequence_feature_columns, num_unroll, 121 expected) 122 123 def testPrepareInputsForRnnNoContext(self): 124 num_unroll = 3 125 126 expected = [ 127 np.array([[11., 31.], [21., 41.]]), np.array([[12., 32.], [22., 42.]]), 128 np.array([[13., 33.], [23., 43.]]) 129 ] 130 131 sequence_features = { 132 'seq_feature0': 133 constant_op.constant([[11., 12., 13.], [21., 22., 23.]]), 134 'seq_feature1': 135 constant_op.constant([[31., 32., 33.], [41., 42., 43.]]) 136 } 137 138 sequence_feature_columns = [ 139 feature_column.real_valued_column( 140 'seq_feature0', dimension=1), 141 feature_column.real_valued_column( 142 'seq_feature1', dimension=1), 143 ] 144 145 context_features = None 146 147 self._test_prepare_inputs_for_rnn(sequence_features, context_features, 148 sequence_feature_columns, num_unroll, 149 expected) 150 151 def testPrepareInputsForRnnSparse(self): 152 num_unroll = 2 153 embedding_dimension = 8 154 155 expected = [ 156 np.array([[1., 1., 1., 1., 1., 1., 1., 1.], 157 [1., 1., 1., 1., 1., 1., 1., 1.], 158 [1., 1., 1., 1., 1., 1., 1., 1.]]), 159 np.array([[1., 1., 1., 1., 1., 1., 1., 1.], 160 [2., 2., 2., 2., 2., 2., 2., 2.], 161 [1., 1., 1., 1., 1., 1., 1., 1.]]) 162 ] 163 164 sequence_features = { 165 'wire_cast': 166 sparse_tensor.SparseTensor( 167 indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1], 168 [2, 0, 0], [2, 1, 1]], 169 values=[ 170 b'marlo', b'stringer', b'omar', b'stringer', b'marlo', 171 b'marlo', b'omar' 172 ], 173 dense_shape=[3, 2, 2]) 174 } 175 176 wire_cast = feature_column.sparse_column_with_keys( 177 'wire_cast', ['marlo', 'omar', 'stringer']) 178 sequence_feature_columns = [ 179 feature_column.embedding_column( 180 wire_cast, 181 dimension=embedding_dimension, 182 combiner='sum', 183 initializer=init_ops.ones_initializer()) 184 ] 185 186 context_features = None 187 188 self._test_prepare_inputs_for_rnn(sequence_features, context_features, 189 sequence_feature_columns, num_unroll, 190 expected) 191 192 def testPrepareInputsForRnnSparseAndDense(self): 193 num_unroll = 2 194 embedding_dimension = 8 195 dense_dimension = 2 196 197 expected = [ 198 np.array([[1., 1., 1., 1., 1., 1., 1., 1., 111., 112.], 199 [1., 1., 1., 1., 1., 1., 1., 1., 211., 212.], 200 [1., 1., 1., 1., 1., 1., 1., 1., 311., 312.]]), 201 np.array([[1., 1., 1., 1., 1., 1., 1., 1., 121., 122.], 202 [2., 2., 2., 2., 2., 2., 2., 2., 221., 222.], 203 [1., 1., 1., 1., 1., 1., 1., 1., 321., 322.]]) 204 ] 205 206 sequence_features = { 207 'wire_cast': 208 sparse_tensor.SparseTensor( 209 indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1], 210 [2, 0, 0], [2, 1, 1]], 211 values=[ 212 b'marlo', b'stringer', b'omar', b'stringer', b'marlo', 213 b'marlo', b'omar' 214 ], 215 dense_shape=[3, 2, 2]), 216 'seq_feature0': 217 constant_op.constant([[[111., 112.], [121., 122.]], 218 [[211., 212.], [221., 222.]], 219 [[311., 312.], [321., 322.]]]) 220 } 221 222 wire_cast = feature_column.sparse_column_with_keys( 223 'wire_cast', ['marlo', 'omar', 'stringer']) 224 wire_cast_embedded = feature_column.embedding_column( 225 wire_cast, 226 dimension=embedding_dimension, 227 combiner='sum', 228 initializer=init_ops.ones_initializer()) 229 seq_feature0_column = feature_column.real_valued_column( 230 'seq_feature0', dimension=dense_dimension) 231 232 sequence_feature_columns = [seq_feature0_column, wire_cast_embedded] 233 234 context_features = None 235 236 self._test_prepare_inputs_for_rnn(sequence_features, context_features, 237 sequence_feature_columns, num_unroll, 238 expected) 239 240 241class StateSavingRnnEstimatorTest(test.TestCase): 242 243 def testPrepareFeaturesForSQSS(self): 244 mode = model_fn_lib.ModeKeys.TRAIN 245 seq_feature_name = 'seq_feature' 246 sparse_seq_feature_name = 'wire_cast' 247 ctx_feature_name = 'ctx_feature' 248 sequence_length = 4 249 embedding_dimension = 8 250 251 features = { 252 sparse_seq_feature_name: 253 sparse_tensor.SparseTensor( 254 indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1], 255 [2, 0, 0], [2, 1, 1]], 256 values=[ 257 b'marlo', b'stringer', b'omar', b'stringer', b'marlo', 258 b'marlo', b'omar' 259 ], 260 dense_shape=[3, 2, 2]), 261 seq_feature_name: 262 constant_op.constant( 263 1.0, shape=[sequence_length]), 264 ctx_feature_name: 265 constant_op.constant(2.0) 266 } 267 268 labels = constant_op.constant(5.0, shape=[sequence_length]) 269 270 wire_cast = feature_column.sparse_column_with_keys( 271 'wire_cast', ['marlo', 'omar', 'stringer']) 272 sequence_feature_columns = [ 273 feature_column.real_valued_column( 274 seq_feature_name, dimension=1), feature_column.embedding_column( 275 wire_cast, 276 dimension=embedding_dimension, 277 initializer=init_ops.ones_initializer()) 278 ] 279 280 context_feature_columns = [ 281 feature_column.real_valued_column( 282 ctx_feature_name, dimension=1) 283 ] 284 285 expected_sequence = { 286 rnn_common.RNNKeys.LABELS_KEY: 287 np.array([5., 5., 5., 5.]), 288 seq_feature_name: 289 np.array([1., 1., 1., 1.]), 290 sparse_seq_feature_name: 291 sparse_tensor.SparseTensor( 292 indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1], 293 [2, 0, 0], [2, 1, 1]], 294 values=[ 295 b'marlo', b'stringer', b'omar', b'stringer', b'marlo', 296 b'marlo', b'omar' 297 ], 298 dense_shape=[3, 2, 2]), 299 } 300 301 expected_context = {ctx_feature_name: 2.} 302 303 sequence, context = ssre._prepare_features_for_sqss( 304 features, labels, mode, sequence_feature_columns, 305 context_feature_columns) 306 307 def assert_equal(expected, got): 308 self.assertEqual(sorted(expected), sorted(got)) 309 for k, v in expected.items(): 310 if isinstance(v, sparse_tensor.SparseTensor): 311 self.assertAllEqual(v.values.eval(), got[k].values) 312 self.assertAllEqual(v.indices.eval(), got[k].indices) 313 self.assertAllEqual(v.dense_shape.eval(), got[k].dense_shape) 314 else: 315 self.assertAllEqual(v, got[k]) 316 317 with self.cached_session() as sess: 318 sess.run(variables.global_variables_initializer()) 319 sess.run(lookup_ops.tables_initializer()) 320 actual_sequence, actual_context = sess.run( 321 [sequence, context]) 322 assert_equal(expected_sequence, actual_sequence) 323 assert_equal(expected_context, actual_context) 324 325 def _getModelFnOpsForMode(self, mode): 326 """Helper for testGetRnnModelFn{Train,Eval,Infer}().""" 327 num_units = [4] 328 seq_columns = [ 329 feature_column.real_valued_column( 330 'inputs', dimension=1) 331 ] 332 features = { 333 'inputs': constant_op.constant([1., 2., 3.]), 334 } 335 labels = constant_op.constant([1., 0., 1.]) 336 model_fn = ssre._get_rnn_model_fn( 337 cell_type='basic_rnn', 338 target_column=target_column_lib.multi_class_target(n_classes=2), 339 optimizer='SGD', 340 num_unroll=2, 341 num_units=num_units, 342 num_threads=1, 343 queue_capacity=10, 344 batch_size=1, 345 # Only CLASSIFICATION yields eval metrics to test for. 346 problem_type=constants.ProblemType.CLASSIFICATION, 347 sequence_feature_columns=seq_columns, 348 context_feature_columns=None, 349 learning_rate=0.1) 350 model_fn_ops = model_fn(features=features, labels=labels, mode=mode) 351 return model_fn_ops 352 353 # testGetRnnModelFn{Train,Eval,Infer}() test which fields 354 # of ModelFnOps are set depending on mode. 355 def testGetRnnModelFnTrain(self): 356 model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.TRAIN) 357 self.assertIsNotNone(model_fn_ops.predictions) 358 self.assertIsNotNone(model_fn_ops.loss) 359 self.assertIsNotNone(model_fn_ops.train_op) 360 # None may get normalized to {}; we accept neither. 361 self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) 362 363 def testGetRnnModelFnEval(self): 364 model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.EVAL) 365 self.assertIsNotNone(model_fn_ops.predictions) 366 self.assertIsNotNone(model_fn_ops.loss) 367 self.assertIsNone(model_fn_ops.train_op) 368 # None may get normalized to {}; we accept neither. 369 self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0) 370 371 def testGetRnnModelFnInfer(self): 372 model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.INFER) 373 self.assertIsNotNone(model_fn_ops.predictions) 374 self.assertIsNone(model_fn_ops.loss) 375 self.assertIsNone(model_fn_ops.train_op) 376 # None may get normalized to {}; we accept both. 377 self.assertFalse(model_fn_ops.eval_metric_ops) 378 379 def testExport(self): 380 input_feature_key = 'magic_input_feature_key' 381 batch_size = 8 382 num_units = [4] 383 sequence_length = 10 384 num_unroll = 2 385 num_classes = 2 386 387 seq_columns = [ 388 feature_column.real_valued_column( 389 'inputs', dimension=4) 390 ] 391 392 def get_input_fn(mode, seed): 393 394 def input_fn(): 395 features = {} 396 random_sequence = random_ops.random_uniform( 397 [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) 398 labels = array_ops.slice(random_sequence, [0], [sequence_length]) 399 inputs = math_ops.cast( 400 array_ops.slice(random_sequence, [1], [sequence_length]), 401 dtypes.float32) 402 features = {'inputs': inputs} 403 404 if mode == model_fn_lib.ModeKeys.INFER: 405 input_examples = array_ops.placeholder(dtypes.string) 406 features[input_feature_key] = input_examples 407 labels = None 408 return features, labels 409 410 return input_fn 411 412 model_dir = tempfile.mkdtemp() 413 414 def estimator_fn(): 415 return ssre.StateSavingRnnEstimator( 416 constants.ProblemType.CLASSIFICATION, 417 num_units=num_units, 418 num_unroll=num_unroll, 419 batch_size=batch_size, 420 sequence_feature_columns=seq_columns, 421 num_classes=num_classes, 422 predict_probabilities=True, 423 model_dir=model_dir, 424 queue_capacity=2 + batch_size, 425 seed=1234) 426 427 # Train a bit to create an exportable checkpoint. 428 estimator_fn().fit(input_fn=get_input_fn( 429 model_fn_lib.ModeKeys.TRAIN, seed=1234), 430 steps=100) 431 # Now export, but from a fresh estimator instance, like you would 432 # in an export binary. That means .export() has to work without 433 # .fit() being called on the same object. 434 export_dir = tempfile.mkdtemp() 435 print('Exporting to', export_dir) 436 estimator_fn().export( 437 export_dir, 438 input_fn=get_input_fn( 439 model_fn_lib.ModeKeys.INFER, seed=4321), 440 use_deprecated_input_fn=False, 441 input_feature_key=input_feature_key) 442 443 444# Smoke tests to ensure deprecated constructor functions still work. 445class LegacyConstructorTest(test.TestCase): 446 447 def _get_input_fn(self, 448 sequence_length, 449 seed=None): 450 def input_fn(): 451 random_sequence = random_ops.random_uniform( 452 [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) 453 labels = array_ops.slice(random_sequence, [0], [sequence_length]) 454 inputs = math_ops.cast( 455 array_ops.slice(random_sequence, [1], [sequence_length]), 456 dtypes.float32) 457 return {'inputs': inputs}, labels 458 return input_fn 459 460 461# TODO(jtbates): move all tests below to a benchmark test. 462class StateSavingRNNEstimatorLearningTest(test.TestCase): 463 """Learning tests for state saving RNN Estimators.""" 464 465 def testLearnSineFunction(self): 466 """Tests learning a sine function.""" 467 batch_size = 8 468 num_unroll = 5 469 sequence_length = 64 470 train_steps = 250 471 eval_steps = 20 472 num_rnn_layers = 1 473 num_units = [4] * num_rnn_layers 474 learning_rate = 0.3 475 loss_threshold = 0.035 476 477 def get_sin_input_fn(sequence_length, increment, seed=None): 478 479 def input_fn(): 480 start = random_ops.random_uniform( 481 (), minval=0, maxval=(np.pi * 2.0), dtype=dtypes.float32, seed=seed) 482 sin_curves = math_ops.sin( 483 math_ops.linspace(start, (sequence_length - 1) * increment, 484 sequence_length + 1)) 485 inputs = array_ops.slice(sin_curves, [0], [sequence_length]) 486 labels = array_ops.slice(sin_curves, [1], [sequence_length]) 487 return {'inputs': inputs}, labels 488 489 return input_fn 490 491 seq_columns = [ 492 feature_column.real_valued_column( 493 'inputs', dimension=1) 494 ] 495 config = run_config.RunConfig(tf_random_seed=1234) 496 dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1) 497 sequence_estimator = ssre.StateSavingRnnEstimator( 498 constants.ProblemType.LINEAR_REGRESSION, 499 num_units=num_units, 500 cell_type='lstm', 501 num_unroll=num_unroll, 502 batch_size=batch_size, 503 sequence_feature_columns=seq_columns, 504 learning_rate=learning_rate, 505 dropout_keep_probabilities=dropout_keep_probabilities, 506 config=config, 507 queue_capacity=2 * batch_size, 508 seed=1234) 509 510 train_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=1234) 511 eval_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=4321) 512 513 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 514 loss = sequence_estimator.evaluate( 515 input_fn=eval_input_fn, steps=eval_steps)['loss'] 516 self.assertLess(loss, loss_threshold, 517 'Loss should be less than {}; got {}'.format(loss_threshold, 518 loss)) 519 520 def testLearnShiftByOne(self): 521 """Tests that learning a 'shift-by-one' example. 522 523 Each label sequence consists of the input sequence 'shifted' by one place. 524 The RNN must learn to 'remember' the previous input. 525 """ 526 batch_size = 16 527 num_classes = 2 528 num_unroll = 32 529 sequence_length = 32 530 train_steps = 300 531 eval_steps = 20 532 num_units = [4] 533 learning_rate = 0.5 534 accuracy_threshold = 0.9 535 536 def get_shift_input_fn(sequence_length, seed=None): 537 538 def input_fn(): 539 random_sequence = random_ops.random_uniform( 540 [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed) 541 labels = array_ops.slice(random_sequence, [0], [sequence_length]) 542 inputs = math_ops.cast( 543 array_ops.slice(random_sequence, [1], [sequence_length]), 544 dtypes.float32) 545 return {'inputs': inputs}, labels 546 547 return input_fn 548 549 seq_columns = [ 550 feature_column.real_valued_column( 551 'inputs', dimension=1) 552 ] 553 config = run_config.RunConfig(tf_random_seed=21212) 554 sequence_estimator = ssre.StateSavingRnnEstimator( 555 constants.ProblemType.CLASSIFICATION, 556 num_units=num_units, 557 cell_type='lstm', 558 num_unroll=num_unroll, 559 batch_size=batch_size, 560 sequence_feature_columns=seq_columns, 561 num_classes=num_classes, 562 learning_rate=learning_rate, 563 config=config, 564 predict_probabilities=True, 565 queue_capacity=2 + batch_size, 566 seed=1234) 567 568 train_input_fn = get_shift_input_fn(sequence_length, seed=12321) 569 eval_input_fn = get_shift_input_fn(sequence_length, seed=32123) 570 571 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 572 573 evaluation = sequence_estimator.evaluate( 574 input_fn=eval_input_fn, steps=eval_steps) 575 accuracy = evaluation['accuracy'] 576 self.assertGreater(accuracy, accuracy_threshold, 577 'Accuracy should be higher than {}; got {}'.format( 578 accuracy_threshold, accuracy)) 579 580 # Testing `predict` when `predict_probabilities=True`. 581 prediction_dict = sequence_estimator.predict( 582 input_fn=eval_input_fn, as_iterable=False) 583 self.assertListEqual( 584 sorted(list(prediction_dict.keys())), 585 sorted([ 586 prediction_key.PredictionKey.CLASSES, 587 prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0) 588 ])) 589 predictions = prediction_dict[prediction_key.PredictionKey.CLASSES] 590 probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES] 591 self.assertListEqual(list(predictions.shape), [batch_size, sequence_length]) 592 self.assertListEqual( 593 list(probabilities.shape), [batch_size, sequence_length, 2]) 594 595 def testLearnLyrics(self): 596 lyrics = 'if I go there will be trouble and if I stay it will be double' 597 lyrics_list = lyrics.split() 598 sequence_length = len(lyrics_list) 599 vocab = set(lyrics_list) 600 batch_size = 16 601 num_classes = len(vocab) 602 num_unroll = 7 # not a divisor of sequence_length 603 train_steps = 350 604 eval_steps = 30 605 num_units = [4] 606 learning_rate = 0.4 607 accuracy_threshold = 0.65 608 609 def get_lyrics_input_fn(seed): 610 611 def input_fn(): 612 start = random_ops.random_uniform( 613 (), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed) 614 # Concatenate lyrics_list so inputs and labels wrap when start > 0. 615 lyrics_list_concat = lyrics_list + lyrics_list 616 inputs_dense = array_ops.slice(lyrics_list_concat, [start], 617 [sequence_length]) 618 indices = array_ops.constant( 619 [[i, 0] for i in range(sequence_length)], dtype=dtypes.int64) 620 dense_shape = [sequence_length, 1] 621 inputs = sparse_tensor.SparseTensor( 622 indices=indices, values=inputs_dense, dense_shape=dense_shape) 623 table = lookup.string_to_index_table_from_tensor( 624 mapping=list(vocab), default_value=-1, name='lookup') 625 labels = table.lookup( 626 array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length])) 627 return {'lyrics': inputs}, labels 628 629 return input_fn 630 631 sequence_feature_columns = [ 632 feature_column.embedding_column( 633 feature_column.sparse_column_with_keys('lyrics', vocab), 634 dimension=8) 635 ] 636 config = run_config.RunConfig(tf_random_seed=21212) 637 sequence_estimator = ssre.StateSavingRnnEstimator( 638 constants.ProblemType.CLASSIFICATION, 639 num_units=num_units, 640 cell_type='basic_rnn', 641 num_unroll=num_unroll, 642 batch_size=batch_size, 643 sequence_feature_columns=sequence_feature_columns, 644 num_classes=num_classes, 645 learning_rate=learning_rate, 646 config=config, 647 predict_probabilities=True, 648 queue_capacity=2 + batch_size, 649 seed=1234) 650 651 train_input_fn = get_lyrics_input_fn(seed=12321) 652 eval_input_fn = get_lyrics_input_fn(seed=32123) 653 654 sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps) 655 656 evaluation = sequence_estimator.evaluate( 657 input_fn=eval_input_fn, steps=eval_steps) 658 accuracy = evaluation['accuracy'] 659 self.assertGreater(accuracy, accuracy_threshold, 660 'Accuracy should be higher than {}; got {}'.format( 661 accuracy_threshold, accuracy)) 662 663 664if __name__ == '__main__': 665 test.main() 666