• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for layers.rnn_common."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.learn.python.learn.estimators import rnn_common
24from tensorflow.python.client import session
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.platform import test
28
29
30class RnnCommonTest(test.TestCase):
31
32  def testMaskActivationsAndLabels(self):
33    """Test `mask_activations_and_labels`."""
34    batch_size = 4
35    padded_length = 6
36    num_classes = 4
37    np.random.seed(1234)
38    sequence_length = np.random.randint(0, padded_length + 1, batch_size)
39    activations = np.random.rand(batch_size, padded_length, num_classes)
40    labels = np.random.randint(0, num_classes, [batch_size, padded_length])
41    (activations_masked_t,
42     labels_masked_t) = rnn_common.mask_activations_and_labels(
43         constant_op.constant(activations, dtype=dtypes.float32),
44         constant_op.constant(labels, dtype=dtypes.int32),
45         constant_op.constant(sequence_length, dtype=dtypes.int32))
46
47    with self.cached_session() as sess:
48      activations_masked, labels_masked = sess.run(
49          [activations_masked_t, labels_masked_t])
50
51    expected_activations_shape = [sum(sequence_length), num_classes]
52    np.testing.assert_equal(
53        expected_activations_shape, activations_masked.shape,
54        'Wrong activations shape. Expected {}; got {}.'.format(
55            expected_activations_shape, activations_masked.shape))
56
57    expected_labels_shape = [sum(sequence_length)]
58    np.testing.assert_equal(expected_labels_shape, labels_masked.shape,
59                            'Wrong labels shape. Expected {}; got {}.'.format(
60                                expected_labels_shape, labels_masked.shape))
61    masked_index = 0
62    for i in range(batch_size):
63      for j in range(sequence_length[i]):
64        actual_activations = activations_masked[masked_index]
65        expected_activations = activations[i, j, :]
66        np.testing.assert_almost_equal(
67            expected_activations,
68            actual_activations,
69            err_msg='Unexpected logit value at index [{}, {}, :].'
70            '  Expected {}; got {}.'.format(i, j, expected_activations,
71                                            actual_activations))
72
73        actual_labels = labels_masked[masked_index]
74        expected_labels = labels[i, j]
75        np.testing.assert_almost_equal(
76            expected_labels,
77            actual_labels,
78            err_msg='Unexpected logit value at index [{}, {}].'
79            ' Expected {}; got {}.'.format(i, j, expected_labels,
80                                           actual_labels))
81        masked_index += 1
82
83  def testSelectLastActivations(self):
84    """Test `select_last_activations`."""
85    batch_size = 4
86    padded_length = 6
87    num_classes = 4
88    np.random.seed(4444)
89    sequence_length = np.random.randint(0, padded_length + 1, batch_size)
90    activations = np.random.rand(batch_size, padded_length, num_classes)
91    last_activations_t = rnn_common.select_last_activations(
92        constant_op.constant(activations, dtype=dtypes.float32),
93        constant_op.constant(sequence_length, dtype=dtypes.int32))
94
95    with session.Session() as sess:
96      last_activations = sess.run(last_activations_t)
97
98    expected_activations_shape = [batch_size, num_classes]
99    np.testing.assert_equal(
100        expected_activations_shape, last_activations.shape,
101        'Wrong activations shape. Expected {}; got {}.'.format(
102            expected_activations_shape, last_activations.shape))
103
104    for i in range(batch_size):
105      actual_activations = last_activations[i, :]
106      expected_activations = activations[i, sequence_length[i] - 1, :]
107      np.testing.assert_almost_equal(
108          expected_activations,
109          actual_activations,
110          err_msg='Unexpected logit value at index [{}, :].'
111          '  Expected {}; got {}.'.format(i, expected_activations,
112                                          actual_activations))
113
114
115if __name__ == '__main__':
116  test.main()
117