• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for learn.estimators.state_saving_rnn_estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tempfile
22
23import numpy as np
24
25from tensorflow.contrib import lookup
26from tensorflow.contrib.layers.python.layers import feature_column
27from tensorflow.contrib.layers.python.layers import target_column as target_column_lib
28from tensorflow.contrib.learn.python.learn.estimators import constants
29from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
30from tensorflow.contrib.learn.python.learn.estimators import prediction_key
31from tensorflow.contrib.learn.python.learn.estimators import rnn_common
32from tensorflow.contrib.learn.python.learn.estimators import run_config
33from tensorflow.contrib.learn.python.learn.estimators import state_saving_rnn_estimator as ssre
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import init_ops
39from tensorflow.python.ops import lookup_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import random_ops
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import test
44
45
46class PrepareInputsForRnnTest(test.TestCase):
47
48  def _test_prepare_inputs_for_rnn(self, sequence_features, context_features,
49                                   sequence_feature_columns, num_unroll,
50                                   expected):
51    features_by_time = ssre._prepare_inputs_for_rnn(sequence_features,
52                                                    context_features,
53                                                    sequence_feature_columns,
54                                                    num_unroll)
55
56    with self.cached_session() as sess:
57      sess.run(variables.global_variables_initializer())
58      sess.run(lookup_ops.tables_initializer())
59      features_val = sess.run(features_by_time)
60      self.assertAllEqual(expected, features_val)
61
62  def testPrepareInputsForRnnBatchSize1(self):
63    num_unroll = 3
64
65    expected = [
66        np.array([[11., 31., 5., 7.]]), np.array([[12., 32., 5., 7.]]),
67        np.array([[13., 33., 5., 7.]])
68    ]
69
70    sequence_features = {
71        'seq_feature0': constant_op.constant([[11., 12., 13.]]),
72        'seq_feature1': constant_op.constant([[31., 32., 33.]])
73    }
74
75    sequence_feature_columns = [
76        feature_column.real_valued_column(
77            'seq_feature0', dimension=1),
78        feature_column.real_valued_column(
79            'seq_feature1', dimension=1),
80    ]
81
82    context_features = {
83        'ctx_feature0': constant_op.constant([[5.]]),
84        'ctx_feature1': constant_op.constant([[7.]])
85    }
86    self._test_prepare_inputs_for_rnn(sequence_features, context_features,
87                                      sequence_feature_columns, num_unroll,
88                                      expected)
89
90  def testPrepareInputsForRnnBatchSize2(self):
91
92    num_unroll = 3
93
94    expected = [
95        np.array([[11., 31., 5., 7.], [21., 41., 6., 8.]]),
96        np.array([[12., 32., 5., 7.], [22., 42., 6., 8.]]),
97        np.array([[13., 33., 5., 7.], [23., 43., 6., 8.]])
98    ]
99
100    sequence_features = {
101        'seq_feature0':
102            constant_op.constant([[11., 12., 13.], [21., 22., 23.]]),
103        'seq_feature1':
104            constant_op.constant([[31., 32., 33.], [41., 42., 43.]])
105    }
106
107    sequence_feature_columns = [
108        feature_column.real_valued_column(
109            'seq_feature0', dimension=1),
110        feature_column.real_valued_column(
111            'seq_feature1', dimension=1),
112    ]
113
114    context_features = {
115        'ctx_feature0': constant_op.constant([[5.], [6.]]),
116        'ctx_feature1': constant_op.constant([[7.], [8.]])
117    }
118
119    self._test_prepare_inputs_for_rnn(sequence_features, context_features,
120                                      sequence_feature_columns, num_unroll,
121                                      expected)
122
123  def testPrepareInputsForRnnNoContext(self):
124    num_unroll = 3
125
126    expected = [
127        np.array([[11., 31.], [21., 41.]]), np.array([[12., 32.], [22., 42.]]),
128        np.array([[13., 33.], [23., 43.]])
129    ]
130
131    sequence_features = {
132        'seq_feature0':
133            constant_op.constant([[11., 12., 13.], [21., 22., 23.]]),
134        'seq_feature1':
135            constant_op.constant([[31., 32., 33.], [41., 42., 43.]])
136    }
137
138    sequence_feature_columns = [
139        feature_column.real_valued_column(
140            'seq_feature0', dimension=1),
141        feature_column.real_valued_column(
142            'seq_feature1', dimension=1),
143    ]
144
145    context_features = None
146
147    self._test_prepare_inputs_for_rnn(sequence_features, context_features,
148                                      sequence_feature_columns, num_unroll,
149                                      expected)
150
151  def testPrepareInputsForRnnSparse(self):
152    num_unroll = 2
153    embedding_dimension = 8
154
155    expected = [
156        np.array([[1., 1., 1., 1., 1., 1., 1., 1.],
157                  [1., 1., 1., 1., 1., 1., 1., 1.],
158                  [1., 1., 1., 1., 1., 1., 1., 1.]]),
159        np.array([[1., 1., 1., 1., 1., 1., 1., 1.],
160                  [2., 2., 2., 2., 2., 2., 2., 2.],
161                  [1., 1., 1., 1., 1., 1., 1., 1.]])
162    ]
163
164    sequence_features = {
165        'wire_cast':
166            sparse_tensor.SparseTensor(
167                indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1],
168                         [2, 0, 0], [2, 1, 1]],
169                values=[
170                    b'marlo', b'stringer', b'omar', b'stringer', b'marlo',
171                    b'marlo', b'omar'
172                ],
173                dense_shape=[3, 2, 2])
174    }
175
176    wire_cast = feature_column.sparse_column_with_keys(
177        'wire_cast', ['marlo', 'omar', 'stringer'])
178    sequence_feature_columns = [
179        feature_column.embedding_column(
180            wire_cast,
181            dimension=embedding_dimension,
182            combiner='sum',
183            initializer=init_ops.ones_initializer())
184    ]
185
186    context_features = None
187
188    self._test_prepare_inputs_for_rnn(sequence_features, context_features,
189                                      sequence_feature_columns, num_unroll,
190                                      expected)
191
192  def testPrepareInputsForRnnSparseAndDense(self):
193    num_unroll = 2
194    embedding_dimension = 8
195    dense_dimension = 2
196
197    expected = [
198        np.array([[1., 1., 1., 1., 1., 1., 1., 1., 111., 112.],
199                  [1., 1., 1., 1., 1., 1., 1., 1., 211., 212.],
200                  [1., 1., 1., 1., 1., 1., 1., 1., 311., 312.]]),
201        np.array([[1., 1., 1., 1., 1., 1., 1., 1., 121., 122.],
202                  [2., 2., 2., 2., 2., 2., 2., 2., 221., 222.],
203                  [1., 1., 1., 1., 1., 1., 1., 1., 321., 322.]])
204    ]
205
206    sequence_features = {
207        'wire_cast':
208            sparse_tensor.SparseTensor(
209                indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1],
210                         [2, 0, 0], [2, 1, 1]],
211                values=[
212                    b'marlo', b'stringer', b'omar', b'stringer', b'marlo',
213                    b'marlo', b'omar'
214                ],
215                dense_shape=[3, 2, 2]),
216        'seq_feature0':
217            constant_op.constant([[[111., 112.], [121., 122.]],
218                                  [[211., 212.], [221., 222.]],
219                                  [[311., 312.], [321., 322.]]])
220    }
221
222    wire_cast = feature_column.sparse_column_with_keys(
223        'wire_cast', ['marlo', 'omar', 'stringer'])
224    wire_cast_embedded = feature_column.embedding_column(
225        wire_cast,
226        dimension=embedding_dimension,
227        combiner='sum',
228        initializer=init_ops.ones_initializer())
229    seq_feature0_column = feature_column.real_valued_column(
230        'seq_feature0', dimension=dense_dimension)
231
232    sequence_feature_columns = [seq_feature0_column, wire_cast_embedded]
233
234    context_features = None
235
236    self._test_prepare_inputs_for_rnn(sequence_features, context_features,
237                                      sequence_feature_columns, num_unroll,
238                                      expected)
239
240
241class StateSavingRnnEstimatorTest(test.TestCase):
242
243  def testPrepareFeaturesForSQSS(self):
244    mode = model_fn_lib.ModeKeys.TRAIN
245    seq_feature_name = 'seq_feature'
246    sparse_seq_feature_name = 'wire_cast'
247    ctx_feature_name = 'ctx_feature'
248    sequence_length = 4
249    embedding_dimension = 8
250
251    features = {
252        sparse_seq_feature_name:
253            sparse_tensor.SparseTensor(
254                indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1],
255                         [2, 0, 0], [2, 1, 1]],
256                values=[
257                    b'marlo', b'stringer', b'omar', b'stringer', b'marlo',
258                    b'marlo', b'omar'
259                ],
260                dense_shape=[3, 2, 2]),
261        seq_feature_name:
262            constant_op.constant(
263                1.0, shape=[sequence_length]),
264        ctx_feature_name:
265            constant_op.constant(2.0)
266    }
267
268    labels = constant_op.constant(5.0, shape=[sequence_length])
269
270    wire_cast = feature_column.sparse_column_with_keys(
271        'wire_cast', ['marlo', 'omar', 'stringer'])
272    sequence_feature_columns = [
273        feature_column.real_valued_column(
274            seq_feature_name, dimension=1), feature_column.embedding_column(
275                wire_cast,
276                dimension=embedding_dimension,
277                initializer=init_ops.ones_initializer())
278    ]
279
280    context_feature_columns = [
281        feature_column.real_valued_column(
282            ctx_feature_name, dimension=1)
283    ]
284
285    expected_sequence = {
286        rnn_common.RNNKeys.LABELS_KEY:
287            np.array([5., 5., 5., 5.]),
288        seq_feature_name:
289            np.array([1., 1., 1., 1.]),
290        sparse_seq_feature_name:
291            sparse_tensor.SparseTensor(
292                indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1],
293                         [2, 0, 0], [2, 1, 1]],
294                values=[
295                    b'marlo', b'stringer', b'omar', b'stringer', b'marlo',
296                    b'marlo', b'omar'
297                ],
298                dense_shape=[3, 2, 2]),
299    }
300
301    expected_context = {ctx_feature_name: 2.}
302
303    sequence, context = ssre._prepare_features_for_sqss(
304        features, labels, mode, sequence_feature_columns,
305        context_feature_columns)
306
307    def assert_equal(expected, got):
308      self.assertEqual(sorted(expected), sorted(got))
309      for k, v in expected.items():
310        if isinstance(v, sparse_tensor.SparseTensor):
311          self.assertAllEqual(v.values.eval(), got[k].values)
312          self.assertAllEqual(v.indices.eval(), got[k].indices)
313          self.assertAllEqual(v.dense_shape.eval(), got[k].dense_shape)
314        else:
315          self.assertAllEqual(v, got[k])
316
317    with self.cached_session() as sess:
318      sess.run(variables.global_variables_initializer())
319      sess.run(lookup_ops.tables_initializer())
320      actual_sequence, actual_context = sess.run(
321          [sequence, context])
322      assert_equal(expected_sequence, actual_sequence)
323      assert_equal(expected_context, actual_context)
324
325  def _getModelFnOpsForMode(self, mode):
326    """Helper for testGetRnnModelFn{Train,Eval,Infer}()."""
327    num_units = [4]
328    seq_columns = [
329        feature_column.real_valued_column(
330            'inputs', dimension=1)
331    ]
332    features = {
333        'inputs': constant_op.constant([1., 2., 3.]),
334    }
335    labels = constant_op.constant([1., 0., 1.])
336    model_fn = ssre._get_rnn_model_fn(
337        cell_type='basic_rnn',
338        target_column=target_column_lib.multi_class_target(n_classes=2),
339        optimizer='SGD',
340        num_unroll=2,
341        num_units=num_units,
342        num_threads=1,
343        queue_capacity=10,
344        batch_size=1,
345        # Only CLASSIFICATION yields eval metrics to test for.
346        problem_type=constants.ProblemType.CLASSIFICATION,
347        sequence_feature_columns=seq_columns,
348        context_feature_columns=None,
349        learning_rate=0.1)
350    model_fn_ops = model_fn(features=features, labels=labels, mode=mode)
351    return model_fn_ops
352
353  # testGetRnnModelFn{Train,Eval,Infer}() test which fields
354  # of ModelFnOps are set depending on mode.
355  def testGetRnnModelFnTrain(self):
356    model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.TRAIN)
357    self.assertIsNotNone(model_fn_ops.predictions)
358    self.assertIsNotNone(model_fn_ops.loss)
359    self.assertIsNotNone(model_fn_ops.train_op)
360    # None may get normalized to {}; we accept neither.
361    self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
362
363  def testGetRnnModelFnEval(self):
364    model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.EVAL)
365    self.assertIsNotNone(model_fn_ops.predictions)
366    self.assertIsNotNone(model_fn_ops.loss)
367    self.assertIsNone(model_fn_ops.train_op)
368    # None may get normalized to {}; we accept neither.
369    self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
370
371  def testGetRnnModelFnInfer(self):
372    model_fn_ops = self._getModelFnOpsForMode(model_fn_lib.ModeKeys.INFER)
373    self.assertIsNotNone(model_fn_ops.predictions)
374    self.assertIsNone(model_fn_ops.loss)
375    self.assertIsNone(model_fn_ops.train_op)
376    # None may get normalized to {}; we accept both.
377    self.assertFalse(model_fn_ops.eval_metric_ops)
378
379  def testExport(self):
380    input_feature_key = 'magic_input_feature_key'
381    batch_size = 8
382    num_units = [4]
383    sequence_length = 10
384    num_unroll = 2
385    num_classes = 2
386
387    seq_columns = [
388        feature_column.real_valued_column(
389            'inputs', dimension=4)
390    ]
391
392    def get_input_fn(mode, seed):
393
394      def input_fn():
395        features = {}
396        random_sequence = random_ops.random_uniform(
397            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
398        labels = array_ops.slice(random_sequence, [0], [sequence_length])
399        inputs = math_ops.cast(
400            array_ops.slice(random_sequence, [1], [sequence_length]),
401            dtypes.float32)
402        features = {'inputs': inputs}
403
404        if mode == model_fn_lib.ModeKeys.INFER:
405          input_examples = array_ops.placeholder(dtypes.string)
406          features[input_feature_key] = input_examples
407          labels = None
408        return features, labels
409
410      return input_fn
411
412    model_dir = tempfile.mkdtemp()
413
414    def estimator_fn():
415      return ssre.StateSavingRnnEstimator(
416          constants.ProblemType.CLASSIFICATION,
417          num_units=num_units,
418          num_unroll=num_unroll,
419          batch_size=batch_size,
420          sequence_feature_columns=seq_columns,
421          num_classes=num_classes,
422          predict_probabilities=True,
423          model_dir=model_dir,
424          queue_capacity=2 + batch_size,
425          seed=1234)
426
427    # Train a bit to create an exportable checkpoint.
428    estimator_fn().fit(input_fn=get_input_fn(
429        model_fn_lib.ModeKeys.TRAIN, seed=1234),
430                       steps=100)
431    # Now export, but from a fresh estimator instance, like you would
432    # in an export binary. That means .export() has to work without
433    # .fit() being called on the same object.
434    export_dir = tempfile.mkdtemp()
435    print('Exporting to', export_dir)
436    estimator_fn().export(
437        export_dir,
438        input_fn=get_input_fn(
439            model_fn_lib.ModeKeys.INFER, seed=4321),
440        use_deprecated_input_fn=False,
441        input_feature_key=input_feature_key)
442
443
444# Smoke tests to ensure deprecated constructor functions still work.
445class LegacyConstructorTest(test.TestCase):
446
447  def _get_input_fn(self,
448                    sequence_length,
449                    seed=None):
450    def input_fn():
451      random_sequence = random_ops.random_uniform(
452          [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
453      labels = array_ops.slice(random_sequence, [0], [sequence_length])
454      inputs = math_ops.cast(
455          array_ops.slice(random_sequence, [1], [sequence_length]),
456          dtypes.float32)
457      return {'inputs': inputs}, labels
458    return input_fn
459
460
461# TODO(jtbates): move all tests below to a benchmark test.
462class StateSavingRNNEstimatorLearningTest(test.TestCase):
463  """Learning tests for state saving RNN Estimators."""
464
465  def testLearnSineFunction(self):
466    """Tests learning a sine function."""
467    batch_size = 8
468    num_unroll = 5
469    sequence_length = 64
470    train_steps = 250
471    eval_steps = 20
472    num_rnn_layers = 1
473    num_units = [4] * num_rnn_layers
474    learning_rate = 0.3
475    loss_threshold = 0.035
476
477    def get_sin_input_fn(sequence_length, increment, seed=None):
478
479      def input_fn():
480        start = random_ops.random_uniform(
481            (), minval=0, maxval=(np.pi * 2.0), dtype=dtypes.float32, seed=seed)
482        sin_curves = math_ops.sin(
483            math_ops.linspace(start, (sequence_length - 1) * increment,
484                              sequence_length + 1))
485        inputs = array_ops.slice(sin_curves, [0], [sequence_length])
486        labels = array_ops.slice(sin_curves, [1], [sequence_length])
487        return {'inputs': inputs}, labels
488
489      return input_fn
490
491    seq_columns = [
492        feature_column.real_valued_column(
493            'inputs', dimension=1)
494    ]
495    config = run_config.RunConfig(tf_random_seed=1234)
496    dropout_keep_probabilities = [0.9] * (num_rnn_layers + 1)
497    sequence_estimator = ssre.StateSavingRnnEstimator(
498        constants.ProblemType.LINEAR_REGRESSION,
499        num_units=num_units,
500        cell_type='lstm',
501        num_unroll=num_unroll,
502        batch_size=batch_size,
503        sequence_feature_columns=seq_columns,
504        learning_rate=learning_rate,
505        dropout_keep_probabilities=dropout_keep_probabilities,
506        config=config,
507        queue_capacity=2 * batch_size,
508        seed=1234)
509
510    train_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=1234)
511    eval_input_fn = get_sin_input_fn(sequence_length, np.pi / 32, seed=4321)
512
513    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
514    loss = sequence_estimator.evaluate(
515        input_fn=eval_input_fn, steps=eval_steps)['loss']
516    self.assertLess(loss, loss_threshold,
517                    'Loss should be less than {}; got {}'.format(loss_threshold,
518                                                                 loss))
519
520  def testLearnShiftByOne(self):
521    """Tests that learning a 'shift-by-one' example.
522
523    Each label sequence consists of the input sequence 'shifted' by one place.
524    The RNN must learn to 'remember' the previous input.
525    """
526    batch_size = 16
527    num_classes = 2
528    num_unroll = 32
529    sequence_length = 32
530    train_steps = 300
531    eval_steps = 20
532    num_units = [4]
533    learning_rate = 0.5
534    accuracy_threshold = 0.9
535
536    def get_shift_input_fn(sequence_length, seed=None):
537
538      def input_fn():
539        random_sequence = random_ops.random_uniform(
540            [sequence_length + 1], 0, 2, dtype=dtypes.int32, seed=seed)
541        labels = array_ops.slice(random_sequence, [0], [sequence_length])
542        inputs = math_ops.cast(
543            array_ops.slice(random_sequence, [1], [sequence_length]),
544            dtypes.float32)
545        return {'inputs': inputs}, labels
546
547      return input_fn
548
549    seq_columns = [
550        feature_column.real_valued_column(
551            'inputs', dimension=1)
552    ]
553    config = run_config.RunConfig(tf_random_seed=21212)
554    sequence_estimator = ssre.StateSavingRnnEstimator(
555        constants.ProblemType.CLASSIFICATION,
556        num_units=num_units,
557        cell_type='lstm',
558        num_unroll=num_unroll,
559        batch_size=batch_size,
560        sequence_feature_columns=seq_columns,
561        num_classes=num_classes,
562        learning_rate=learning_rate,
563        config=config,
564        predict_probabilities=True,
565        queue_capacity=2 + batch_size,
566        seed=1234)
567
568    train_input_fn = get_shift_input_fn(sequence_length, seed=12321)
569    eval_input_fn = get_shift_input_fn(sequence_length, seed=32123)
570
571    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
572
573    evaluation = sequence_estimator.evaluate(
574        input_fn=eval_input_fn, steps=eval_steps)
575    accuracy = evaluation['accuracy']
576    self.assertGreater(accuracy, accuracy_threshold,
577                       'Accuracy should be higher than {}; got {}'.format(
578                           accuracy_threshold, accuracy))
579
580    # Testing `predict` when `predict_probabilities=True`.
581    prediction_dict = sequence_estimator.predict(
582        input_fn=eval_input_fn, as_iterable=False)
583    self.assertListEqual(
584        sorted(list(prediction_dict.keys())),
585        sorted([
586            prediction_key.PredictionKey.CLASSES,
587            prediction_key.PredictionKey.PROBABILITIES, ssre._get_state_name(0)
588        ]))
589    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
590    probabilities = prediction_dict[prediction_key.PredictionKey.PROBABILITIES]
591    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
592    self.assertListEqual(
593        list(probabilities.shape), [batch_size, sequence_length, 2])
594
595  def testLearnLyrics(self):
596    lyrics = 'if I go there will be trouble and if I stay it will be double'
597    lyrics_list = lyrics.split()
598    sequence_length = len(lyrics_list)
599    vocab = set(lyrics_list)
600    batch_size = 16
601    num_classes = len(vocab)
602    num_unroll = 7  # not a divisor of sequence_length
603    train_steps = 350
604    eval_steps = 30
605    num_units = [4]
606    learning_rate = 0.4
607    accuracy_threshold = 0.65
608
609    def get_lyrics_input_fn(seed):
610
611      def input_fn():
612        start = random_ops.random_uniform(
613            (), minval=0, maxval=sequence_length, dtype=dtypes.int32, seed=seed)
614        # Concatenate lyrics_list so inputs and labels wrap when start > 0.
615        lyrics_list_concat = lyrics_list + lyrics_list
616        inputs_dense = array_ops.slice(lyrics_list_concat, [start],
617                                       [sequence_length])
618        indices = array_ops.constant(
619            [[i, 0] for i in range(sequence_length)], dtype=dtypes.int64)
620        dense_shape = [sequence_length, 1]
621        inputs = sparse_tensor.SparseTensor(
622            indices=indices, values=inputs_dense, dense_shape=dense_shape)
623        table = lookup.string_to_index_table_from_tensor(
624            mapping=list(vocab), default_value=-1, name='lookup')
625        labels = table.lookup(
626            array_ops.slice(lyrics_list_concat, [start + 1], [sequence_length]))
627        return {'lyrics': inputs}, labels
628
629      return input_fn
630
631    sequence_feature_columns = [
632        feature_column.embedding_column(
633            feature_column.sparse_column_with_keys('lyrics', vocab),
634            dimension=8)
635    ]
636    config = run_config.RunConfig(tf_random_seed=21212)
637    sequence_estimator = ssre.StateSavingRnnEstimator(
638        constants.ProblemType.CLASSIFICATION,
639        num_units=num_units,
640        cell_type='basic_rnn',
641        num_unroll=num_unroll,
642        batch_size=batch_size,
643        sequence_feature_columns=sequence_feature_columns,
644        num_classes=num_classes,
645        learning_rate=learning_rate,
646        config=config,
647        predict_probabilities=True,
648        queue_capacity=2 + batch_size,
649        seed=1234)
650
651    train_input_fn = get_lyrics_input_fn(seed=12321)
652    eval_input_fn = get_lyrics_input_fn(seed=32123)
653
654    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
655
656    evaluation = sequence_estimator.evaluate(
657        input_fn=eval_input_fn, steps=eval_steps)
658    accuracy = evaluation['accuracy']
659    self.assertGreater(accuracy, accuracy_threshold,
660                       'Accuracy should be higher than {}; got {}'.format(
661                           accuracy_threshold, accuracy))
662
663
664if __name__ == '__main__':
665  test.main()
666