1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for layers.rnn_common.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.learn.python.learn.estimators import rnn_common 24from tensorflow.python.client import session 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.platform import test 28 29 30class RnnCommonTest(test.TestCase): 31 32 def testMaskActivationsAndLabels(self): 33 """Test `mask_activations_and_labels`.""" 34 batch_size = 4 35 padded_length = 6 36 num_classes = 4 37 np.random.seed(1234) 38 sequence_length = np.random.randint(0, padded_length + 1, batch_size) 39 activations = np.random.rand(batch_size, padded_length, num_classes) 40 labels = np.random.randint(0, num_classes, [batch_size, padded_length]) 41 (activations_masked_t, 42 labels_masked_t) = rnn_common.mask_activations_and_labels( 43 constant_op.constant(activations, dtype=dtypes.float32), 44 constant_op.constant(labels, dtype=dtypes.int32), 45 constant_op.constant(sequence_length, dtype=dtypes.int32)) 46 47 with self.cached_session() as sess: 48 activations_masked, labels_masked = sess.run( 49 [activations_masked_t, labels_masked_t]) 50 51 expected_activations_shape = [sum(sequence_length), num_classes] 52 np.testing.assert_equal( 53 expected_activations_shape, activations_masked.shape, 54 'Wrong activations shape. Expected {}; got {}.'.format( 55 expected_activations_shape, activations_masked.shape)) 56 57 expected_labels_shape = [sum(sequence_length)] 58 np.testing.assert_equal(expected_labels_shape, labels_masked.shape, 59 'Wrong labels shape. Expected {}; got {}.'.format( 60 expected_labels_shape, labels_masked.shape)) 61 masked_index = 0 62 for i in range(batch_size): 63 for j in range(sequence_length[i]): 64 actual_activations = activations_masked[masked_index] 65 expected_activations = activations[i, j, :] 66 np.testing.assert_almost_equal( 67 expected_activations, 68 actual_activations, 69 err_msg='Unexpected logit value at index [{}, {}, :].' 70 ' Expected {}; got {}.'.format(i, j, expected_activations, 71 actual_activations)) 72 73 actual_labels = labels_masked[masked_index] 74 expected_labels = labels[i, j] 75 np.testing.assert_almost_equal( 76 expected_labels, 77 actual_labels, 78 err_msg='Unexpected logit value at index [{}, {}].' 79 ' Expected {}; got {}.'.format(i, j, expected_labels, 80 actual_labels)) 81 masked_index += 1 82 83 def testSelectLastActivations(self): 84 """Test `select_last_activations`.""" 85 batch_size = 4 86 padded_length = 6 87 num_classes = 4 88 np.random.seed(4444) 89 sequence_length = np.random.randint(0, padded_length + 1, batch_size) 90 activations = np.random.rand(batch_size, padded_length, num_classes) 91 last_activations_t = rnn_common.select_last_activations( 92 constant_op.constant(activations, dtype=dtypes.float32), 93 constant_op.constant(sequence_length, dtype=dtypes.int32)) 94 95 with session.Session() as sess: 96 last_activations = sess.run(last_activations_t) 97 98 expected_activations_shape = [batch_size, num_classes] 99 np.testing.assert_equal( 100 expected_activations_shape, last_activations.shape, 101 'Wrong activations shape. Expected {}; got {}.'.format( 102 expected_activations_shape, last_activations.shape)) 103 104 for i in range(batch_size): 105 actual_activations = last_activations[i, :] 106 expected_activations = activations[i, sequence_length[i] - 1, :] 107 np.testing.assert_almost_equal( 108 expected_activations, 109 actual_activations, 110 err_msg='Unexpected logit value at index [{}, :].' 111 ' Expected {}; got {}.'.format(i, expected_activations, 112 actual_activations)) 113 114 115if __name__ == '__main__': 116 test.main() 117