• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for rnn module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import time
23import timeit
24
25import numpy as np
26
27from six.moves import xrange  # pylint: disable=redefined-builtin
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.client import session
30from tensorflow.python.eager import context
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops as ops_lib
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import test_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import gradients_impl
39from tensorflow.python.ops import init_ops
40from tensorflow.python.ops import rnn
41from tensorflow.python.ops import rnn_cell_impl
42from tensorflow.python.ops import tensor_array_ops
43from tensorflow.python.ops import variables as variables_lib
44import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
45import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
46import tensorflow.python.ops.sparse_grad  # pylint: disable=unused-import
47import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
48from tensorflow.python.platform import test
49from tensorflow.python.training import saver
50
51
52class Plus1RNNCell(rnn_cell_impl.RNNCell):
53  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
54
55  @property
56  def output_size(self):
57    return 5
58
59  @property
60  def state_size(self):
61    return 5
62
63  def call(self, input_, state, scope=None):
64    return (input_ + 1, state + 1)
65
66
67class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
68  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
69
70  @property
71  def output_size(self):
72    return 1
73
74  @property
75  def state_size(self):
76    return tensor_shape.TensorShape([])
77
78  def zero_state(self, batch_size, dtype):
79    return array_ops.zeros([], dtype=dtypes.int32)
80
81  def call(self, input_, state, scope=None):
82    return (input_, state + 1)
83
84
85class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell):
86  """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
87
88  @property
89  def output_size(self):
90    return  tensor_shape.TensorShape(1), tensor_shape.TensorShape((2))
91
92  @property
93  def state_size(self):
94    return tensor_shape.TensorShape([])
95
96  def zero_state(self, batch_size, dtype):
97    return array_ops.zeros([], dtype=dtypes.int32)
98
99  def call(self, input_, state, scope=None):
100    concatenated = array_ops.concat((input_, input_), axis=-1)
101    return (input_, concatenated), state + 1
102
103
104class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell):
105  """RNN Cell its state as a TensorArray."""
106
107  @property
108  def output_size(self):
109    return 1
110
111  @property
112  def state_size(self):
113    return (tensor_shape.TensorShape([]), ())
114
115  def zero_state(self, batch_size, dtype):
116    return (array_ops.zeros([], dtype=dtypes.int32),
117            tensor_array_ops.TensorArray(
118                dtype=dtype, size=0, dynamic_size=True))
119
120  def call(self, input_, state, scope=None):
121    new_array = state[1].write(state[0], input_)
122    return (input_, (state[0] + 1, new_array))
123
124
125class RNNTest(test.TestCase):
126
127  def setUp(self):
128    self._seed = 23489
129    np.random.seed(self._seed)
130
131  @test_util.run_in_graph_and_eager_modes
132  def testInvalidSequenceLengthShape(self):
133    cell = Plus1RNNCell()
134    if context.executing_eagerly():
135      inputs = [constant_op.constant(np.ones((3, 4)))]
136    else:
137      inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))]
138    with self.assertRaisesRegex(ValueError, "must be a vector"):
139      rnn.dynamic_rnn(
140          cell,
141          array_ops.stack(inputs),
142          dtype=dtypes.float32,
143          sequence_length=[[4]])
144
145  @test_util.run_in_graph_and_eager_modes
146  def testInvalidDtype(self):
147    if context.executing_eagerly():
148      inputs = np.zeros((3, 4, 5), dtype=np.int32)
149    else:
150      inputs = array_ops.placeholder(dtypes.int32, shape=(3, 4, 5))
151
152    cells = [
153        rnn_cell_impl.BasicRNNCell,
154        rnn_cell_impl.GRUCell,
155        rnn_cell_impl.BasicLSTMCell,
156        rnn_cell_impl.LSTMCell,
157    ]
158    for cell_cls in cells:
159      with self.cached_session():
160        with self.assertRaisesRegex(ValueError,
161                                    "RNN cell only supports floating"):
162          cell = cell_cls(2, dtype=dtypes.int32)
163          rnn.dynamic_rnn(cell, inputs, dtype=dtypes.int32)
164
165  @test_util.run_in_graph_and_eager_modes
166  def testBatchSizeFromInput(self):
167    cell = Plus1RNNCell()
168    in_eager_mode = context.executing_eagerly()
169    # With static batch size
170    if in_eager_mode:
171      inputs = np.zeros((3, 4, 5), dtype=np.float32)
172      initial_state = np.zeros((3, 5), dtype=np.float32)
173    else:
174      inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5))
175      initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5))
176
177    # - Without initial_state
178    outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
179    self.assertEqual(3, outputs.shape[0])
180    self.assertEqual(3, state.shape[0])
181
182    # - With initial_state
183    outputs, state = rnn.dynamic_rnn(
184        cell, inputs, initial_state=initial_state)
185    self.assertEqual(3, outputs.shape[0])
186    self.assertEqual(3, state.shape[0])
187
188    # Without static batch size
189    # Tensor shapes are fully determined with eager execution enabled,
190    # so only run this test for graph construction.
191    if not in_eager_mode:
192      inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5))
193      # - Without initial_state
194      outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
195      self.assertEqual(None, outputs.shape.dims[0].value)
196      self.assertEqual(None, state.shape.dims[0].value)
197      # - With initial_state
198      outputs, state = rnn.dynamic_rnn(
199          cell,
200          inputs,
201          initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5)))
202      self.assertEqual(None, outputs.shape.dims[0].value)
203      self.assertEqual(None, state.shape.dims[0].value)
204
205  @test_util.run_in_graph_and_eager_modes
206  def testScalarStateIsAccepted(self):
207    cell = ScalarStateRNNCell()
208    in_eager_mode = context.executing_eagerly()
209
210    if in_eager_mode:
211      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
212    else:
213      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
214
215    with self.cached_session() as sess:
216      outputs, state = rnn.dynamic_rnn(
217          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
218      if not in_eager_mode:
219        outputs, state = sess.run(
220            [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
221
222    self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
223    self.assertAllEqual(4, state)
224
225  @test_util.run_in_graph_and_eager_modes
226  def testUnbalancedOutputIsAccepted(self):
227    cell = UnbalancedOutputRNNCell()
228    in_eager_mode = context.executing_eagerly()
229
230    if in_eager_mode:
231      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
232    else:
233      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
234
235    with self.cached_session() as sess:
236      outputs, state = rnn.dynamic_rnn(
237          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
238      if not in_eager_mode:
239        outputs, state = sess.run(
240            [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
241
242    self.assertIsInstance(outputs, tuple)
243    self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0])
244    self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
245    self.assertAllEqual(4, state)
246
247  @test_util.assert_no_new_pyobjects_executing_eagerly
248  def testEagerMemory(self):
249    with context.eager_mode():
250      cell = TensorArrayStateRNNCell()
251      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
252      rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4])
253
254  @test_util.run_in_graph_and_eager_modes
255  @test_util.run_v1_only("b/120545219")
256  def testTensorArrayStateIsAccepted(self):
257    cell = TensorArrayStateRNNCell()
258    in_eager_mode = context.executing_eagerly()
259
260    if in_eager_mode:
261      inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
262    else:
263      inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
264
265    with self.cached_session() as sess:
266      outputs, state = rnn.dynamic_rnn(
267          cell, inputs, dtype=dtypes.float32, sequence_length=[4])
268      state = (state[0], state[1].stack())
269      if not in_eager_mode:
270        outputs, state = sess.run(
271            [outputs, state], feed_dict={
272                inputs: [[[1], [2], [3], [4]]]
273            })
274
275    self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
276    self.assertAllEqual(4, state[0])
277    self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1])
278
279  @test_util.run_deprecated_v1
280  def testCellGetInitialState(self):
281    cell = rnn_cell_impl.BasicRNNCell(5)
282    with self.assertRaisesRegex(ValueError,
283                                "batch_size and dtype cannot be None"):
284      cell.get_initial_state(None, None, None)
285
286    inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1))
287    with self.assertRaisesRegex(
288        ValueError, "batch size from input tensor is different from"):
289      cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None)
290
291    with self.assertRaisesRegex(
292        ValueError, "batch size from input tensor is different from"):
293      cell.get_initial_state(
294          inputs=inputs, batch_size=constant_op.constant(50), dtype=None)
295
296    with self.assertRaisesRegex(ValueError,
297                                "dtype from input tensor is different from"):
298      cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16)
299
300    initial_state = cell.get_initial_state(
301        inputs=inputs, batch_size=None, dtype=None)
302    self.assertEqual(initial_state.shape.as_list(), [None, 5])
303    self.assertEqual(initial_state.dtype, inputs.dtype)
304
305    batch = array_ops.shape(inputs)[0]
306    dtype = inputs.dtype
307    initial_state = cell.get_initial_state(None, batch, dtype)
308    self.assertEqual(initial_state.shape.as_list(), [None, 5])
309    self.assertEqual(initial_state.dtype, inputs.dtype)
310
311  def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size,
312                          out_size):
313    cell = cell_class(out_size, dtype=dtype)
314    in_shape = tensor_shape.TensorShape((batch_size, in_size))
315    cell.build(in_shape)
316    state_output = cell.get_initial_state(
317        inputs=None, batch_size=batch_size, dtype=dtype)
318    cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
319    self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
320
321  @test_util.run_in_graph_and_eager_modes
322  def testCellsBuild(self):
323    f32 = dtypes.float32
324    f64 = dtypes.float64
325    self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f32, 5, 7, 3)
326    self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f64, 5, 7, 3)
327    self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f32, 5, 7, 3)
328    self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f64, 5, 7, 3)
329    self._assert_cell_builds(rnn_cell_impl.GRUCell, f32, 5, 7, 3)
330    self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3)
331    self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
332    self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
333
334  @test_util.run_deprecated_v1
335  def testBasicLSTMCellInterchangeWithLSTMCell(self):
336    with self.session(graph=ops_lib.Graph()) as sess:
337      basic_cell = rnn_cell_impl.BasicLSTMCell(1)
338      basic_cell(array_ops.ones([1, 1]),
339                 state=basic_cell.get_initial_state(inputs=None,
340                                                    batch_size=1,
341                                                    dtype=dtypes.float32))
342      self.evaluate([v.initializer for v in basic_cell.variables])
343      self.evaluate(basic_cell._bias.assign([10.] * 4))
344      save = saver.Saver()
345      prefix = os.path.join(self.get_temp_dir(), "ckpt")
346      save_path = save.save(sess, prefix)
347
348    with self.session(graph=ops_lib.Graph()) as sess:
349      lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
350      lstm_cell(array_ops.ones([1, 1]),
351                state=lstm_cell.get_initial_state(inputs=None,
352                                                  batch_size=1,
353                                                  dtype=dtypes.float32))
354      self.evaluate([v.initializer for v in lstm_cell.variables])
355      save = saver.Saver()
356      save.restore(sess, save_path)
357      self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias))
358
359######### Benchmarking RNN code
360
361
362def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length):
363  (_, input_size) = inputs_list_t[0].get_shape().as_list()
364  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
365  cell = rnn_cell_impl.LSTMCell(
366      num_units=input_size,
367      use_peepholes=True,
368      initializer=initializer,
369      state_is_tuple=False)
370  outputs, final_state = rnn.static_rnn(
371      cell,
372      inputs_list_t,
373      sequence_length=sequence_length,
374      dtype=dtypes.float32)
375
376  trainable_variables = ops_lib.get_collection(
377      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
378  gradients = gradients_impl.gradients(outputs + [final_state],
379                                       trainable_variables)
380
381  return control_flow_ops.group(final_state, *(gradients + outputs))
382
383
384def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length):
385  (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
386  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
387  cell = rnn_cell_impl.LSTMCell(
388      num_units=input_size,
389      use_peepholes=True,
390      initializer=initializer,
391      state_is_tuple=False)
392  outputs, final_state = rnn.dynamic_rnn(
393      cell, inputs_t, sequence_length=sequence_length, dtype=dtypes.float32)
394
395  trainable_variables = ops_lib.get_collection(
396      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
397  gradients = gradients_impl.gradients([outputs, final_state],
398                                       trainable_variables)
399
400  return control_flow_ops.group(final_state, outputs, *gradients)
401
402
403def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
404  config = config_pb2.ConfigProto()
405  config.allow_soft_placement = True
406
407  # These parameters don't matter
408  batch_size = 512
409  num_units = 512
410
411  # Set up sequence lengths
412  np.random.seed([127])
413  sequence_length = np.random.randint(0, max_time, size=batch_size)
414  inputs_list = [
415      np.random.randn(batch_size, num_units).astype(np.float32)
416      for _ in range(max_time)
417  ]
418  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
419
420  def _create_static_rnn():
421    with session.Session(config=config, graph=ops_lib.Graph()):
422      inputs_list_t = [
423          variables_lib.Variable(
424              x, trainable=False).value() for x in inputs_list
425      ]
426      _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length)
427
428  def _create_dynamic_rnn():
429    with session.Session(config=config, graph=ops_lib.Graph()):
430      inputs_t = variables_lib.Variable(inputs, trainable=False).value()
431      _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length)
432
433  delta_static = timeit.timeit(_create_static_rnn, number=5)
434  delta_dynamic = timeit.timeit(_create_dynamic_rnn, number=5)
435
436  print("%d \t %f \t %f \t %f" %
437        (max_time, delta_static, delta_dynamic, delta_dynamic / delta_static))
438  return delta_static, delta_dynamic
439
440
441def _timer(sess, ops):
442  # Warm in
443  for _ in range(2):
444    sess.run(ops)
445
446  # Timing run
447  runs = 20
448  start = time.time()
449  for _ in range(runs):
450    sess.run(ops)
451  end = time.time()
452  return (end - start) / float(runs)
453
454
455def static_vs_dynamic_rnn_benchmark(batch_size, max_time, num_units, use_gpu):
456  config = config_pb2.ConfigProto()
457  config.allow_soft_placement = True
458
459  # Set up sequence lengths
460  np.random.seed([127])
461  sequence_length = np.random.randint(0, max_time, size=batch_size)
462  inputs_list = [
463      np.random.randn(batch_size, num_units).astype(np.float32)
464      for _ in range(max_time)
465  ]
466  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
467
468  # Using rnn()
469  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
470    with ops_lib.device("/cpu:0" if not use_gpu else None):
471      inputs_list_t = [
472          variables_lib.Variable(
473              x, trainable=False).value() for x in inputs_list
474      ]
475      ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t,
476                                                    sequence_length)
477    variables_lib.global_variables_initializer().run()
478    delta_static = _timer(sess, ops)
479
480  # Using dynamic_rnn()
481  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
482    with ops_lib.device("/cpu:0" if not use_gpu else None):
483      inputs_t = variables_lib.Variable(inputs, trainable=False).value()
484      ops = _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length)
485    variables_lib.global_variables_initializer().run()
486    delta_dynamic = _timer(sess, ops)
487
488  print("%d \t %d \t %d \t %s \t %f \t %f \t %f" %
489        (batch_size, max_time, num_units, use_gpu, delta_static, delta_dynamic,
490         delta_dynamic / delta_static))
491
492  return delta_static, delta_dynamic
493
494
495def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length):
496  (_, input_size) = inputs_list_t[0].get_shape().as_list()
497  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
498  cell = rnn_cell_impl.LSTMCell(
499      num_units=input_size,
500      use_peepholes=True,
501      initializer=initializer,
502      state_is_tuple=False)
503  outputs, final_state = rnn.static_rnn(
504      cell,
505      inputs_list_t,
506      sequence_length=sequence_length,
507      dtype=dtypes.float32)
508
509  trainable_variables = ops_lib.get_collection(
510      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
511  gradients = gradients_impl.gradients(outputs + [final_state],
512                                       trainable_variables)
513
514  return control_flow_ops.group(final_state, *(gradients + outputs))
515
516
517def half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, max_time, num_units,
518                                              use_gpu):
519  config = config_pb2.ConfigProto()
520  config.allow_soft_placement = True
521
522  # Set up sequence lengths
523  np.random.seed([127])
524  sequence_length = max_time * np.ones((batch_size,))
525  inputs_list = [
526      np.random.randn(batch_size, num_units).astype(np.float32)
527      for _ in range(max_time)
528  ]
529
530  # Halve the sequence length, full static unroll
531  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
532    with ops_lib.device("/cpu:0" if not use_gpu else None):
533      inputs_list_t = [
534          variables_lib.Variable(
535              x, trainable=False).value() for x in inputs_list
536      ]
537      ops = _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t,
538                                                       sequence_length / 2)
539    variables_lib.global_variables_initializer().run()
540    delta_half_seq_len = _timer(sess, ops)
541
542  # Halve the unroll size, don't use sequence length
543  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
544    with ops_lib.device("/cpu:0" if not use_gpu else None):
545      inputs_list_t = [
546          variables_lib.Variable(
547              x, trainable=False).value() for x in inputs_list
548      ]
549      ops = _half_seq_len_vs_unroll_half_rnn_benchmark(
550          inputs_list_t[:(max_time // 2)], sequence_length / 2)
551    variables_lib.global_variables_initializer().run()
552    delta_unroll_half = _timer(sess, ops)
553  print("%d \t %d \t\t %d \t %s \t %f \t\t %f \t\t %f" %
554        (batch_size, max_time, num_units, use_gpu, delta_half_seq_len,
555         delta_unroll_half, delta_half_seq_len / delta_unroll_half))
556
557  return delta_half_seq_len, delta_unroll_half
558
559
560def _concat_state_vs_tuple_state_rnn_benchmark(inputs_list_t, sequence_length,
561                                               state_is_tuple):
562  (_, input_size) = inputs_list_t[0].get_shape().as_list()
563  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
564  cell = rnn_cell_impl.LSTMCell(
565      num_units=input_size,
566      use_peepholes=True,
567      initializer=initializer,
568      state_is_tuple=state_is_tuple)
569  outputs, final_state = rnn.static_rnn(
570      cell,
571      inputs_list_t,
572      sequence_length=sequence_length,
573      dtype=dtypes.float32)
574
575  final_state = list(final_state) if state_is_tuple else [final_state]
576
577  trainable_variables = ops_lib.get_collection(
578      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
579  gradients = gradients_impl.gradients(outputs + final_state,
580                                       trainable_variables)
581
582  return control_flow_ops.group(*(final_state + gradients + outputs))
583
584
585def concat_state_vs_tuple_state_rnn_benchmark(batch_size, max_time, num_units,
586                                              use_gpu):
587  config = config_pb2.ConfigProto()
588  config.allow_soft_placement = True
589
590  # Set up sequence lengths
591  np.random.seed([127])
592  sequence_length = max_time * np.ones((batch_size,))
593  inputs_list = [
594      np.random.randn(batch_size, num_units).astype(np.float32)
595      for _ in range(max_time)
596  ]
597
598  # Run with concatenated states (default)
599  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
600    with ops_lib.device("/cpu:0" if not use_gpu else None):
601      inputs_list_t = [
602          variables_lib.Variable(
603              x, trainable=False).value() for x in inputs_list
604      ]
605      ops = _concat_state_vs_tuple_state_rnn_benchmark(
606          inputs_list_t, sequence_length, state_is_tuple=False)
607    variables_lib.global_variables_initializer().run()
608    delta_concat_state = _timer(sess, ops)
609
610  # Run with tuple states (new)
611  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
612    with ops_lib.device("/cpu:0" if not use_gpu else None):
613      inputs_list_t = [
614          variables_lib.Variable(
615              x, trainable=False).value() for x in inputs_list
616      ]
617      ops = _concat_state_vs_tuple_state_rnn_benchmark(
618          inputs_list_t, sequence_length, state_is_tuple=True)
619    variables_lib.global_variables_initializer().run()
620    delta_tuple_state = _timer(sess, ops)
621  print("%d \t %d \t %d \t %s \t %f \t\t %f \t\t %f" %
622        (batch_size, max_time, num_units, use_gpu, delta_concat_state,
623         delta_tuple_state, delta_concat_state / delta_tuple_state))
624
625  return delta_concat_state, delta_tuple_state
626
627
628def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length, swap_memory):
629  (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list()
630  initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
631  cell = rnn_cell_impl.LSTMCell(
632      num_units=input_size,
633      use_peepholes=True,
634      initializer=initializer,
635      state_is_tuple=False)
636  outputs, final_state = rnn.dynamic_rnn(
637      cell,
638      inputs_t,
639      sequence_length=sequence_length,
640      swap_memory=swap_memory,
641      dtype=dtypes.float32)
642
643  trainable_variables = ops_lib.get_collection(
644      ops_lib.GraphKeys.TRAINABLE_VARIABLES)
645  gradients = gradients_impl.gradients([outputs, final_state],
646                                       trainable_variables)
647
648  return control_flow_ops.group(final_state, outputs, *gradients)
649
650
651def dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units):
652  config = config_pb2.ConfigProto()
653  config.allow_soft_placement = True
654
655  # Set up sequence lengths
656  np.random.seed([127])
657  sequence_length = np.random.randint(0, max_time, size=batch_size)
658  inputs_list = [
659      np.random.randn(batch_size, num_units).astype(np.float32)
660      for _ in range(max_time)
661  ]
662  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
663
664  # No memory swap
665  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
666    inputs_t = variables_lib.Variable(inputs, trainable=False).value()
667    ops = _dynamic_rnn_swap_memory_benchmark(
668        inputs_t, sequence_length, swap_memory=False)
669    variables_lib.global_variables_initializer().run()
670    no_swap = _timer(sess, ops)
671
672  # Memory swap
673  with session.Session(config=config, graph=ops_lib.Graph()) as sess:
674    inputs_t = variables_lib.Variable(inputs, trainable=False).value()
675    ops = _dynamic_rnn_swap_memory_benchmark(
676        inputs_t, sequence_length, swap_memory=True)
677    variables_lib.global_variables_initializer().run()
678    swap = _timer(sess, ops)
679
680  print("%d \t %d \t %d \t %f \t %f \t %f" %
681        (batch_size, max_time, num_units, no_swap, swap, swap / no_swap))
682  return no_swap, swap
683
684
685def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, dynamic,
686                                swap_memory, nn):
687  config = config_pb2.ConfigProto()
688  config.allow_soft_placement = True
689
690  # Set up sequence lengths
691  np.random.seed([127])
692  sequence_length = [seqlen for _ in range(batch_size)]
693  inputs_list = [
694      np.random.randn(batch_size, num_units).astype(np.float32)
695      for _ in range(seqlen)
696  ]
697  inputs = np.dstack(inputs_list).transpose([0, 2, 1])  # batch x time x depth
698
699  for _ in range(nn):
700    if dynamic:
701      with session.Session(config=config, graph=ops_lib.Graph()) as sess:
702        inputs_t = variables_lib.Variable(inputs, trainable=False).value()
703        ops = _dynamic_rnn_swap_memory_benchmark(
704            inputs_t, sequence_length, swap_memory=swap_memory)
705        variables_lib.global_variables_initializer().run()
706        elapsed = _timer(sess, ops)
707    else:
708      with session.Session(config=config, graph=ops_lib.Graph()) as sess:
709        inputs_list_t = [
710            variables_lib.Variable(
711                x, trainable=False).value() for x in inputs_list
712        ]
713        ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t,
714                                                      sequence_length)
715        variables_lib.global_variables_initializer().run()
716        elapsed = _timer(sess, ops)
717
718    print("%d \t %d \t %d \t %s \t %f \t %f" % (batch_size, seqlen, num_units,
719                                                dynamic, elapsed,
720                                                elapsed / seqlen))
721
722
723class BenchmarkRNN(test.Benchmark):
724
725  def benchmarkGraphCreationStaticVsDynamicLSTM(self):
726    print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
727    print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
728    for max_time in (1, 25, 50):
729      s_dt, d_dt = graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
730      self.report_benchmark(
731          name="graph_creation_time_static_T%02d" % max_time,
732          iters=5,
733          wall_time=s_dt)
734      self.report_benchmark(
735          name="graph_creation_time_dynamic_T%02d" % max_time,
736          iters=5,
737          wall_time=d_dt)
738
739  def benchmarkStaticUnrollVsDynamicFlowLSTM(self):
740    print("Calculation: Static Unroll with Dynamic Flow LSTM "
741          "vs. Dynamic Unroll LSTM")
742    print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
743          "\t dt(dynamic)/dt(static)")
744    for batch_size in (256,):
745      for max_time in (50,):
746        for num_units in (512, 256, 128):
747          for use_gpu in (False, True):
748            s_dt, d_dt = static_vs_dynamic_rnn_benchmark(batch_size, max_time,
749                                                         num_units, use_gpu)
750            self.report_benchmark(
751                name="static_unroll_time_T%02d_B%03d_N%03d_gpu_%s" %
752                (max_time, batch_size, num_units, use_gpu),
753                iters=20,
754                wall_time=s_dt)
755            self.report_benchmark(
756                name="dynamic_unroll_time_T%02d_B%03d_N%03d_gpu_%s" %
757                (max_time, batch_size, num_units, use_gpu),
758                iters=20,
759                wall_time=d_dt)
760
761  def benchmarkDynamicLSTMNoMemorySwapVsMemorySwap(self):
762    print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
763    print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
764    for batch_size in (256, 512):
765      for max_time in (100,):
766        for num_units in (512, 256, 128):
767          no_swap, swap = dynamic_rnn_swap_memory_benchmark(batch_size,
768                                                            max_time, num_units)
769          self.report_benchmark(
770              name="dynamic_lstm_no_memory_swap_T%02d_B%03d_N%03d" %
771              (max_time, batch_size, num_units),
772              iters=20,
773              wall_time=no_swap)
774          self.report_benchmark(
775              name="dynamic_lstm_with_memory_swap_T%02d_B%03d_N%03d" %
776              (max_time, batch_size, num_units),
777              iters=20,
778              wall_time=swap)
779
780  def benchmarkStaticUnrollHalfSequenceLengthVsHalfUnroll(self):
781    print("Calculation: Static Unroll with Halved Sequence Length "
782          "vs. Half Static Unroll")
783    print("batch \t full_t \t units \t gpu \t dt(half_seq_len) "
784          "\t dt(unroll_half) \t dt(half_seq_len)/dt(unroll_half)")
785    for batch_size in (128,):
786      for max_time in (50,):
787        for num_units in (256,):
788          for use_gpu in (False, True):
789            s_dt, d_dt = half_seq_len_vs_unroll_half_rnn_benchmark(batch_size,
790                                                                   max_time,
791                                                                   num_units,
792                                                                   use_gpu)
793            self.report_benchmark(
794                name="half_seq_len_time_T%02d_B%03d_N%03d_gpu_%s" %
795                (max_time, batch_size, num_units, use_gpu),
796                iters=20,
797                wall_time=s_dt)
798            self.report_benchmark(
799                name="unroll_half_time_T%02d_B%03d_N%03d_gpu_%s" %
800                (max_time, batch_size, num_units, use_gpu),
801                iters=20,
802                wall_time=d_dt)
803
804  def benchmarkStaticUnrollStateConcatVsStateTuple(self):
805    print("Calculation: Static Unroll with Concatenated State "
806          "vs. Tuple State")
807    print("batch \t time \t units \t gpu \t dt(concat_state) "
808          "\t dt(tuple_state) \t dt(concat_state)/dt(tuple_state)")
809    for batch_size in (
810        16,
811        128,):
812      for max_time in (50,):
813        for num_units in (
814            16,
815            128,):
816          for use_gpu in (False, True):
817            c_dt, t_dt = concat_state_vs_tuple_state_rnn_benchmark(batch_size,
818                                                                   max_time,
819                                                                   num_units,
820                                                                   use_gpu)
821            self.report_benchmark(
822                name="concat_state_time_T%02d_B%03d_N%03d_gpu_%s" %
823                (max_time, batch_size, num_units, use_gpu),
824                iters=20,
825                wall_time=c_dt)
826            self.report_benchmark(
827                name="tuple_state_time_T%02d_B%03d_N%03d_gpu_%s" %
828                (max_time, batch_size, num_units, use_gpu),
829                iters=20,
830                wall_time=t_dt)
831
832  def _benchmarkDynamicLSTMMemorySwapLongSeq(self):
833    """The memory swapping test for the SOSP submission."""
834    print("Calculation: Long LSTM Sequence")
835    print("batch \t len \t units \t dynamic \t elapsed_t \t elapsed_t/len")
836    batch_size = 512
837    seqlen = 800
838    num_units = 512
839    dynamic = True
840    swap_memory = True
841    # Some warming up.
842    if swap_memory:
843      rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
844                                  dynamic, swap_memory, 2)
845    # Measure the performance.
846    for slen in xrange(100, 1100, 100):
847      rnn_long_sequence_benchmark(batch_size, slen, num_units, dynamic,
848                                  swap_memory, 3)
849
850if __name__ == "__main__":
851  test.main()
852