• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for learn.estimators.dynamic_rnn_estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tempfile
22
23import numpy as np
24
25from tensorflow.contrib import rnn
26from tensorflow.contrib.layers.python.layers import feature_column
27from tensorflow.contrib.layers.python.layers import target_column as target_column_lib
28from tensorflow.contrib.learn.python.learn.estimators import constants
29from tensorflow.contrib.learn.python.learn.estimators import dynamic_rnn_estimator
30from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
31from tensorflow.contrib.learn.python.learn.estimators import prediction_key
32from tensorflow.contrib.learn.python.learn.estimators import rnn_common
33from tensorflow.contrib.learn.python.learn.estimators import run_config
34from tensorflow.python.client import session
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import random_seed
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import lookup_ops
41from tensorflow.python.ops import map_fn
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import random_ops
44from tensorflow.python.ops import rnn_cell
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import test
47
48
49class IdentityRNNCell(rnn.RNNCell):
50
51  def __init__(self, state_size, output_size):
52    self._state_size = state_size
53    self._output_size = output_size
54
55  @property
56  def state_size(self):
57    return self._state_size
58
59  @property
60  def output_size(self):
61    return self._output_size
62
63  def __call__(self, inputs, state):
64    return array_ops.identity(inputs), array_ops.ones(
65        [array_ops.shape(inputs)[0], self.state_size])
66
67
68class MockTargetColumn(object):
69
70  def __init__(self, num_label_columns=None):
71    self._num_label_columns = num_label_columns
72
73  def get_eval_ops(self, features, activations, labels, metrics):
74    raise NotImplementedError(
75        'MockTargetColumn.get_eval_ops called unexpectedly.')
76
77  def logits_to_predictions(self, flattened_activations, proba=False):
78    raise NotImplementedError(
79        'MockTargetColumn.logits_to_predictions called unexpectedly.')
80
81  def loss(self, activations, labels, features):
82    raise NotImplementedError('MockTargetColumn.loss called unexpectedly.')
83
84  @property
85  def num_label_columns(self):
86    if self._num_label_columns is None:
87      raise ValueError('MockTargetColumn.num_label_columns has not been set.')
88    return self._num_label_columns
89
90  def set_num_label_columns(self, n):
91    self._num_label_columns = n
92
93
94def sequence_length_mask(values, lengths):
95  masked = values
96  for i, length in enumerate(lengths):
97    masked[i, length:, :] = np.zeros_like(masked[i, length:, :])
98  return masked
99
100
101class DynamicRnnEstimatorTest(test.TestCase):
102
103  NUM_RNN_CELL_UNITS = 8
104  NUM_LABEL_COLUMNS = 6
105  INPUTS_COLUMN = feature_column.real_valued_column(
106      'inputs', dimension=NUM_LABEL_COLUMNS)
107
108  def setUp(self):
109    super(DynamicRnnEstimatorTest, self).setUp()
110    self.rnn_cell = rnn_cell.BasicRNNCell(self.NUM_RNN_CELL_UNITS)
111    self.mock_target_column = MockTargetColumn(
112        num_label_columns=self.NUM_LABEL_COLUMNS)
113
114    location = feature_column.sparse_column_with_keys(
115        'location', keys=['west_side', 'east_side', 'nyc'])
116    location_onehot = feature_column.one_hot_column(location)
117    self.context_feature_columns = [location_onehot]
118
119    wire_cast = feature_column.sparse_column_with_keys(
120        'wire_cast', ['marlo', 'omar', 'stringer'])
121    wire_cast_embedded = feature_column.embedding_column(wire_cast, dimension=8)
122    measurements = feature_column.real_valued_column(
123        'measurements', dimension=2)
124    self.sequence_feature_columns = [measurements, wire_cast_embedded]
125
126  def GetColumnsToTensors(self):
127    """Get columns_to_tensors matching setUp(), in the current default graph."""
128    return {
129        'location':
130            sparse_tensor.SparseTensor(
131                indices=[[0, 0], [1, 0], [2, 0]],
132                values=['west_side', 'west_side', 'nyc'],
133                dense_shape=[3, 1]),
134        'wire_cast':
135            sparse_tensor.SparseTensor(
136                indices=[[0, 0, 0], [0, 1, 0],
137                         [1, 0, 0], [1, 1, 0], [1, 1, 1],
138                         [2, 0, 0]],
139                values=[b'marlo', b'stringer',
140                        b'omar', b'stringer', b'marlo',
141                        b'marlo'],
142                dense_shape=[3, 2, 2]),
143        'measurements':
144            random_ops.random_uniform(
145                [3, 2, 2], seed=4711)
146    }
147
148  def GetClassificationTargetsOrNone(self, mode):
149    """Get targets matching setUp() and mode, in the current default graph."""
150    return (random_ops.random_uniform(
151        [3, 2, 1], 0, 2, dtype=dtypes.int64, seed=1412) if
152            mode != model_fn_lib.ModeKeys.INFER else None)
153
154  def testBuildSequenceInputInput(self):
155    sequence_input = dynamic_rnn_estimator.build_sequence_input(
156        self.GetColumnsToTensors(), self.sequence_feature_columns,
157        self.context_feature_columns)
158    with self.cached_session() as sess:
159      sess.run(variables.global_variables_initializer())
160      sess.run(lookup_ops.tables_initializer())
161      sequence_input_val = sess.run(sequence_input)
162    expected_shape = np.array([
163        3,  # expected batch size
164        2,  # padded sequence length
165        3 + 8 + 2  # location keys + embedding dim + measurement dimension
166    ])
167    self.assertAllEqual(expected_shape, sequence_input_val.shape)
168
169  def testConstructRNN(self):
170    initial_state = None
171    sequence_input = dynamic_rnn_estimator.build_sequence_input(
172        self.GetColumnsToTensors(), self.sequence_feature_columns,
173        self.context_feature_columns)
174    activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
175        initial_state, sequence_input, self.rnn_cell,
176        self.mock_target_column.num_label_columns)
177
178    # Obtain values of activations and final state.
179    with session.Session() as sess:
180      sess.run(variables.global_variables_initializer())
181      sess.run(lookup_ops.tables_initializer())
182      activations, final_state = sess.run([activations_t, final_state_t])
183
184    expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
185    self.assertAllEqual(expected_activations_shape, activations.shape)
186    expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
187    self.assertAllEqual(expected_state_shape, final_state.shape)
188
189  def testGetOutputAlternatives(self):
190    test_cases = (
191        (rnn_common.PredictionType.SINGLE_VALUE,
192         constants.ProblemType.CLASSIFICATION,
193         {prediction_key.PredictionKey.CLASSES: True,
194          prediction_key.PredictionKey.PROBABILITIES: True,
195          dynamic_rnn_estimator._get_state_name(0): True},
196         {'dynamic_rnn_output':
197          (constants.ProblemType.CLASSIFICATION,
198           {prediction_key.PredictionKey.CLASSES: True,
199            prediction_key.PredictionKey.PROBABILITIES: True})}),
200
201        (rnn_common.PredictionType.SINGLE_VALUE,
202         constants.ProblemType.LINEAR_REGRESSION,
203         {prediction_key.PredictionKey.SCORES: True,
204          dynamic_rnn_estimator._get_state_name(0): True,
205          dynamic_rnn_estimator._get_state_name(1): True},
206         {'dynamic_rnn_output':
207          (constants.ProblemType.LINEAR_REGRESSION,
208           {prediction_key.PredictionKey.SCORES: True})}),
209
210        (rnn_common.PredictionType.MULTIPLE_VALUE,
211         constants.ProblemType.CLASSIFICATION,
212         {prediction_key.PredictionKey.CLASSES: True,
213          prediction_key.PredictionKey.PROBABILITIES: True,
214          dynamic_rnn_estimator._get_state_name(0): True},
215         None))
216
217    for pred_type, prob_type, pred_dict, expected_alternatives in test_cases:
218      actual_alternatives = dynamic_rnn_estimator._get_output_alternatives(
219          pred_type, prob_type, pred_dict)
220      self.assertEqual(expected_alternatives, actual_alternatives)
221
222  # testGetDynamicRnnModelFn{Train,Eval,Infer}() test which fields
223  # of ModelFnOps are set depending on mode.
224  def testGetDynamicRnnModelFnTrain(self):
225    model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.TRAIN)
226    self.assertIsNotNone(model_fn_ops.predictions)
227    self.assertIsNotNone(model_fn_ops.loss)
228    self.assertIsNotNone(model_fn_ops.train_op)
229    # None may get normalized to {}; we accept neither.
230    self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
231
232  def testGetDynamicRnnModelFnEval(self):
233    model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.EVAL)
234    self.assertIsNotNone(model_fn_ops.predictions)
235    self.assertIsNotNone(model_fn_ops.loss)
236    self.assertIsNone(model_fn_ops.train_op)
237    # None may get normalized to {}; we accept neither.
238    self.assertNotEqual(len(model_fn_ops.eval_metric_ops), 0)
239
240  def testGetDynamicRnnModelFnInfer(self):
241    model_fn_ops = self._GetModelFnOpsForMode(model_fn_lib.ModeKeys.INFER)
242    self.assertIsNotNone(model_fn_ops.predictions)
243    self.assertIsNone(model_fn_ops.loss)
244    self.assertIsNone(model_fn_ops.train_op)
245    # None may get normalized to {}; we accept both.
246    self.assertFalse(model_fn_ops.eval_metric_ops)
247
248  def _GetModelFnOpsForMode(self, mode):
249    """Helper for testGetDynamicRnnModelFn{Train,Eval,Infer}()."""
250    model_fn = dynamic_rnn_estimator._get_dynamic_rnn_model_fn(
251        cell_type='basic_rnn',
252        num_units=[10],
253        target_column=target_column_lib.multi_class_target(n_classes=2),
254        # Only CLASSIFICATION yields eval metrics to test for.
255        problem_type=constants.ProblemType.CLASSIFICATION,
256        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
257        optimizer='SGD',
258        sequence_feature_columns=self.sequence_feature_columns,
259        context_feature_columns=self.context_feature_columns,
260        learning_rate=0.1)
261    labels = self.GetClassificationTargetsOrNone(mode)
262    model_fn_ops = model_fn(
263        features=self.GetColumnsToTensors(), labels=labels, mode=mode)
264    return model_fn_ops
265
266  def testExport(self):
267    input_feature_key = 'magic_input_feature_key'
268
269    def get_input_fn(mode):
270
271      def input_fn():
272        features = self.GetColumnsToTensors()
273        if mode == model_fn_lib.ModeKeys.INFER:
274          input_examples = array_ops.placeholder(dtypes.string)
275          features[input_feature_key] = input_examples
276          # Real code would now parse features out of input_examples,
277          # but this test can just stick to the constants above.
278        return features, self.GetClassificationTargetsOrNone(mode)
279
280      return input_fn
281
282    model_dir = tempfile.mkdtemp()
283
284    def estimator_fn():
285      return dynamic_rnn_estimator.DynamicRnnEstimator(
286          problem_type=constants.ProblemType.CLASSIFICATION,
287          prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
288          num_classes=2,
289          num_units=self.NUM_RNN_CELL_UNITS,
290          sequence_feature_columns=self.sequence_feature_columns,
291          context_feature_columns=self.context_feature_columns,
292          predict_probabilities=True,
293          model_dir=model_dir)
294
295    # Train a bit to create an exportable checkpoint.
296    estimator_fn().fit(input_fn=get_input_fn(model_fn_lib.ModeKeys.TRAIN),
297                       steps=100)
298    # Now export, but from a fresh estimator instance, like you would
299    # in an export binary. That means .export() has to work without
300    # .fit() being called on the same object.
301    export_dir = tempfile.mkdtemp()
302    print('Exporting to', export_dir)
303    estimator_fn().export(
304        export_dir,
305        input_fn=get_input_fn(model_fn_lib.ModeKeys.INFER),
306        use_deprecated_input_fn=False,
307        input_feature_key=input_feature_key)
308
309  def testStateTupleDictConversion(self):
310    """Test `state_tuple_to_dict` and `dict_to_state_tuple`."""
311    cell_sizes = [5, 3, 7]
312    # A MultiRNNCell of LSTMCells is both a common choice and an interesting
313    # test case, because it has two levels of nesting, with an inner class that
314    # is not a plain tuple.
315    cell = rnn_cell.MultiRNNCell(
316        [rnn_cell.LSTMCell(i) for i in cell_sizes])
317    state_dict = {
318        dynamic_rnn_estimator._get_state_name(i):
319        array_ops.expand_dims(math_ops.range(cell_size), 0)
320        for i, cell_size in enumerate([5, 5, 3, 3, 7, 7])
321    }
322    expected_state = (rnn_cell.LSTMStateTuple(
323        np.reshape(np.arange(5), [1, -1]), np.reshape(np.arange(5), [1, -1])),
324                      rnn_cell.LSTMStateTuple(
325                          np.reshape(np.arange(3), [1, -1]),
326                          np.reshape(np.arange(3), [1, -1])),
327                      rnn_cell.LSTMStateTuple(
328                          np.reshape(np.arange(7), [1, -1]),
329                          np.reshape(np.arange(7), [1, -1])))
330    actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
331    flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state)
332
333    with self.cached_session() as sess:
334      (state_dict_val, actual_state_val, flattened_state_val) = sess.run(
335          [state_dict, actual_state, flattened_state])
336
337    def _recursive_assert_equal(x, y):
338      self.assertEqual(type(x), type(y))
339      if isinstance(x, (list, tuple)):
340        self.assertEqual(len(x), len(y))
341        for i, _ in enumerate(x):
342          _recursive_assert_equal(x[i], y[i])
343      elif isinstance(x, np.ndarray):
344        np.testing.assert_array_equal(x, y)
345      else:
346        self.fail('Unexpected type: {}'.format(type(x)))
347
348    for k in state_dict_val.keys():
349      np.testing.assert_array_almost_equal(
350          state_dict_val[k],
351          flattened_state_val[k],
352          err_msg='Wrong value for state component {}.'.format(k))
353    _recursive_assert_equal(expected_state, actual_state_val)
354
355  def testMultiRNNState(self):
356    """Test that state flattening/reconstruction works for `MultiRNNCell`."""
357    batch_size = 11
358    sequence_length = 16
359    train_steps = 5
360    cell_sizes = [4, 8, 7]
361    learning_rate = 0.1
362
363    def get_shift_input_fn(batch_size, sequence_length, seed=None):
364
365      def input_fn():
366        random_sequence = random_ops.random_uniform(
367            [batch_size, sequence_length + 1],
368            0,
369            2,
370            dtype=dtypes.int32,
371            seed=seed)
372        labels = array_ops.slice(random_sequence, [0, 0],
373                                 [batch_size, sequence_length])
374        inputs = array_ops.expand_dims(
375            math_ops.cast(
376                array_ops.slice(random_sequence, [0, 1],
377                                [batch_size, sequence_length]),
378                dtypes.float32), 2)
379        input_dict = {
380            dynamic_rnn_estimator._get_state_name(i): random_ops.random_uniform(
381                [batch_size, cell_size], seed=((i + 1) * seed))
382            for i, cell_size in enumerate([4, 4, 8, 8, 7, 7])
383        }
384        input_dict['inputs'] = inputs
385        return input_dict, labels
386
387      return input_fn
388
389    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
390    config = run_config.RunConfig(tf_random_seed=21212)
391    cell_type = 'lstm'
392    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
393        problem_type=constants.ProblemType.CLASSIFICATION,
394        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
395        num_classes=2,
396        num_units=cell_sizes,
397        sequence_feature_columns=seq_columns,
398        cell_type=cell_type,
399        learning_rate=learning_rate,
400        config=config,
401        predict_probabilities=True)
402
403    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
404    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)
405
406    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
407
408    prediction_dict = sequence_estimator.predict(
409        input_fn=eval_input_fn, as_iterable=False)
410    for i, state_size in enumerate([4, 4, 8, 8, 7, 7]):
411      state_piece = prediction_dict[dynamic_rnn_estimator._get_state_name(i)]
412      self.assertListEqual(list(state_piece.shape), [batch_size, state_size])
413
414  def testMultipleRuns(self):
415    """Tests resuming training by feeding state."""
416    cell_sizes = [4, 7]
417    batch_size = 11
418    learning_rate = 0.1
419    train_sequence_length = 21
420    train_steps = 121
421    dropout_keep_probabilities = [0.5, 0.5, 0.5]
422    prediction_steps = [3, 2, 5, 11, 6]
423
424    def get_input_fn(batch_size, sequence_length, state_dict, starting_step=0):
425
426      def input_fn():
427        sequence = constant_op.constant(
428            [[(starting_step + i + j) % 2 for j in range(sequence_length + 1)]
429             for i in range(batch_size)],
430            dtype=dtypes.int32)
431        labels = array_ops.slice(sequence, [0, 0],
432                                 [batch_size, sequence_length])
433        inputs = array_ops.expand_dims(
434            math_ops.cast(
435                array_ops.slice(sequence, [0, 1], [batch_size, sequence_length
436                                                  ]),
437                dtypes.float32), 2)
438        input_dict = state_dict
439        input_dict['inputs'] = inputs
440        return input_dict, labels
441
442      return input_fn
443
444    seq_columns = [feature_column.real_valued_column('inputs', dimension=1)]
445    config = run_config.RunConfig(tf_random_seed=21212)
446
447    model_dir = tempfile.mkdtemp()
448    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
449        problem_type=constants.ProblemType.CLASSIFICATION,
450        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
451        num_classes=2,
452        sequence_feature_columns=seq_columns,
453        num_units=cell_sizes,
454        cell_type='lstm',
455        dropout_keep_probabilities=dropout_keep_probabilities,
456        learning_rate=learning_rate,
457        config=config,
458        model_dir=model_dir)
459
460    train_input_fn = get_input_fn(
461        batch_size, train_sequence_length, state_dict={})
462
463    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
464
465    def incremental_predict(estimator, increments):
466      """Run `estimator.predict` for `i` steps for `i` in `increments`."""
467      step = 0
468      incremental_state_dict = {}
469      for increment in increments:
470        input_fn = get_input_fn(
471            batch_size,
472            increment,
473            state_dict=incremental_state_dict,
474            starting_step=step)
475        prediction_dict = estimator.predict(
476            input_fn=input_fn, as_iterable=False)
477        step += increment
478        incremental_state_dict = {
479            k: v
480            for (k, v) in prediction_dict.items()
481            if k.startswith(rnn_common.RNNKeys.STATE_PREFIX)
482        }
483      return prediction_dict
484
485    pred_all_at_once = incremental_predict(sequence_estimator,
486                                           [sum(prediction_steps)])
487    pred_step_by_step = incremental_predict(sequence_estimator,
488                                            prediction_steps)
489
490    # Check that the last `prediction_steps[-1]` steps give the same
491    # predictions.
492    np.testing.assert_array_equal(
493        pred_all_at_once[prediction_key.PredictionKey.CLASSES]
494        [:, -1 * prediction_steps[-1]:],
495        pred_step_by_step[prediction_key.PredictionKey.CLASSES],
496        err_msg='Mismatch on last {} predictions.'.format(prediction_steps[-1]))
497    # Check that final states are identical.
498    for k, v in pred_all_at_once.items():
499      if k.startswith(rnn_common.RNNKeys.STATE_PREFIX):
500        np.testing.assert_array_equal(
501            v, pred_step_by_step[k], err_msg='Mismatch on state {}.'.format(k))
502
503
504# TODO(jamieas): move all tests below to a benchmark test.
505class DynamicRNNEstimatorLearningTest(test.TestCase):
506  """Learning tests for dynamic RNN Estimators."""
507
508  def testLearnSineFunction(self):
509    """Tests learning a sine function."""
510    batch_size = 8
511    sequence_length = 64
512    train_steps = 200
513    eval_steps = 20
514    cell_size = [4]
515    learning_rate = 0.1
516    loss_threshold = 0.02
517
518    def get_sin_input_fn(batch_size, sequence_length, increment, seed=None):
519
520      def _sin_fn(x):
521        ranger = math_ops.linspace(
522            array_ops.reshape(x[0], []), (sequence_length - 1) * increment,
523            sequence_length + 1)
524        return math_ops.sin(ranger)
525
526      def input_fn():
527        starts = random_ops.random_uniform(
528            [batch_size], maxval=(2 * np.pi), seed=seed)
529        sin_curves = map_fn.map_fn(
530            _sin_fn, (starts,), dtype=dtypes.float32)
531        inputs = array_ops.expand_dims(
532            array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]),
533            2)
534        labels = array_ops.slice(sin_curves, [0, 1],
535                                 [batch_size, sequence_length])
536        return {'inputs': inputs}, labels
537
538      return input_fn
539
540    seq_columns = [
541        feature_column.real_valued_column(
542            'inputs', dimension=cell_size[0])
543    ]
544    config = run_config.RunConfig(tf_random_seed=1234)
545    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
546        problem_type=constants.ProblemType.LINEAR_REGRESSION,
547        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
548        num_units=cell_size,
549        sequence_feature_columns=seq_columns,
550        learning_rate=learning_rate,
551        dropout_keep_probabilities=[0.9, 0.9],
552        config=config)
553
554    train_input_fn = get_sin_input_fn(
555        batch_size, sequence_length, np.pi / 32, seed=1234)
556    eval_input_fn = get_sin_input_fn(
557        batch_size, sequence_length, np.pi / 32, seed=4321)
558
559    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
560    loss = sequence_estimator.evaluate(
561        input_fn=eval_input_fn, steps=eval_steps)['loss']
562    self.assertLess(loss, loss_threshold,
563                    'Loss should be less than {}; got {}'.format(loss_threshold,
564                                                                 loss))
565
566  def testLearnShiftByOne(self):
567    """Tests that learning a 'shift-by-one' example.
568
569    Each label sequence consists of the input sequence 'shifted' by one place.
570    The RNN must learn to 'remember' the previous input.
571    """
572    batch_size = 16
573    sequence_length = 32
574    train_steps = 200
575    eval_steps = 20
576    cell_size = 4
577    learning_rate = 0.3
578    accuracy_threshold = 0.9
579
580    def get_shift_input_fn(batch_size, sequence_length, seed=None):
581
582      def input_fn():
583        random_sequence = random_ops.random_uniform(
584            [batch_size, sequence_length + 1],
585            0,
586            2,
587            dtype=dtypes.int32,
588            seed=seed)
589        labels = array_ops.slice(random_sequence, [0, 0],
590                                 [batch_size, sequence_length])
591        inputs = array_ops.expand_dims(
592            math_ops.cast(
593                array_ops.slice(random_sequence, [0, 1],
594                                [batch_size, sequence_length]),
595                dtypes.float32),
596            2)
597        return {'inputs': inputs}, labels
598
599      return input_fn
600
601    seq_columns = [
602        feature_column.real_valued_column(
603            'inputs', dimension=cell_size)
604    ]
605    config = run_config.RunConfig(tf_random_seed=21212)
606    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
607        problem_type=constants.ProblemType.CLASSIFICATION,
608        prediction_type=rnn_common.PredictionType.MULTIPLE_VALUE,
609        num_classes=2,
610        num_units=cell_size,
611        sequence_feature_columns=seq_columns,
612        learning_rate=learning_rate,
613        config=config,
614        predict_probabilities=True)
615
616    train_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=12321)
617    eval_input_fn = get_shift_input_fn(batch_size, sequence_length, seed=32123)
618
619    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
620
621    evaluation = sequence_estimator.evaluate(
622        input_fn=eval_input_fn, steps=eval_steps)
623    accuracy = evaluation['accuracy']
624    self.assertGreater(accuracy, accuracy_threshold,
625                       'Accuracy should be higher than {}; got {}'.format(
626                           accuracy_threshold, accuracy))
627
628    # Testing `predict` when `predict_probabilities=True`.
629    prediction_dict = sequence_estimator.predict(
630        input_fn=eval_input_fn, as_iterable=False)
631    self.assertListEqual(
632        sorted(list(prediction_dict.keys())),
633        sorted([
634            prediction_key.PredictionKey.CLASSES,
635            prediction_key.PredictionKey.PROBABILITIES,
636            dynamic_rnn_estimator._get_state_name(0)
637        ]))
638    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
639    probabilities = prediction_dict[
640        prediction_key.PredictionKey.PROBABILITIES]
641    self.assertListEqual(list(predictions.shape), [batch_size, sequence_length])
642    self.assertListEqual(
643        list(probabilities.shape), [batch_size, sequence_length, 2])
644
645  def testLearnMean(self):
646    """Test learning to calculate a mean."""
647    batch_size = 16
648    sequence_length = 3
649    train_steps = 200
650    eval_steps = 20
651    cell_type = 'basic_rnn'
652    cell_size = 8
653    optimizer_type = 'Momentum'
654    learning_rate = 0.1
655    momentum = 0.9
656    loss_threshold = 0.1
657
658    def get_mean_input_fn(batch_size, sequence_length, seed=None):
659
660      def input_fn():
661        # Create examples by choosing 'centers' and adding uniform noise.
662        centers = math_ops.matmul(
663            random_ops.random_uniform(
664                [batch_size, 1], -0.75, 0.75, dtype=dtypes.float32, seed=seed),
665            array_ops.ones([1, sequence_length]))
666        noise = random_ops.random_uniform(
667            [batch_size, sequence_length],
668            -0.25,
669            0.25,
670            dtype=dtypes.float32,
671            seed=seed)
672        sequences = centers + noise
673
674        inputs = array_ops.expand_dims(sequences, 2)
675        labels = math_ops.reduce_mean(sequences, axis=[1])
676        return {'inputs': inputs}, labels
677
678      return input_fn
679
680    seq_columns = [
681        feature_column.real_valued_column(
682            'inputs', dimension=cell_size)
683    ]
684    config = run_config.RunConfig(tf_random_seed=6)
685    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
686        problem_type=constants.ProblemType.LINEAR_REGRESSION,
687        prediction_type=rnn_common.PredictionType.SINGLE_VALUE,
688        num_units=cell_size,
689        sequence_feature_columns=seq_columns,
690        cell_type=cell_type,
691        optimizer=optimizer_type,
692        learning_rate=learning_rate,
693        momentum=momentum,
694        config=config)
695
696    train_input_fn = get_mean_input_fn(batch_size, sequence_length, 121)
697    eval_input_fn = get_mean_input_fn(batch_size, sequence_length, 212)
698
699    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
700    evaluation = sequence_estimator.evaluate(
701        input_fn=eval_input_fn, steps=eval_steps)
702    loss = evaluation['loss']
703    self.assertLess(loss, loss_threshold,
704                    'Loss should be less than {}; got {}'.format(loss_threshold,
705                                                                 loss))
706
707  def DISABLED_testLearnMajority(self):
708    """Test learning the 'majority' function."""
709    batch_size = 16
710    sequence_length = 7
711    train_steps = 500
712    eval_steps = 20
713    cell_type = 'lstm'
714    cell_size = 4
715    optimizer_type = 'Momentum'
716    learning_rate = 2.0
717    momentum = 0.9
718    accuracy_threshold = 0.6
719
720    def get_majority_input_fn(batch_size, sequence_length, seed=None):
721      random_seed.set_random_seed(seed)
722
723      def input_fn():
724        random_sequence = random_ops.random_uniform(
725            [batch_size, sequence_length], 0, 2, dtype=dtypes.int32, seed=seed)
726        inputs = array_ops.expand_dims(
727            math_ops.cast(random_sequence, dtypes.float32), 2)
728        labels = math_ops.cast(
729            array_ops.squeeze(
730                math_ops.reduce_sum(inputs, axis=[1]) > (
731                    sequence_length / 2.0)),
732            dtypes.int32)
733        return {'inputs': inputs}, labels
734
735      return input_fn
736
737    seq_columns = [
738        feature_column.real_valued_column(
739            'inputs', dimension=cell_size)
740    ]
741    config = run_config.RunConfig(tf_random_seed=77)
742    sequence_estimator = dynamic_rnn_estimator.DynamicRnnEstimator(
743        problem_type=constants.ProblemType.CLASSIFICATION,
744        prediction_type=rnn_common.PredictionType.SINGLE_VALUE,
745        num_classes=2,
746        num_units=cell_size,
747        sequence_feature_columns=seq_columns,
748        cell_type=cell_type,
749        optimizer=optimizer_type,
750        learning_rate=learning_rate,
751        momentum=momentum,
752        config=config,
753        predict_probabilities=True)
754
755    train_input_fn = get_majority_input_fn(batch_size, sequence_length, 1111)
756    eval_input_fn = get_majority_input_fn(batch_size, sequence_length, 2222)
757
758    sequence_estimator.fit(input_fn=train_input_fn, steps=train_steps)
759    evaluation = sequence_estimator.evaluate(
760        input_fn=eval_input_fn, steps=eval_steps)
761    accuracy = evaluation['accuracy']
762    self.assertGreater(accuracy, accuracy_threshold,
763                       'Accuracy should be higher than {}; got {}'.format(
764                           accuracy_threshold, accuracy))
765
766    # Testing `predict` when `predict_probabilities=True`.
767    prediction_dict = sequence_estimator.predict(
768        input_fn=eval_input_fn, as_iterable=False)
769    self.assertListEqual(
770        sorted(list(prediction_dict.keys())),
771        sorted([
772            prediction_key.PredictionKey.CLASSES,
773            prediction_key.PredictionKey.PROBABILITIES,
774            dynamic_rnn_estimator._get_state_name(0),
775            dynamic_rnn_estimator._get_state_name(1)
776        ]))
777    predictions = prediction_dict[prediction_key.PredictionKey.CLASSES]
778    probabilities = prediction_dict[
779        prediction_key.PredictionKey.PROBABILITIES]
780    self.assertListEqual(list(predictions.shape), [batch_size])
781    self.assertListEqual(list(probabilities.shape), [batch_size, 2])
782
783
784if __name__ == '__main__':
785  test.main()
786