1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""LSTM Block Cell ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.contrib.rnn.python.kernel_tests import benchmarking 25from tensorflow.contrib.rnn.python.ops import lstm_ops 26from tensorflow.python.client import session 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_array_ops 32from tensorflow.python.ops import gen_bitwise_ops 33from tensorflow.python.ops import gradients_impl 34from tensorflow.python.ops import init_ops 35from tensorflow.python.ops import rnn 36from tensorflow.python.ops import rnn_cell 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import test 40 41block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access 42 43 44class _MaskedRandomUniformInitializer(init_ops.RandomUniform): 45 """Initializer for uniform dist tensors with trailing bits zeroed-out. 46 47 Allow returning tensors with last few mantissa bits set to 0. This potentially 48 helps avoid getting into precision issues when testing low precision (float16) 49 computation. 50 """ 51 52 def __init__(self, 53 minval=0, 54 maxval=None, 55 seed=None, 56 dtype=dtypes.float16, 57 num_valid_mantissa_bits=4): 58 """Constructor. 59 60 Args: 61 minval: A python scalar or a scalar tensor. Lower bound of the range of 62 random values to generate. 63 maxval: A python scalar or a scalar tensor. Upper bound of the range of 64 random values to generate. Defaults to 1 for float types. 65 seed: A Python integer. Used to create random seeds. See 66 `tf.set_random_seed` for behavior. 67 dtype: The data type. Only supports tf.float16 for now. 68 num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4. 69 70 Raises: 71 ValueError: An error if `dtype` is not tf.float16. 72 """ 73 if dtype not in (dtypes.float16,): 74 raise ValueError("dtype: %s not supported" % dtype.name) 75 76 super(_MaskedRandomUniformInitializer, self).__init__( 77 minval=minval, maxval=maxval, seed=seed, dtype=dtype) 78 self._num_mantissa_bits = 10 79 self._num_valid_mantissa_bits = num_valid_mantissa_bits 80 81 def __call__(self, shape, dtype=dtypes.float16, partition_info=None): 82 if dtype and dtype != dtypes.float16: 83 raise ValueError("dtype: %s not supported" % dtype.name) 84 res = super(_MaskedRandomUniformInitializer, self).__call__( 85 shape, dtype, partition_info) 86 # get uint16 view of the underlying buffer. 87 res = gen_array_ops.bitcast(res, dtypes.uint16) 88 89 # mask the last `shift` mantissa bits. 90 shift = self._num_mantissa_bits - self._num_valid_mantissa_bits 91 mask = (0xffff >> shift) << shift 92 res = gen_bitwise_ops.bitwise_and(res, mask) 93 94 # restore float16 view. 95 return gen_array_ops.bitcast(res, dtype) 96 97 98def _get_initializer(init_bound, dtype, seed): 99 if dtype == dtypes.float16: 100 return _MaskedRandomUniformInitializer( 101 -init_bound, init_bound, dtype=dtype, seed=seed) 102 else: 103 return init_ops.random_uniform_initializer( 104 -init_bound, init_bound, dtype=dtype, seed=seed) 105 106 107def blocks_match(sess, use_peephole, dtype=dtypes.float32, cell_clip=None): 108 batch_size = 2 109 input_size = 3 110 cell_size = 4 111 sequence_length = 4 112 113 inputs = [] 114 for _ in range(sequence_length): 115 inp = ops.convert_to_tensor( 116 np.random.randn(batch_size, input_size), dtype=dtype) 117 inputs.append(inp) 118 stacked_inputs = array_ops.stack(inputs) 119 120 init_bound = 1e-1 if dtype == dtypes.float16 else 1e-2 121 initializer = _get_initializer(init_bound, dtype=dtype, seed=19890212) 122 123 with variable_scope.variable_scope("test", initializer=initializer): 124 # magic naming so that the cells pick up these variables and reuse them 125 if use_peephole: 126 wci = variable_scope.get_variable( 127 "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtype) 128 wcf = variable_scope.get_variable( 129 "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtype) 130 wco = variable_scope.get_variable( 131 "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtype) 132 133 w = variable_scope.get_variable( 134 "rnn/lstm_cell/kernel", 135 shape=[input_size + cell_size, cell_size * 4], 136 dtype=dtype) 137 b = variable_scope.get_variable( 138 "rnn/lstm_cell/bias", 139 shape=[cell_size * 4], 140 dtype=dtype, 141 initializer=init_ops.zeros_initializer()) 142 143 basic_cell = rnn_cell.LSTMCell( 144 cell_size, 145 use_peepholes=use_peephole, 146 cell_clip=cell_clip, 147 dtype=dtype, 148 state_is_tuple=True, 149 reuse=True) 150 basic_outputs_op, basic_state_op = rnn.static_rnn( 151 basic_cell, inputs, dtype=dtype) 152 153 if use_peephole: 154 _, _, _, _, _, _, block_outputs_op = block_lstm( 155 ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), 156 inputs, 157 w, 158 b, 159 wci=wci, 160 wcf=wcf, 161 wco=wco, 162 cell_clip=cell_clip, 163 use_peephole=True) 164 else: 165 _, _, _, _, _, _, block_outputs_op = block_lstm( 166 ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), 167 inputs, 168 w, 169 b, 170 cell_clip=cell_clip) 171 172 fused_cell = lstm_ops.LSTMBlockFusedCell( 173 cell_size, 174 cell_clip=cell_clip, 175 use_peephole=use_peephole, 176 reuse=True, 177 name="rnn/lstm_cell") 178 fused_outputs_op, fused_state_op = fused_cell(stacked_inputs, dtype=dtype) 179 180 sess.run([variables.global_variables_initializer()]) 181 basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) 182 basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs)) 183 xs = [w, b] 184 if use_peephole: 185 xs += [wci, wcf, wco] 186 basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs)) 187 188 block_outputs = sess.run(block_outputs_op) 189 block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs)) 190 block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs)) 191 192 xs = [w, b] 193 if use_peephole: 194 xs += [wci, wcf, wco] 195 fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]]) 196 fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) 197 fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs)) 198 199 return (basic_state, fused_state, basic_outputs, block_outputs, 200 fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, 201 block_wgrads, fused_wgrads) 202 203 204class LSTMBlockCellTest(test.TestCase, parameterized.TestCase): 205 206 TEST_CASES = ({ 207 "testcase_name": "Fp32", 208 "dtype": dtypes.float32, 209 "rtol": 1e-6, 210 "atol": 1e-6 211 }, { 212 "testcase_name": "Fp16", 213 "dtype": dtypes.float16, 214 "rtol": 8e-3, 215 "atol": 8e-4 216 }) 217 218 def testNoneDimsWithDynamicRNN(self): 219 with self.session(use_gpu=True, graph=ops.Graph()) as sess: 220 batch_size = 4 221 num_steps = 5 222 input_dim = 6 223 cell_size = 7 224 225 cell = lstm_ops.LSTMBlockCell(cell_size) 226 x = array_ops.placeholder(dtypes.float32, shape=(None, None, input_dim)) 227 228 output, _ = rnn.dynamic_rnn( 229 cell, x, time_major=True, dtype=dtypes.float32) 230 sess.run(variables.global_variables_initializer()) 231 feed = {} 232 feed[x] = np.random.randn(num_steps, batch_size, input_dim) 233 sess.run(output, feed) 234 235 def testLSTMBlockCell(self): 236 with self.session(use_gpu=True, graph=ops.Graph()) as sess: 237 with variable_scope.variable_scope( 238 "root", initializer=init_ops.constant_initializer(0.5)): 239 x = array_ops.zeros([1, 2]) 240 m0 = array_ops.zeros([1, 2]) 241 m1 = array_ops.zeros([1, 2]) 242 m2 = array_ops.zeros([1, 2]) 243 m3 = array_ops.zeros([1, 2]) 244 g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell( 245 [lstm_ops.LSTMBlockCell(2) 246 for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3))) 247 sess.run([variables.global_variables_initializer()]) 248 res = sess.run([g, out_m0, out_m1, out_m2, out_m3], { 249 x.name: np.array([[1., 1.]]), 250 m0.name: 0.1 * np.ones([1, 2]), 251 m1.name: 0.1 * np.ones([1, 2]), 252 m2.name: 0.1 * np.ones([1, 2]), 253 m3.name: 0.1 * np.ones([1, 2]) 254 }) 255 self.assertEqual(len(res), 5) 256 self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) 257 # These numbers are from testBasicLSTMCell and only test c/h. 258 self.assertAllClose(res[1], [[0.68967271, 0.68967271]]) 259 self.assertAllClose(res[2], [[0.44848421, 0.44848421]]) 260 self.assertAllClose(res[3], [[0.39897051, 0.39897051]]) 261 self.assertAllClose(res[4], [[0.24024698, 0.24024698]]) 262 263 def testCompatibleNames(self): 264 with self.session(use_gpu=True, graph=ops.Graph()): 265 cell = rnn_cell.LSTMCell(10) 266 pcell = rnn_cell.LSTMCell(10, use_peepholes=True) 267 inputs = [array_ops.zeros([4, 5])] * 6 268 rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic") 269 rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole") 270 basic_names = { 271 v.name: v.get_shape() 272 for v in variables.trainable_variables() 273 } 274 275 with self.session(use_gpu=True, graph=ops.Graph()): 276 cell = lstm_ops.LSTMBlockCell(10) 277 pcell = lstm_ops.LSTMBlockCell(10, use_peephole=True) 278 inputs = [array_ops.zeros([4, 5])] * 6 279 rnn.static_rnn(cell, inputs, dtype=dtypes.float32, scope="basic") 280 rnn.static_rnn(pcell, inputs, dtype=dtypes.float32, scope="peephole") 281 block_names = { 282 v.name: v.get_shape() 283 for v in variables.trainable_variables() 284 } 285 286 with self.session(use_gpu=True, graph=ops.Graph()): 287 cell = lstm_ops.LSTMBlockFusedCell(10) 288 pcell = lstm_ops.LSTMBlockFusedCell(10, use_peephole=True) 289 inputs = array_ops.stack([array_ops.zeros([4, 5])] * 6) 290 cell(inputs, dtype=dtypes.float32, scope="basic/lstm_cell") 291 pcell(inputs, dtype=dtypes.float32, scope="peephole/lstm_cell") 292 fused_names = { 293 v.name: v.get_shape() 294 for v in variables.trainable_variables() 295 } 296 297 self.assertEqual(basic_names, block_names) 298 self.assertEqual(basic_names, fused_names) 299 300 def testLSTMBasicToBlockCell(self): 301 with self.session(use_gpu=True) as sess: 302 x = array_ops.zeros([1, 2]) 303 x_values = np.random.randn(1, 2) 304 305 m0_val = 0.1 * np.ones([1, 2]) 306 m1_val = -0.1 * np.ones([1, 2]) 307 m2_val = -0.2 * np.ones([1, 2]) 308 m3_val = 0.2 * np.ones([1, 2]) 309 310 initializer = init_ops.random_uniform_initializer( 311 -0.01, 0.01, seed=19890212) 312 with variable_scope.variable_scope("basic", initializer=initializer): 313 m0 = array_ops.zeros([1, 2]) 314 m1 = array_ops.zeros([1, 2]) 315 m2 = array_ops.zeros([1, 2]) 316 m3 = array_ops.zeros([1, 2]) 317 g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell( 318 [rnn_cell.BasicLSTMCell(2, state_is_tuple=True) for _ in range(2)], 319 state_is_tuple=True)(x, ((m0, m1), (m2, m3))) 320 sess.run([variables.global_variables_initializer()]) 321 basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], { 322 x.name: x_values, 323 m0.name: m0_val, 324 m1.name: m1_val, 325 m2.name: m2_val, 326 m3.name: m3_val 327 }) 328 329 with variable_scope.variable_scope("block", initializer=initializer): 330 m0 = array_ops.zeros([1, 2]) 331 m1 = array_ops.zeros([1, 2]) 332 m2 = array_ops.zeros([1, 2]) 333 m3 = array_ops.zeros([1, 2]) 334 g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell( 335 [lstm_ops.LSTMBlockCell(2) 336 for _ in range(2)], state_is_tuple=True)(x, ((m0, m1), (m2, m3))) 337 sess.run([variables.global_variables_initializer()]) 338 block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], { 339 x.name: x_values, 340 m0.name: m0_val, 341 m1.name: m1_val, 342 m2.name: m2_val, 343 m3.name: m3_val 344 }) 345 346 self.assertEqual(len(basic_res), len(block_res)) 347 for basic, block in zip(basic_res, block_res): 348 self.assertAllClose(basic, block) 349 350 def testLSTMBasicToBlockCellPeeping(self): 351 with self.session(use_gpu=True) as sess: 352 x = array_ops.zeros([1, 2]) 353 x_values = np.random.randn(1, 2) 354 355 m0_val = 0.1 * np.ones([1, 2]) 356 m1_val = -0.1 * np.ones([1, 2]) 357 m2_val = -0.2 * np.ones([1, 2]) 358 m3_val = 0.2 * np.ones([1, 2]) 359 360 initializer = init_ops.random_uniform_initializer( 361 -0.01, 0.01, seed=19890212) 362 with variable_scope.variable_scope("basic", initializer=initializer): 363 m0 = array_ops.zeros([1, 2]) 364 m1 = array_ops.zeros([1, 2]) 365 m2 = array_ops.zeros([1, 2]) 366 m3 = array_ops.zeros([1, 2]) 367 g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell( 368 [ 369 rnn_cell.LSTMCell(2, use_peepholes=True, state_is_tuple=True) 370 for _ in range(2) 371 ], 372 state_is_tuple=True)(x, ((m0, m1), (m2, m3))) 373 sess.run([variables.global_variables_initializer()]) 374 basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], { 375 x.name: x_values, 376 m0.name: m0_val, 377 m1.name: m1_val, 378 m2.name: m2_val, 379 m3.name: m3_val 380 }) 381 382 with variable_scope.variable_scope("block", initializer=initializer): 383 m0 = array_ops.zeros([1, 2]) 384 m1 = array_ops.zeros([1, 2]) 385 m2 = array_ops.zeros([1, 2]) 386 m3 = array_ops.zeros([1, 2]) 387 g, ((out_m0, out_m1), (out_m2, out_m3)) = rnn_cell.MultiRNNCell( 388 [lstm_ops.LSTMBlockCell(2, use_peephole=True) for _ in range(2)], 389 state_is_tuple=True)(x, ((m0, m1), (m2, m3))) 390 sess.run([variables.global_variables_initializer()]) 391 block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3], { 392 x.name: x_values, 393 m0.name: m0_val, 394 m1.name: m1_val, 395 m2.name: m2_val, 396 m3.name: m3_val 397 }) 398 399 self.assertEqual(len(basic_res), len(block_res)) 400 for basic, block in zip(basic_res, block_res): 401 self.assertAllClose(basic, block) 402 403 def LSTMBasicToBlockTestHelper(self, 404 dtype=dtypes.float32, 405 use_peephole=False, 406 cell_clip=None, 407 rtol=1e-6, 408 atol=1e-6): 409 with self.session(use_gpu=True, graph=ops.Graph()) as sess: 410 (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, 411 basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, 412 fused_wgrads) = blocks_match( 413 sess, use_peephole=use_peephole, dtype=dtype, cell_clip=cell_clip) 414 415 self.assertAllClose(basic_outputs, block_outputs, rtol=rtol, atol=atol) 416 self.assertAllClose(basic_grads, block_grads, rtol=rtol, atol=atol) 417 for basic, block in zip(basic_wgrads, block_wgrads): 418 self.assertAllClose(basic, block, rtol=rtol, atol=atol) 419 420 self.assertAllClose(basic_outputs, fused_outputs, rtol=rtol, atol=atol) 421 self.assertAllClose(basic_state, fused_state, rtol=rtol, atol=atol) 422 self.assertAllClose(basic_grads, fused_grads, rtol=rtol, atol=atol) 423 for basic, fused in zip(basic_wgrads, fused_wgrads): 424 self.assertAllClose(basic, fused, rtol=rtol, atol=atol) 425 426 @parameterized.named_parameters(*TEST_CASES) 427 def testLSTMBasicToBlock(self, dtype, rtol, atol): 428 self.LSTMBasicToBlockTestHelper( 429 dtype, use_peephole=False, rtol=rtol, atol=atol) 430 431 @parameterized.named_parameters(*TEST_CASES) 432 def testLSTMBasicToBlockPeeping(self, dtype, rtol, atol): 433 self.LSTMBasicToBlockTestHelper( 434 dtype, use_peephole=True, rtol=rtol, atol=atol) 435 436 @parameterized.named_parameters(*TEST_CASES) 437 def testLSTMBasicToBlockCellClip(self, dtype, rtol, atol): 438 self.LSTMBasicToBlockTestHelper( 439 dtype, use_peephole=True, cell_clip=0.5, rtol=rtol, atol=atol) 440 441 def testLSTMFusedSequenceLengths(self): 442 """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" 443 with self.session(use_gpu=True) as sess: 444 batch_size = 3 445 input_size = 4 446 cell_size = 5 447 max_sequence_length = 6 448 449 inputs = [] 450 for _ in range(max_sequence_length): 451 inp = ops.convert_to_tensor( 452 np.random.randn(batch_size, input_size), dtype=dtypes.float32) 453 inputs.append(inp) 454 seq_lengths = constant_op.constant([3, 4, 5]) 455 cell_inputs = array_ops.stack(inputs) 456 457 initializer = init_ops.random_uniform_initializer( 458 -0.01, 0.01, seed=19890213) 459 460 with variable_scope.variable_scope("lstm_cell", initializer=initializer): 461 # magic naming so that the cells pick up these variables and reuse them 462 variable_scope.get_variable( 463 "kernel", 464 shape=[input_size + cell_size, cell_size * 4], 465 dtype=dtypes.float32) 466 467 variable_scope.get_variable( 468 "bias", 469 shape=[cell_size * 4], 470 dtype=dtypes.float32, 471 initializer=init_ops.zeros_initializer()) 472 473 cell = lstm_ops.LSTMBlockFusedCell( 474 cell_size, cell_clip=0, use_peephole=False, reuse=True, 475 name="lstm_cell") 476 477 fused_outputs_op, fused_state_op = cell( 478 cell_inputs, dtype=dtypes.float32, sequence_length=seq_lengths) 479 480 cell_vars = [ 481 v for v in variables.trainable_variables() 482 if v.name.endswith("kernel") or v.name.endswith("bias") 483 ] 484 485 # Verify that state propagation works if we turn our sequence into 486 # tiny (single-time) subsequences, i.e. unfuse the cell 487 unfused_outputs_op = [] 488 state = None 489 with variable_scope.variable_scope( 490 variable_scope.get_variable_scope(), reuse=True): 491 for i, inp in enumerate(inputs): 492 lengths = [int(i < l) for l in seq_lengths.eval()] 493 output, state = cell( 494 array_ops.expand_dims(inp, 0), 495 initial_state=state, 496 dtype=dtypes.float32, 497 sequence_length=lengths) 498 unfused_outputs_op.append(output[0]) 499 unfused_outputs_op = array_ops.stack(unfused_outputs_op) 500 501 sess.run([variables.global_variables_initializer()]) 502 unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]]) 503 unfused_grads = sess.run( 504 gradients_impl.gradients(unfused_outputs_op, inputs)) 505 unfused_wgrads = sess.run( 506 gradients_impl.gradients(unfused_outputs_op, cell_vars)) 507 508 fused_outputs, fused_state = sess.run( 509 [fused_outputs_op, fused_state_op[0]]) 510 fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) 511 fused_wgrads = sess.run( 512 gradients_impl.gradients(fused_outputs_op, cell_vars)) 513 514 self.assertAllClose(fused_outputs, unfused_outputs) 515 self.assertAllClose(fused_state, unfused_state) 516 self.assertAllClose(fused_grads, unfused_grads) 517 for fused, unfused in zip(fused_wgrads, unfused_wgrads): 518 self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6) 519 520#### Benchmarking. 521 522 523class BenchmarkLSTMBlock(test.Benchmark): 524 525 def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): 526 print("BlockLSTMCell forward propagation via dynamic_rnn().") 527 print("--------------------------------------------------------------") 528 print("LSTMBlockCell Seconds per inference.") 529 print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") 530 iters = 10 531 for config in benchmarking.dict_product({ 532 "batch_size": [1, 8, 13, 32, 67, 128], 533 "cell_size": [128, 250, 512, 650, 1024, 1350], 534 "time_steps": [40], 535 "use_gpu": [True, False], 536 "dtype": ["float32", "float16"], 537 }): 538 dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 539 with ops.Graph().as_default(): 540 with benchmarking.device(use_gpu=config["use_gpu"]): 541 inputs = variable_scope.get_variable( 542 "x", 543 dtype=dtype, 544 shape=[ 545 config["time_steps"], config["batch_size"], 546 config["cell_size"] 547 ]) 548 cell = lstm_ops.LSTMBlockCell(config["cell_size"], dtype=dtype) 549 outputs = rnn.dynamic_rnn(cell, inputs, time_major=True, dtype=dtype) 550 init_op = variables.global_variables_initializer() 551 552 with session.Session() as sess: 553 sess.run(init_op) 554 wall_time = benchmarking.seconds_per_run(outputs, sess, iters) 555 556 # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable 557 # is set, this will produce a copy-paste-able CSV file. 558 print(",".join( 559 map(str, [ 560 config["dtype"], config["batch_size"], config["cell_size"], 561 config["cell_size"], config["time_steps"], config["use_gpu"], 562 wall_time 563 ]))) 564 benchmark_name_template = "_".join([ 565 "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i", 566 "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", 567 "gpu_%(use_gpu)s" 568 ]) 569 570 self.report_benchmark( 571 name=benchmark_name_template % config, 572 iters=iters, 573 wall_time=wall_time, 574 extras=config) 575 576 def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): 577 print("BlockLSTMCell backward propagation via dynamic_rnn().") 578 print("--------------------------------------------------------------") 579 print("LSTMBlockCell Seconds per inference.") 580 print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") 581 iters = 10 582 for config in benchmarking.dict_product({ 583 "batch_size": [1, 8, 13, 32, 67, 128], 584 "cell_size": [128, 250, 512, 650, 1024, 1350], 585 "time_steps": [40], 586 "use_gpu": [True, False], 587 "dtype": ["float32", "float16"], 588 }): 589 dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 590 with ops.Graph().as_default(): 591 with benchmarking.device(use_gpu=config["use_gpu"]): 592 time_steps = config["time_steps"] 593 batch_size = config["batch_size"] 594 cell_size = input_size = config["cell_size"] 595 inputs = variable_scope.get_variable( 596 "x", [time_steps, batch_size, cell_size], 597 trainable=False, 598 dtype=dtype) 599 with variable_scope.variable_scope( 600 "rnn", reuse=variable_scope.AUTO_REUSE): 601 w = variable_scope.get_variable( 602 "rnn/lstm_cell/kernel", 603 shape=[input_size + cell_size, cell_size * 4], 604 dtype=dtype) 605 b = variable_scope.get_variable( 606 "rnn/lstm_cell/bias", 607 shape=[cell_size * 4], 608 dtype=dtype, 609 initializer=init_ops.zeros_initializer()) 610 cell = lstm_ops.LSTMBlockCell(cell_size, dtype=dtype) 611 outputs = rnn.dynamic_rnn( 612 cell, inputs, time_major=True, dtype=dtype) 613 grads = gradients_impl.gradients(outputs, [inputs, w, b]) 614 init_op = variables.global_variables_initializer() 615 616 with session.Session() as sess: 617 sess.run(init_op) 618 wall_time = benchmarking.seconds_per_run(grads, sess, iters) 619 620 # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable 621 # is set, this will produce a copy-paste-able CSV file. 622 print(",".join( 623 map(str, [ 624 config["dtype"], batch_size, cell_size, cell_size, time_steps, 625 config["use_gpu"], wall_time 626 ]))) 627 benchmark_name_template = "_".join([ 628 "LSTMBlockCell_bprop", "DT_%(dtype)s", "BS%(batch_size)i", 629 "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", 630 "gpu_%(use_gpu)s" 631 ]) 632 633 self.report_benchmark( 634 name=benchmark_name_template % config, 635 iters=iters, 636 wall_time=wall_time, 637 extras=config) 638 639 640if __name__ == "__main__": 641 test.main() 642