• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sequence-to-sequence tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.learn.python.learn import ops
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import rnn_cell
27from tensorflow.python.platform import test
28
29
30class Seq2SeqOpsTest(test.TestCase):
31  """Sequence-to-sequence tests."""
32
33  def test_sequence_classifier(self):
34    with self.cached_session() as session:
35      decoding = [
36          array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
37      ]
38      labels = [array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)]
39      sampling_decoding = [
40          array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
41      ]
42      predictions, loss = ops.sequence_classifier(decoding, labels,
43                                                  sampling_decoding)
44      pred, cost = session.run(
45          [predictions, loss],
46          feed_dict={
47              decoding[0].name: [[0.1, 0.9], [0.7, 0.3]],
48              decoding[1].name: [[0.9, 0.1], [0.8, 0.2]],
49              decoding[2].name: [[0.5, 0.5], [0.4, 0.6]],
50              labels[0].name: [[1, 0], [0, 1]],
51              labels[1].name: [[1, 0], [0, 1]],
52              labels[2].name: [[1, 0], [0, 1]],
53              sampling_decoding[0].name: [[0.1, 0.9], [0.7, 0.3]],
54              sampling_decoding[1].name: [[0.9, 0.1], [0.8, 0.2]],
55              sampling_decoding[2].name: [[0.5, 0.5], [0.4, 0.6]],
56          })
57    self.assertAllEqual(pred.argmax(axis=2), [[1, 0, 0], [0, 0, 1]])
58    self.assertAllClose(cost, 4.7839908599)
59
60  def test_seq2seq_inputs(self):
61    inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
62    out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
63    with self.cached_session() as session:
64      x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
65      y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
66      in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
67      enc_inp = session.run(in_x, feed_dict={x.name: inp})
68      dec_inp = session.run(in_y, feed_dict={x.name: inp, y.name: out})
69      dec_out = session.run(out_y, feed_dict={x.name: inp, y.name: out})
70    # Swaps from batch x len x height to list of len of batch x height.
71    self.assertAllEqual(enc_inp, np.swapaxes(inp, 0, 1))
72    self.assertAllEqual(dec_inp, [[[0, 0, 0], [0, 0, 0]],
73                                  [[0, 1, 0], [1, 0, 0]],
74                                  [[1, 0, 0], [0, 1, 0]]])
75    self.assertAllEqual(dec_out, [[[0, 1, 0], [1, 0, 0]],
76                                  [[1, 0, 0], [0, 1, 0]],
77                                  [[0, 0, 0], [0, 0, 0]]])
78
79  def test_rnn_decoder(self):
80    with self.cached_session():
81      decoder_inputs = [
82          array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
83      ]
84      encoding = array_ops.placeholder(dtypes.float32, [2, 2])
85      cell = rnn_cell.GRUCell(2)
86      outputs, states, sampling_outputs, sampling_states = (
87          ops.rnn_decoder(decoder_inputs, encoding, cell))
88      self.assertEqual(len(outputs), 3)
89      self.assertEqual(outputs[0].get_shape(), [2, 2])
90      self.assertEqual(len(states), 4)
91      self.assertEqual(states[0].get_shape(), [2, 2])
92      self.assertEqual(len(sampling_outputs), 3)
93      self.assertEqual(sampling_outputs[0].get_shape(), [2, 2])
94      self.assertEqual(len(sampling_states), 4)
95      self.assertEqual(sampling_states[0].get_shape(), [2, 2])
96
97
98if __name__ == "__main__":
99  test.main()
100