• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""LSTM Block Cell ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.contrib.rnn.python.kernel_tests import benchmarking
25from tensorflow.contrib.rnn.python.ops import lstm_ops
26from tensorflow.python.client import session
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_array_ops
32from tensorflow.python.ops import gen_bitwise_ops
33from tensorflow.python.ops import gradients_impl
34from tensorflow.python.ops import init_ops
35from tensorflow.python.ops import rnn
36from tensorflow.python.ops import rnn_cell
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40
41block_lstm = lstm_ops._block_lstm  # pylint: disable=protected-access
42
43
44class _MaskedRandomUniformInitializer(init_ops.RandomUniform):
45  """Initializer for uniform dist tensors with trailing bits zeroed-out.
46
47  Allow returning tensors with last few mantissa bits set to 0. This potentially
48  helps avoid getting into precision issues when testing low precision (float16)
49  computation.
50  """
51
52  def __init__(self,
53               minval=0,
54               maxval=None,
55               seed=None,
56               dtype=dtypes.float16,
57               num_valid_mantissa_bits=4):
58    """Constructor.
59
60    Args:
61      minval: A python scalar or a scalar tensor. Lower bound of the range of
62        random values to generate.
63      maxval: A python scalar or a scalar tensor. Upper bound of the range of
64        random values to generate.  Defaults to 1 for float types.
65      seed: A Python integer. Used to create random seeds. See
66        `tf.set_random_seed` for behavior.
67      dtype: The data type. Only supports tf.float16 for now.
68      num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4.
69
70    Raises:
71      ValueError: An error if `dtype` is not tf.float16.
72    """
73    if dtype not in (dtypes.float16,):
74      raise ValueError("dtype: %s not supported" % dtype.name)
75
76    super(_MaskedRandomUniformInitializer, self).__init__(
77        minval=minval, maxval=maxval, seed=seed, dtype=dtype)
78    self._num_mantissa_bits = 10
79    self._num_valid_mantissa_bits = num_valid_mantissa_bits
80
81  def __call__(self, shape, dtype=dtypes.float16, partition_info=None):
82    if dtype and dtype != dtypes.float16:
83      raise ValueError("dtype: %s not supported" % dtype.name)
84    res = super(_MaskedRandomUniformInitializer, self).__call__(
85        shape, dtype, partition_info)
86    # get uint16 view of the underlying buffer.
87    res = gen_array_ops.bitcast(res, dtypes.uint16)
88
89    # mask the last `shift` mantissa bits.
90    shift = self._num_mantissa_bits - self._num_valid_mantissa_bits
91    mask = (0xffff >> shift) << shift
92    res = gen_bitwise_ops.bitwise_and(res, mask)
93
94    # restore float16 view.
95    return gen_array_ops.bitcast(res, dtype)
96
97
98def _get_initializer(init_bound, dtype, seed):
99  if dtype == dtypes.float16:
100    return _MaskedRandomUniformInitializer(
101        -init_bound, init_bound, dtype=dtype, seed=seed)
102  else:
103    return init_ops.random_uniform_initializer(
104        -init_bound, init_bound, dtype=dtype, seed=seed)
105
106
107def blocks_match(sess, use_peephole, dtype=dtypes.float32, cell_clip=None):
108  batch_size = 2
109  input_size = 3
110  cell_size = 4
111  sequence_length = 4
112
113  inputs = []
114  for _ in range(sequence_length):
115    inp = ops.convert_to_tensor(
116        np.random.randn(batch_size, input_size), dtype=dtype)
117    inputs.append(inp)
118  stacked_inputs = array_ops.stack(inputs)
119
120  init_bound = 1e-1 if dtype == dtypes.float16 else 1e-2
121  initializer = _get_initializer(init_bound, dtype=dtype, seed=19890212)
122
123  with variable_scope.variable_scope("test", initializer=initializer):
124    # magic naming so that the cells pick up these variables and reuse them
125    if use_peephole:
126      wci = variable_scope.get_variable(
127          "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtype)
128      wcf = variable_scope.get_variable(
129          "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtype)
130      wco = variable_scope.get_variable(
131          "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtype)
132
133    w = variable_scope.get_variable(
134        "rnn/lstm_cell/kernel",
135        shape=[input_size + cell_size, cell_size * 4],
136        dtype=dtype)
137    b = variable_scope.get_variable(
138        "rnn/lstm_cell/bias",
139        shape=[cell_size * 4],
140        dtype=dtype,
141        initializer=init_ops.zeros_initializer())
142
143    basic_cell = rnn_cell.LSTMCell(
144        cell_size,
145        use_peepholes=use_peephole,
146        cell_clip=cell_clip,
147        dtype=dtype,
148        state_is_tuple=True,
149        reuse=True)
150    basic_outputs_op, basic_state_op = rnn.static_rnn(
151        basic_cell, inputs, dtype=dtype)
152
153    if use_peephole:
154      _, _, _, _, _, _, block_outputs_op = block_lstm(
155          ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
156          inputs,
157          w,
158          b,
159          wci=wci,
160          wcf=wcf,
161          wco=wco,
162          cell_clip=cell_clip,
163          use_peephole=True)
164    else:
165      _, _, _, _, _, _, block_outputs_op = block_lstm(
166          ops.convert_to_tensor(sequence_length, dtype=dtypes.int64),
167          inputs,
168          w,
169          b,
170          cell_clip=cell_clip)
171
172    fused_cell = lstm_ops.LSTMBlockFusedCell(
173        cell_size,
174        cell_clip=cell_clip,
175        use_peephole=use_peephole,
176        reuse=True,
177        name="rnn/lstm_cell")
178    fused_outputs_op, fused_state_op = fused_cell(stacked_inputs, dtype=dtype)
179
180    sess.run([variables.global_variables_initializer()])
181    basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]])
182    basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs))
183    xs = [w, b]
184    if use_peephole:
185      xs += [wci, wcf, wco]
186    basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs))
187
188    block_outputs = sess.run(block_outputs_op)
189    block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs))
190    block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs))
191
192    xs = [w, b]
193    if use_peephole:
194      xs += [wci, wcf, wco]
195    fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]])
196    fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
197    fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs))
198
199    return (basic_state, fused_state, basic_outputs, block_outputs,
200            fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads,
201            block_wgrads, fused_wgrads)
202
203
204class LSTMBlockCellTest(test.TestCase, parameterized.TestCase):
205
206  TEST_CASES = ({
207      "testcase_name": "Fp32",
208      "dtype": dtypes.float32,
209      "rtol": 1e-6,
210      "atol": 1e-6
211  }, {
212      "testcase_name": "Fp16",
213      "dtype": dtypes.float16,
214      "rtol": 8e-3,
215      "atol": 8e-4
216  })
217
218  def testNoneDimsWithDynamicRNN(self):
219    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
220      batch_size = 4
221      num_steps = 5
222      input_dim = 6
223      cell_size = 7
224
225      cell = lstm_ops.LSTMBlockCell(cell_size)
226      x = array_ops.placeholder(dtypes.float32, shape=(None, None, input_dim))
227
228      output, _ = rnn.dynamic_rnn(
229          cell, x, time_major=True, dtype=dtypes.float32)
230      sess.run(variables.global_variables_initializer())
231      feed = {}
232      feed[x] = np.random.randn(num_steps, batch_size, input_dim)
233      sess.run(output, feed)
234
235  def testLSTMBlockCell(self):
236    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
237      with variable_scope.variable_scope(
238          "root", initializer=init_ops.constant_initializer(0.5)):
239        x = array_ops.zeros([1, 2])
240        m0 = array_ops.zeros([1, 2])
241        m1 = array_ops.zeros([1, 2])
242        m2 = array_ops.zeros([1, 2])
243        m3 = array_ops.zeros([1, 2])
244        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
245            [lstm_ops.LSTMBlockCell(2)
246             for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
247        sess.run([variables.global_variables_initializer()])
248        res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
249            x.name: np.array([[1., 1.]]),
250            m0.name: 0.1 * np.ones([1, 2]),
251            m1.name: 0.1 * np.ones([1, 2]),
252            m2.name: 0.1 * np.ones([1, 2]),
253            m3.name: 0.1 * np.ones([1, 2])
254        })
255        self.assertEqual(len(res), 5)
256        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
257        # These numbers are from testBasicLSTMCell and only test c/h.
258        self.assertAllClose(res[1], [[0.68967271, 0.68967271]])
259        self.assertAllClose(res[2], [[0.44848421, 0.44848421]])
260        self.assertAllClose(res[3], [[0.39897051, 0.39897051]])
261        self.assertAllClose(res[4], [[0.24024698, 0.24024698]])
262
263  def testCompatibleNames(self):
264    with self.session(use_gpu=True, graph=ops.Graph()):
265      cell = rnn_cell.LSTMCell(10)
266      pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
267      inputs = [array_ops.zeros([4, 5])] * 6
268      rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
269      rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
270      basic_names = {
271          v.name: v.get_shape()
272          for v in variables.trainable_variables()
273      }
274
275    with self.session(use_gpu=True, graph=ops.Graph()):
276      cell = lstm_ops.LSTMBlockCell(10)
277      pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True)
278      inputs = [array_ops.zeros([4, 5])] * 6
279      rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic")
280      rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole")
281      block_names = {
282          v.name: v.get_shape()
283          for v in variables.trainable_variables()
284      }
285
286    with self.session(use_gpu=True, graph=ops.Graph()):
287      cell = lstm_ops.LSTMBlockFusedCell(10)
288      pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True)
289      inputs = array_ops.stack([array_ops.zeros([4, 5])] * 6)
290      cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell")
291      pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell")
292      fused_names = {
293          v.name: v.get_shape()
294          for v in variables.trainable_variables()
295      }
296
297    self.assertEqual(basic_names, block_names)
298    self.assertEqual(basic_names, fused_names)
299
300  def testLSTMBasicToBlockCell(self):
301    with self.session(use_gpu=True) as sess:
302      x = array_ops.zeros([1, 2])
303      x_values = np.random.randn(1, 2)
304
305      m0_val = 0.1 * np.ones([1, 2])
306      m1_val = -0.1 * np.ones([1, 2])
307      m2_val = -0.2 * np.ones([1, 2])
308      m3_val = 0.2 * np.ones([1, 2])
309
310      initializer = init_ops.random_uniform_initializer(
311          -0.01, 0.01, seed=19890212)
312      with variable_scope.variable_scope("basic", initializer=initializer):
313        m0 = array_ops.zeros([1, 2])
314        m1 = array_ops.zeros([1, 2])
315        m2 = array_ops.zeros([1, 2])
316        m3 = array_ops.zeros([1, 2])
317        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
318            [rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)],
319            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
320        sess.run([variables.global_variables_initializer()])
321        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
322            x.name: x_values,
323            m0.name: m0_val,
324            m1.name: m1_val,
325            m2.name: m2_val,
326            m3.name: m3_val
327        })
328
329      with variable_scope.variable_scope("block", initializer=initializer):
330        m0 = array_ops.zeros([1, 2])
331        m1 = array_ops.zeros([1, 2])
332        m2 = array_ops.zeros([1, 2])
333        m3 = array_ops.zeros([1, 2])
334        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
335            [lstm_ops.LSTMBlockCell(2)
336             for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
337        sess.run([variables.global_variables_initializer()])
338        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
339            x.name: x_values,
340            m0.name: m0_val,
341            m1.name: m1_val,
342            m2.name: m2_val,
343            m3.name: m3_val
344        })
345
346      self.assertEqual(len(basic_res), len(block_res))
347      for basic, block in zip(basic_res, block_res):
348        self.assertAllClose(basic, block)
349
350  def testLSTMBasicToBlockCellPeeping(self):
351    with self.session(use_gpu=True) as sess:
352      x = array_ops.zeros([1, 2])
353      x_values = np.random.randn(1, 2)
354
355      m0_val = 0.1 * np.ones([1, 2])
356      m1_val = -0.1 * np.ones([1, 2])
357      m2_val = -0.2 * np.ones([1, 2])
358      m3_val = 0.2 * np.ones([1, 2])
359
360      initializer = init_ops.random_uniform_initializer(
361          -0.01, 0.01, seed=19890212)
362      with variable_scope.variable_scope("basic", initializer=initializer):
363        m0 = array_ops.zeros([1, 2])
364        m1 = array_ops.zeros([1, 2])
365        m2 = array_ops.zeros([1, 2])
366        m3 = array_ops.zeros([1, 2])
367        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
368            [
369                rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True)
370                for _ in range(2)
371            ],
372            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
373        sess.run([variables.global_variables_initializer()])
374        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
375            x.name: x_values,
376            m0.name: m0_val,
377            m1.name: m1_val,
378            m2.name: m2_val,
379            m3.name: m3_val
380        })
381
382      with variable_scope.variable_scope("block", initializer=initializer):
383        m0 = array_ops.zeros([1, 2])
384        m1 = array_ops.zeros([1, 2])
385        m2 = array_ops.zeros([1, 2])
386        m3 = array_ops.zeros([1, 2])
387        g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell(
388            [lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)],
389            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
390        sess.run([variables.global_variables_initializer()])
391        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
392            x.name: x_values,
393            m0.name: m0_val,
394            m1.name: m1_val,
395            m2.name: m2_val,
396            m3.name: m3_val
397        })
398
399      self.assertEqual(len(basic_res), len(block_res))
400      for basic, block in zip(basic_res, block_res):
401        self.assertAllClose(basic, block)
402
403  def LSTMBasicToBlockTestHelper(self,
404                                 dtype=dtypes.float32,
405                                 use_peephole=False,
406                                 cell_clip=None,
407                                 rtol=1e-6,
408                                 atol=1e-6):
409    with self.session(use_gpu=True, graph=ops.Graph()) as sess:
410      (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs,
411       basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads,
412       fused_wgrads) = blocks_match(
413           sess, use_peephole=use_peephole, dtype=dtype, cell_clip=cell_clip)
414
415      self.assertAllClose(basic_outputs, block_outputs, rtol=rtol, atol=atol)
416      self.assertAllClose(basic_grads, block_grads, rtol=rtol, atol=atol)
417      for basic, block in zip(basic_wgrads, block_wgrads):
418        self.assertAllClose(basic, block, rtol=rtol, atol=atol)
419
420      self.assertAllClose(basic_outputs, fused_outputs, rtol=rtol, atol=atol)
421      self.assertAllClose(basic_state, fused_state, rtol=rtol, atol=atol)
422      self.assertAllClose(basic_grads, fused_grads, rtol=rtol, atol=atol)
423      for basic, fused in zip(basic_wgrads, fused_wgrads):
424        self.assertAllClose(basic, fused, rtol=rtol, atol=atol)
425
426  @parameterized.named_parameters(*TEST_CASES)
427  def testLSTMBasicToBlock(self, dtype, rtol, atol):
428    self.LSTMBasicToBlockTestHelper(
429        dtype, use_peephole=False, rtol=rtol, atol=atol)
430
431  @parameterized.named_parameters(*TEST_CASES)
432  def testLSTMBasicToBlockPeeping(self, dtype, rtol, atol):
433    self.LSTMBasicToBlockTestHelper(
434        dtype, use_peephole=True, rtol=rtol, atol=atol)
435
436  @parameterized.named_parameters(*TEST_CASES)
437  def testLSTMBasicToBlockCellClip(self, dtype, rtol, atol):
438    self.LSTMBasicToBlockTestHelper(
439        dtype, use_peephole=True, cell_clip=0.5, rtol=rtol, atol=atol)
440
441  def testLSTMFusedSequenceLengths(self):
442    """Verify proper support for sequence lengths in LSTMBlockFusedCell."""
443    with self.session(use_gpu=True) as sess:
444      batch_size = 3
445      input_size = 4
446      cell_size = 5
447      max_sequence_length = 6
448
449      inputs = []
450      for _ in range(max_sequence_length):
451        inp = ops.convert_to_tensor(
452            np.random.randn(batch_size, input_size), dtype=dtypes.float32)
453        inputs.append(inp)
454      seq_lengths = constant_op.constant([3, 4, 5])
455      cell_inputs = array_ops.stack(inputs)
456
457      initializer = init_ops.random_uniform_initializer(
458          -0.01, 0.01, seed=19890213)
459
460      with variable_scope.variable_scope("lstm_cell", initializer=initializer):
461        # magic naming so that the cells pick up these variables and reuse them
462        variable_scope.get_variable(
463            "kernel",
464            shape=[input_size + cell_size, cell_size * 4],
465            dtype=dtypes.float32)
466
467        variable_scope.get_variable(
468            "bias",
469            shape=[cell_size * 4],
470            dtype=dtypes.float32,
471            initializer=init_ops.zeros_initializer())
472
473      cell = lstm_ops.LSTMBlockFusedCell(
474          cell_size, cell_clip=0, use_peephole=False, reuse=True,
475          name="lstm_cell")
476
477      fused_outputs_op, fused_state_op = cell(
478          cell_inputs, dtype=dtypes.float32, sequence_length=seq_lengths)
479
480      cell_vars = [
481          v for v in variables.trainable_variables()
482          if v.name.endswith("kernel") or v.name.endswith("bias")
483      ]
484
485      # Verify that state propagation works if we turn our sequence into
486      # tiny (single-time) subsequences, i.e. unfuse the cell
487      unfused_outputs_op = []
488      state = None
489      with variable_scope.variable_scope(
490          variable_scope.get_variable_scope(), reuse=True):
491        for i, inp in enumerate(inputs):
492          lengths = [int(i < l) for l in seq_lengths.eval()]
493          output, state = cell(
494              array_ops.expand_dims(inp, 0),
495              initial_state=state,
496              dtype=dtypes.float32,
497              sequence_length=lengths)
498          unfused_outputs_op.append(output[0])
499      unfused_outputs_op = array_ops.stack(unfused_outputs_op)
500
501      sess.run([variables.global_variables_initializer()])
502      unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]])
503      unfused_grads = sess.run(
504          gradients_impl.gradients(unfused_outputs_op, inputs))
505      unfused_wgrads = sess.run(
506          gradients_impl.gradients(unfused_outputs_op, cell_vars))
507
508      fused_outputs, fused_state = sess.run(
509          [fused_outputs_op, fused_state_op[0]])
510      fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs))
511      fused_wgrads = sess.run(
512          gradients_impl.gradients(fused_outputs_op, cell_vars))
513
514      self.assertAllClose(fused_outputs, unfused_outputs)
515      self.assertAllClose(fused_state, unfused_state)
516      self.assertAllClose(fused_grads, unfused_grads)
517      for fused, unfused in zip(fused_wgrads, unfused_wgrads):
518        self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6)
519
520#### Benchmarking.
521
522
523class BenchmarkLSTMBlock(test.Benchmark):
524
525  def benchmarkLSTMBlockCellFpropWithDynamicRNN(self):
526    print("BlockLSTMCell forward propagation via dynamic_rnn().")
527    print("--------------------------------------------------------------")
528    print("LSTMBlockCell Seconds per inference.")
529    print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
530    iters = 10
531    for config in benchmarking.dict_product({
532        "batch_size": [1, 8, 13, 32, 67, 128],
533        "cell_size": [128, 250, 512, 650, 1024, 1350],
534        "time_steps": [40],
535        "use_gpu": [True, False],
536        "dtype": ["float32", "float16"],
537    }):
538      dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16
539      with ops.Graph().as_default():
540        with benchmarking.device(use_gpu=config["use_gpu"]):
541          inputs = variable_scope.get_variable(
542              "x",
543              dtype=dtype,
544              shape=[
545                  config["time_steps"], config["batch_size"],
546                  config["cell_size"]
547              ])
548          cell = lstm_ops.LSTMBlockCell(config["cell_size"], dtype=dtype)
549          outputs = rnn.dynamic_rnn(cell, inputs, time_major=True, dtype=dtype)
550          init_op = variables.global_variables_initializer()
551
552        with session.Session() as sess:
553          sess.run(init_op)
554          wall_time = benchmarking.seconds_per_run(outputs, sess, iters)
555
556        # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
557        # is set, this will produce a copy-paste-able CSV file.
558        print(",".join(
559            map(str, [
560                config["dtype"], config["batch_size"], config["cell_size"],
561                config["cell_size"], config["time_steps"], config["use_gpu"],
562                wall_time
563            ])))
564        benchmark_name_template = "_".join([
565            "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i",
566            "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i",
567            "gpu_%(use_gpu)s"
568        ])
569
570        self.report_benchmark(
571            name=benchmark_name_template % config,
572            iters=iters,
573            wall_time=wall_time,
574            extras=config)
575
576  def benchmarkLSTMBlockCellBpropWithDynamicRNN(self):
577    print("BlockLSTMCell backward propagation via dynamic_rnn().")
578    print("--------------------------------------------------------------")
579    print("LSTMBlockCell Seconds per inference.")
580    print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time")
581    iters = 10
582    for config in benchmarking.dict_product({
583        "batch_size": [1, 8, 13, 32, 67, 128],
584        "cell_size": [128, 250, 512, 650, 1024, 1350],
585        "time_steps": [40],
586        "use_gpu": [True, False],
587        "dtype": ["float32", "float16"],
588    }):
589      dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16
590      with ops.Graph().as_default():
591        with benchmarking.device(use_gpu=config["use_gpu"]):
592          time_steps = config["time_steps"]
593          batch_size = config["batch_size"]
594          cell_size = input_size = config["cell_size"]
595          inputs = variable_scope.get_variable(
596              "x", [time_steps, batch_size, cell_size],
597              trainable=False,
598              dtype=dtype)
599          with variable_scope.variable_scope(
600              "rnn", reuse=variable_scope.AUTO_REUSE):
601            w = variable_scope.get_variable(
602                "rnn/lstm_cell/kernel",
603                shape=[input_size + cell_size, cell_size * 4],
604                dtype=dtype)
605            b = variable_scope.get_variable(
606                "rnn/lstm_cell/bias",
607                shape=[cell_size * 4],
608                dtype=dtype,
609                initializer=init_ops.zeros_initializer())
610            cell = lstm_ops.LSTMBlockCell(cell_size, dtype=dtype)
611            outputs = rnn.dynamic_rnn(
612                cell, inputs, time_major=True, dtype=dtype)
613          grads = gradients_impl.gradients(outputs, [inputs, w, b])
614          init_op = variables.global_variables_initializer()
615
616        with session.Session() as sess:
617          sess.run(init_op)
618          wall_time = benchmarking.seconds_per_run(grads, sess, iters)
619
620        # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable
621        # is set, this will produce a copy-paste-able CSV file.
622        print(",".join(
623            map(str, [
624                config["dtype"], batch_size, cell_size, cell_size, time_steps,
625                config["use_gpu"], wall_time
626            ])))
627        benchmark_name_template = "_".join([
628            "LSTMBlockCell_bprop", "DT_%(dtype)s", "BS%(batch_size)i",
629            "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i",
630            "gpu_%(use_gpu)s"
631        ])
632
633        self.report_benchmark(
634            name=benchmark_name_template % config,
635            iters=iters,
636            wall_time=wall_time,
637            extras=config)
638
639
640if __name__ == "__main__":
641  test.main()
642