1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sequence-to-sequence tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.learn.python.learn import ops 24from tensorflow.python.framework import dtypes 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import rnn_cell 27from tensorflow.python.platform import test 28 29 30class Seq2SeqOpsTest(test.TestCase): 31 """Sequence-to-sequence tests.""" 32 33 def test_sequence_classifier(self): 34 with self.cached_session() as session: 35 decoding = [ 36 array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3) 37 ] 38 labels = [array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)] 39 sampling_decoding = [ 40 array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3) 41 ] 42 predictions, loss = ops.sequence_classifier(decoding, labels, 43 sampling_decoding) 44 pred, cost = session.run( 45 [predictions, loss], 46 feed_dict={ 47 decoding[0].name: [[0.1, 0.9], [0.7, 0.3]], 48 decoding[1].name: [[0.9, 0.1], [0.8, 0.2]], 49 decoding[2].name: [[0.5, 0.5], [0.4, 0.6]], 50 labels[0].name: [[1, 0], [0, 1]], 51 labels[1].name: [[1, 0], [0, 1]], 52 labels[2].name: [[1, 0], [0, 1]], 53 sampling_decoding[0].name: [[0.1, 0.9], [0.7, 0.3]], 54 sampling_decoding[1].name: [[0.9, 0.1], [0.8, 0.2]], 55 sampling_decoding[2].name: [[0.5, 0.5], [0.4, 0.6]], 56 }) 57 self.assertAllEqual(pred.argmax(axis=2), [[1, 0, 0], [0, 0, 1]]) 58 self.assertAllClose(cost, 4.7839908599) 59 60 def test_seq2seq_inputs(self): 61 inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]]) 62 out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]]) 63 with self.cached_session() as session: 64 x = array_ops.placeholder(dtypes.float32, [2, 3, 2]) 65 y = array_ops.placeholder(dtypes.float32, [2, 2, 3]) 66 in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2) 67 enc_inp = session.run(in_x, feed_dict={x.name: inp}) 68 dec_inp = session.run(in_y, feed_dict={x.name: inp, y.name: out}) 69 dec_out = session.run(out_y, feed_dict={x.name: inp, y.name: out}) 70 # Swaps from batch x len x height to list of len of batch x height. 71 self.assertAllEqual(enc_inp, np.swapaxes(inp, 0, 1)) 72 self.assertAllEqual(dec_inp, [[[0, 0, 0], [0, 0, 0]], 73 [[0, 1, 0], [1, 0, 0]], 74 [[1, 0, 0], [0, 1, 0]]]) 75 self.assertAllEqual(dec_out, [[[0, 1, 0], [1, 0, 0]], 76 [[1, 0, 0], [0, 1, 0]], 77 [[0, 0, 0], [0, 0, 0]]]) 78 79 def test_rnn_decoder(self): 80 with self.cached_session(): 81 decoder_inputs = [ 82 array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3) 83 ] 84 encoding = array_ops.placeholder(dtypes.float32, [2, 2]) 85 cell = rnn_cell.GRUCell(2) 86 outputs, states, sampling_outputs, sampling_states = ( 87 ops.rnn_decoder(decoder_inputs, encoding, cell)) 88 self.assertEqual(len(outputs), 3) 89 self.assertEqual(outputs[0].get_shape(), [2, 2]) 90 self.assertEqual(len(states), 4) 91 self.assertEqual(states[0].get_shape(), [2, 2]) 92 self.assertEqual(len(sampling_outputs), 3) 93 self.assertEqual(sampling_outputs[0].get_shape(), [2, 2]) 94 self.assertEqual(len(sampling_states), 4) 95 self.assertEqual(sampling_states[0].get_shape(), [2, 2]) 96 97 98if __name__ == "__main__": 99 test.main() 100