• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Functional RNN."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23
24from tensorflow.contrib.recurrent.python.ops import functional_rnn
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import rnn as rnn_lib
31from tensorflow.python.ops import rnn_cell_impl
32from tensorflow.python.ops import variables
33import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
34import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
35from tensorflow.python.platform import test as test_lib
36from tensorflow.python.platform import tf_logging as logging
37
38
39def _CreateStackedLstmCell(*cell_sizes):
40  subcells = [rnn_cell_impl.LSTMCell(cell_size) for cell_size in cell_sizes]
41  return rnn_cell_impl.MultiRNNCell(subcells)
42
43
44class FunctionalRnnTest(test_util.TensorFlowTestCase):
45
46  _BATCH_SIZE = 3
47  _TOTAL_TIME = 5
48  _INPUT_SIZE = 11
49  _NUM_UNITS = 7
50
51  # Set this to some output if you want to use it.
52  _LSTM_GRAPH_DEF_FILEPATH = None
53
54  _CELLDEFS = {
55      'gru': (rnn_cell_impl.GRUCell, [_NUM_UNITS]),
56      'lstm': (rnn_cell_impl.LSTMCell, [_NUM_UNITS]),
57      'stacked_lstm': (_CreateStackedLstmCell, [_NUM_UNITS] * 3)
58  }
59
60  def _CreateCell(self, celldef_name):
61    func, args = self._CELLDEFS[celldef_name]
62    return func(*args)
63
64  def _CreateInputs(self, time_major=False):
65    if time_major:
66      inputs = np.random.random([
67          FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE,
68          FunctionalRnnTest._INPUT_SIZE
69      ])
70    else:
71      inputs = np.random.random([
72          FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME,
73          FunctionalRnnTest._INPUT_SIZE
74      ])
75    # Always leave one time slot empty, to check max_length behavior.
76    sequence_length = np.random.randint(
77        0, high=FunctionalRnnTest._TOTAL_TIME - 1,
78        size=FunctionalRnnTest._BATCH_SIZE,
79        dtype=np.int)
80    return (inputs, sequence_length)
81
82  def _CreateSymmetricInputs(self):
83    # total time = batch size
84    inputs = np.zeros(
85        (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE,
86         FunctionalRnnTest._INPUT_SIZE))
87    for i in range(FunctionalRnnTest._BATCH_SIZE):
88      for j in range(i, FunctionalRnnTest._BATCH_SIZE):
89        inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE])
90        inputs[j][i] = inputs[i][j]
91
92    # Always leave one time slot empty, to check max_length behavior.
93    sequence_length = np.random.randint(
94        0,
95        high=FunctionalRnnTest._BATCH_SIZE - 1,
96        size=FunctionalRnnTest._BATCH_SIZE,
97        dtype=np.int)
98    return (inputs, sequence_length)
99
100  def _CreateRnnGraph(self,
101                      create_rnn_computation_func,
102                      cell,
103                      tf_inputs,
104                      tf_sequence_length,
105                      is_bidirectional,
106                      initial_state=None,
107                      time_major=None,
108                      scope=None):
109    if is_bidirectional:
110      tf_result = create_rnn_computation_func(
111          cell_fw=cell,
112          cell_bw=cell,
113          inputs=tf_inputs,
114          sequence_length=tf_sequence_length,
115          dtype=dtypes.float32,
116          time_major=time_major,
117          scope=scope)
118    else:
119      tf_result = create_rnn_computation_func(
120          cell=cell,
121          inputs=tf_inputs,
122          sequence_length=tf_sequence_length,
123          initial_state=initial_state,
124          dtype=dtypes.float32,
125          time_major=time_major,
126          scope=scope)
127    grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
128    return {'inference': tf_result, 'grad': grad}
129
130  def _MaybeResetVariables(self, variable_cache, sess, var_list):
131    """Possibly resets the variables to a previously seen value."""
132    reset_ops = []
133    fetches = []
134    for var in var_list:
135      if var.name in variable_cache:
136        reset_ops += [var.assign(variable_cache[var.name])]
137      else:
138        fetches += [(var.name, var)]
139    if reset_ops:
140      sess.run(reset_ops)
141    if fetches:
142      val = sess.run(dict(fetches))
143      for n, v in val.items():
144        assert n not in variable_cache
145        variable_cache[n] = v
146
147  def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
148              is_dynamic, time_major=None, is_bidirectional=False):
149    with ops.Graph().as_default() as graph:
150      tf_inputs = array_ops.placeholder(
151          dtypes.float32, shape=numpy_inputs.shape)
152      tf_slen = array_ops.placeholder(dtypes.int32)
153      feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
154      cell = self._CreateCell(cell_name)
155      if is_dynamic:
156        if is_bidirectional:
157          fn = rnn_lib.bidirectional_dynamic_rnn
158        else:
159          fn = rnn_lib.dynamic_rnn
160      else:
161        if is_bidirectional:
162          fn = functional_rnn.bidirectional_functional_rnn
163        else:
164          fn = functional_rnn.functional_rnn
165
166      fetches = self._CreateRnnGraph(
167          fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
168      with self.session(graph=graph) as sess:
169        sess.run(variables.global_variables_initializer())
170        # Note that cell.trainable_variables it not always set.
171        self._MaybeResetVariables(variable_cache, sess,
172                                  variables.trainable_variables())
173        val = sess.run(fetches, feed_dict=feeds)
174      graph_def = graph.as_graph_def()
175      return graph_def, val
176
177  def testRunLstm(self):
178    """Runs a simple LSTM. Does not check output."""
179    np_inputs, np_slen = self._CreateInputs()
180    var_cache = {}
181    graphdef, _ = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
182    logging.info('graphdef: %s', graphdef)
183    if self._LSTM_GRAPH_DEF_FILEPATH:
184      with open(self._LSTM_GRAPH_DEF_FILEPATH, 'w') as f:
185        f.write(str(graphdef))
186
187  def testLstm(self):
188    """Checks an LSTM against the reference implementation."""
189    np_inputs, np_slen = self._CreateInputs()
190    var_cache = {}
191    _, func_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
192    _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, True)
193    self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
194    self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
195
196  def testGru(self):
197    """Checks a GRU cell against the reference implementation."""
198    np_inputs, np_slen = self._CreateInputs()
199    var_cache = {}
200    _, func_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, False)
201    _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, True)
202    self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
203    self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
204
205  def testStackedLstm(self):
206    """Checks a stacked LSTM cell against the reference implementation."""
207    np_inputs, np_slen = self._CreateInputs()
208    var_cache = {}
209    args = [np_inputs, np_slen, 'stacked_lstm', var_cache]
210    _, func_rnn = self._RunRnn(*(args + [False]))
211    _, dyn_rnn = self._RunRnn(*(args + [True]))
212    self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
213    self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
214
215  def testLstmWithTimeMajorInputs(self):
216    """Checks an LSTM against the reference implementation, with time_major."""
217    time_major = True
218    np_inputs, np_slen = self._CreateInputs(time_major=True)
219    var_cache = {}
220    args = [np_inputs, np_slen, 'lstm', var_cache]
221    _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major)
222    _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major)
223    self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
224    self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
225
226  def testBidirectionalLstmWithTimeMajorInputs(self):
227    """Checks a bi-directional LSTM with time-major inputs."""
228    time_major = True
229    np_inputs, np_slen = self._CreateInputs(time_major)
230    var_cache = {}
231    args = [np_inputs, np_slen, 'lstm', var_cache]
232    _, func_rnn = self._RunRnn(
233        *(args + [False]), time_major=time_major, is_bidirectional=True)
234    _, dyn_rnn = self._RunRnn(
235        *(args + [True]), time_major=time_major, is_bidirectional=True)
236    self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
237    # TODO(b/112170761): comment out this line after the bug is fixed.
238    # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
239
240  def testBidirectionalLstm(self):
241    """Checks time-major and batch-major rnn produce consistent results."""
242    time_major_inputs, np_slen = self._CreateInputs(True)
243    batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2])
244    var_cache = {}
245    args = [np_slen, 'lstm', var_cache, False]
246    _, time_major_rnn = self._RunRnn(
247        *([time_major_inputs] + args), time_major=True, is_bidirectional=True)
248    _, batch_major_rnn = self._RunRnn(
249        *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True)
250    # Convert the batch-major outputs to be time-major before the comparasion.
251    outputs, state = batch_major_rnn['inference']
252    outputs = [np.transpose(x, [1, 0, 2]) for x in outputs]
253    batch_major_rnn['inference'] = [outputs, state]
254    self.assertAllClose(time_major_rnn['inference'],
255                        batch_major_rnn['inference'])
256    self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad'])
257
258  def testBidirectionalLstmWithSymmetricInputs(self):
259    """Checks a bi-directional LSTM with symmetric inputs.
260
261    time-major and batch-major rnn produce the same result with symmetric
262    inputs.
263    """
264    np_inputs, np_slen = self._CreateSymmetricInputs()
265    var_cache = {}
266    args = [np_inputs, np_slen, 'lstm', var_cache]
267    _, time_major_func_rnn = self._RunRnn(
268        *(args + [False]), time_major=True, is_bidirectional=True)
269    _, batch_major_func_rnn = self._RunRnn(
270        *(args + [False]), time_major=False, is_bidirectional=True)
271    _, time_major_dyn_rnn = self._RunRnn(
272        *(args + [True]), time_major=True, is_bidirectional=True)
273    _, batch_major_dyn_rnn = self._RunRnn(
274        *(args + [True]), time_major=False, is_bidirectional=True)
275    self.assertAllClose(time_major_func_rnn['inference'],
276                        batch_major_func_rnn['inference'])
277    self.assertAllClose(time_major_func_rnn['grad'],
278                        batch_major_func_rnn['grad'])
279    self.assertAllClose(time_major_dyn_rnn['inference'],
280                        batch_major_dyn_rnn['inference'])
281    self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad'])
282    self.assertAllClose(time_major_func_rnn['inference'],
283                        batch_major_dyn_rnn['inference'])
284    self.assertAllClose(time_major_func_rnn['grad'],
285                        batch_major_dyn_rnn['grad'])
286
287
288if __name__ == '__main__':
289  test_lib.main()
290