1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for rnn module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import time 23import timeit 24 25import numpy as np 26 27from six.moves import xrange # pylint: disable=redefined-builtin 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.client import session 30from tensorflow.python.eager import context 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops as ops_lib 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import test_util 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gradients_impl 39from tensorflow.python.ops import init_ops 40from tensorflow.python.ops import rnn 41from tensorflow.python.ops import rnn_cell_impl 42from tensorflow.python.ops import tensor_array_ops 43from tensorflow.python.ops import variables as variables_lib 44import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 45import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 46import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import 47import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 48from tensorflow.python.platform import test 49from tensorflow.python.training import saver 50 51 52class Plus1RNNCell(rnn_cell_impl.RNNCell): 53 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 54 55 @property 56 def output_size(self): 57 return 5 58 59 @property 60 def state_size(self): 61 return 5 62 63 def call(self, input_, state, scope=None): 64 return (input_ + 1, state + 1) 65 66 67class ScalarStateRNNCell(rnn_cell_impl.RNNCell): 68 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 69 70 @property 71 def output_size(self): 72 return 1 73 74 @property 75 def state_size(self): 76 return tensor_shape.TensorShape([]) 77 78 def zero_state(self, batch_size, dtype): 79 return array_ops.zeros([], dtype=dtypes.int32) 80 81 def call(self, input_, state, scope=None): 82 return (input_, state + 1) 83 84 85class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): 86 """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" 87 88 @property 89 def output_size(self): 90 return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) 91 92 @property 93 def state_size(self): 94 return tensor_shape.TensorShape([]) 95 96 def zero_state(self, batch_size, dtype): 97 return array_ops.zeros([], dtype=dtypes.int32) 98 99 def call(self, input_, state, scope=None): 100 concatenated = array_ops.concat((input_, input_), axis=-1) 101 return (input_, concatenated), state + 1 102 103 104class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): 105 """RNN Cell its state as a TensorArray.""" 106 107 @property 108 def output_size(self): 109 return 1 110 111 @property 112 def state_size(self): 113 return (tensor_shape.TensorShape([]), ()) 114 115 def zero_state(self, batch_size, dtype): 116 return (array_ops.zeros([], dtype=dtypes.int32), 117 tensor_array_ops.TensorArray( 118 dtype=dtype, size=0, dynamic_size=True)) 119 120 def call(self, input_, state, scope=None): 121 new_array = state[1].write(state[0], input_) 122 return (input_, (state[0] + 1, new_array)) 123 124 125class RNNTest(test.TestCase): 126 127 def setUp(self): 128 self._seed = 23489 129 np.random.seed(self._seed) 130 131 @test_util.run_in_graph_and_eager_modes 132 def testInvalidSequenceLengthShape(self): 133 cell = Plus1RNNCell() 134 if context.executing_eagerly(): 135 inputs = [constant_op.constant(np.ones((3, 4)))] 136 else: 137 inputs = [array_ops.placeholder(dtypes.float32, shape=(3, 4))] 138 with self.assertRaisesRegex(ValueError, "must be a vector"): 139 rnn.dynamic_rnn( 140 cell, 141 array_ops.stack(inputs), 142 dtype=dtypes.float32, 143 sequence_length=[[4]]) 144 145 @test_util.run_in_graph_and_eager_modes 146 def testInvalidDtype(self): 147 if context.executing_eagerly(): 148 inputs = np.zeros((3, 4, 5), dtype=np.int32) 149 else: 150 inputs = array_ops.placeholder(dtypes.int32, shape=(3, 4, 5)) 151 152 cells = [ 153 rnn_cell_impl.BasicRNNCell, 154 rnn_cell_impl.GRUCell, 155 rnn_cell_impl.BasicLSTMCell, 156 rnn_cell_impl.LSTMCell, 157 ] 158 for cell_cls in cells: 159 with self.cached_session(): 160 with self.assertRaisesRegex(ValueError, 161 "RNN cell only supports floating"): 162 cell = cell_cls(2, dtype=dtypes.int32) 163 rnn.dynamic_rnn(cell, inputs, dtype=dtypes.int32) 164 165 @test_util.run_in_graph_and_eager_modes 166 def testBatchSizeFromInput(self): 167 cell = Plus1RNNCell() 168 in_eager_mode = context.executing_eagerly() 169 # With static batch size 170 if in_eager_mode: 171 inputs = np.zeros((3, 4, 5), dtype=np.float32) 172 initial_state = np.zeros((3, 5), dtype=np.float32) 173 else: 174 inputs = array_ops.placeholder(dtypes.float32, shape=(3, 4, 5)) 175 initial_state = array_ops.placeholder(dtypes.float32, shape=(3, 5)) 176 177 # - Without initial_state 178 outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) 179 self.assertEqual(3, outputs.shape[0]) 180 self.assertEqual(3, state.shape[0]) 181 182 # - With initial_state 183 outputs, state = rnn.dynamic_rnn( 184 cell, inputs, initial_state=initial_state) 185 self.assertEqual(3, outputs.shape[0]) 186 self.assertEqual(3, state.shape[0]) 187 188 # Without static batch size 189 # Tensor shapes are fully determined with eager execution enabled, 190 # so only run this test for graph construction. 191 if not in_eager_mode: 192 inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 5)) 193 # - Without initial_state 194 outputs, state = rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) 195 self.assertEqual(None, outputs.shape.dims[0].value) 196 self.assertEqual(None, state.shape.dims[0].value) 197 # - With initial_state 198 outputs, state = rnn.dynamic_rnn( 199 cell, 200 inputs, 201 initial_state=array_ops.placeholder(dtypes.float32, shape=(None, 5))) 202 self.assertEqual(None, outputs.shape.dims[0].value) 203 self.assertEqual(None, state.shape.dims[0].value) 204 205 @test_util.run_in_graph_and_eager_modes 206 def testScalarStateIsAccepted(self): 207 cell = ScalarStateRNNCell() 208 in_eager_mode = context.executing_eagerly() 209 210 if in_eager_mode: 211 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 212 else: 213 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 214 215 with self.cached_session() as sess: 216 outputs, state = rnn.dynamic_rnn( 217 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 218 if not in_eager_mode: 219 outputs, state = sess.run( 220 [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) 221 222 self.assertAllEqual([[[1], [2], [3], [4]]], outputs) 223 self.assertAllEqual(4, state) 224 225 @test_util.run_in_graph_and_eager_modes 226 def testUnbalancedOutputIsAccepted(self): 227 cell = UnbalancedOutputRNNCell() 228 in_eager_mode = context.executing_eagerly() 229 230 if in_eager_mode: 231 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 232 else: 233 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 234 235 with self.cached_session() as sess: 236 outputs, state = rnn.dynamic_rnn( 237 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 238 if not in_eager_mode: 239 outputs, state = sess.run( 240 [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) 241 242 self.assertIsInstance(outputs, tuple) 243 self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) 244 self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) 245 self.assertAllEqual(4, state) 246 247 @test_util.assert_no_new_pyobjects_executing_eagerly 248 def testEagerMemory(self): 249 with context.eager_mode(): 250 cell = TensorArrayStateRNNCell() 251 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 252 rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 253 254 @test_util.run_in_graph_and_eager_modes 255 @test_util.run_v1_only("b/120545219") 256 def testTensorArrayStateIsAccepted(self): 257 cell = TensorArrayStateRNNCell() 258 in_eager_mode = context.executing_eagerly() 259 260 if in_eager_mode: 261 inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) 262 else: 263 inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) 264 265 with self.cached_session() as sess: 266 outputs, state = rnn.dynamic_rnn( 267 cell, inputs, dtype=dtypes.float32, sequence_length=[4]) 268 state = (state[0], state[1].stack()) 269 if not in_eager_mode: 270 outputs, state = sess.run( 271 [outputs, state], feed_dict={ 272 inputs: [[[1], [2], [3], [4]]] 273 }) 274 275 self.assertAllEqual([[[1], [2], [3], [4]]], outputs) 276 self.assertAllEqual(4, state[0]) 277 self.assertAllEqual([[[1]], [[2]], [[3]], [[4]]], state[1]) 278 279 @test_util.run_deprecated_v1 280 def testCellGetInitialState(self): 281 cell = rnn_cell_impl.BasicRNNCell(5) 282 with self.assertRaisesRegex(ValueError, 283 "batch_size and dtype cannot be None"): 284 cell.get_initial_state(None, None, None) 285 286 inputs = array_ops.placeholder(dtypes.float32, shape=(None, 4, 1)) 287 with self.assertRaisesRegex( 288 ValueError, "batch size from input tensor is different from"): 289 cell.get_initial_state(inputs=inputs, batch_size=50, dtype=None) 290 291 with self.assertRaisesRegex( 292 ValueError, "batch size from input tensor is different from"): 293 cell.get_initial_state( 294 inputs=inputs, batch_size=constant_op.constant(50), dtype=None) 295 296 with self.assertRaisesRegex(ValueError, 297 "dtype from input tensor is different from"): 298 cell.get_initial_state(inputs=inputs, batch_size=None, dtype=dtypes.int16) 299 300 initial_state = cell.get_initial_state( 301 inputs=inputs, batch_size=None, dtype=None) 302 self.assertEqual(initial_state.shape.as_list(), [None, 5]) 303 self.assertEqual(initial_state.dtype, inputs.dtype) 304 305 batch = array_ops.shape(inputs)[0] 306 dtype = inputs.dtype 307 initial_state = cell.get_initial_state(None, batch, dtype) 308 self.assertEqual(initial_state.shape.as_list(), [None, 5]) 309 self.assertEqual(initial_state.dtype, inputs.dtype) 310 311 def _assert_cell_builds(self, cell_class, dtype, batch_size, in_size, 312 out_size): 313 cell = cell_class(out_size, dtype=dtype) 314 in_shape = tensor_shape.TensorShape((batch_size, in_size)) 315 cell.build(in_shape) 316 state_output = cell.get_initial_state( 317 inputs=None, batch_size=batch_size, dtype=dtype) 318 cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) 319 self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) 320 321 @test_util.run_in_graph_and_eager_modes 322 def testCellsBuild(self): 323 f32 = dtypes.float32 324 f64 = dtypes.float64 325 self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f32, 5, 7, 3) 326 self._assert_cell_builds(rnn_cell_impl.BasicRNNCell, f64, 5, 7, 3) 327 self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f32, 5, 7, 3) 328 self._assert_cell_builds(rnn_cell_impl.BasicLSTMCell, f64, 5, 7, 3) 329 self._assert_cell_builds(rnn_cell_impl.GRUCell, f32, 5, 7, 3) 330 self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3) 331 self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3) 332 self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3) 333 334 @test_util.run_deprecated_v1 335 def testBasicLSTMCellInterchangeWithLSTMCell(self): 336 with self.session(graph=ops_lib.Graph()) as sess: 337 basic_cell = rnn_cell_impl.BasicLSTMCell(1) 338 basic_cell(array_ops.ones([1, 1]), 339 state=basic_cell.get_initial_state(inputs=None, 340 batch_size=1, 341 dtype=dtypes.float32)) 342 self.evaluate([v.initializer for v in basic_cell.variables]) 343 self.evaluate(basic_cell._bias.assign([10.] * 4)) 344 save = saver.Saver() 345 prefix = os.path.join(self.get_temp_dir(), "ckpt") 346 save_path = save.save(sess, prefix) 347 348 with self.session(graph=ops_lib.Graph()) as sess: 349 lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell") 350 lstm_cell(array_ops.ones([1, 1]), 351 state=lstm_cell.get_initial_state(inputs=None, 352 batch_size=1, 353 dtype=dtypes.float32)) 354 self.evaluate([v.initializer for v in lstm_cell.variables]) 355 save = saver.Saver() 356 save.restore(sess, save_path) 357 self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) 358 359######### Benchmarking RNN code 360 361 362def _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length): 363 (_, input_size) = inputs_list_t[0].get_shape().as_list() 364 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 365 cell = rnn_cell_impl.LSTMCell( 366 num_units=input_size, 367 use_peepholes=True, 368 initializer=initializer, 369 state_is_tuple=False) 370 outputs, final_state = rnn.static_rnn( 371 cell, 372 inputs_list_t, 373 sequence_length=sequence_length, 374 dtype=dtypes.float32) 375 376 trainable_variables = ops_lib.get_collection( 377 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 378 gradients = gradients_impl.gradients(outputs + [final_state], 379 trainable_variables) 380 381 return control_flow_ops.group(final_state, *(gradients + outputs)) 382 383 384def _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length): 385 (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() 386 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 387 cell = rnn_cell_impl.LSTMCell( 388 num_units=input_size, 389 use_peepholes=True, 390 initializer=initializer, 391 state_is_tuple=False) 392 outputs, final_state = rnn.dynamic_rnn( 393 cell, inputs_t, sequence_length=sequence_length, dtype=dtypes.float32) 394 395 trainable_variables = ops_lib.get_collection( 396 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 397 gradients = gradients_impl.gradients([outputs, final_state], 398 trainable_variables) 399 400 return control_flow_ops.group(final_state, outputs, *gradients) 401 402 403def graph_creation_static_vs_dynamic_rnn_benchmark(max_time): 404 config = config_pb2.ConfigProto() 405 config.allow_soft_placement = True 406 407 # These parameters don't matter 408 batch_size = 512 409 num_units = 512 410 411 # Set up sequence lengths 412 np.random.seed([127]) 413 sequence_length = np.random.randint(0, max_time, size=batch_size) 414 inputs_list = [ 415 np.random.randn(batch_size, num_units).astype(np.float32) 416 for _ in range(max_time) 417 ] 418 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 419 420 def _create_static_rnn(): 421 with session.Session(config=config, graph=ops_lib.Graph()): 422 inputs_list_t = [ 423 variables_lib.Variable( 424 x, trainable=False).value() for x in inputs_list 425 ] 426 _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, sequence_length) 427 428 def _create_dynamic_rnn(): 429 with session.Session(config=config, graph=ops_lib.Graph()): 430 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 431 _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length) 432 433 delta_static = timeit.timeit(_create_static_rnn, number=5) 434 delta_dynamic = timeit.timeit(_create_dynamic_rnn, number=5) 435 436 print("%d \t %f \t %f \t %f" % 437 (max_time, delta_static, delta_dynamic, delta_dynamic / delta_static)) 438 return delta_static, delta_dynamic 439 440 441def _timer(sess, ops): 442 # Warm in 443 for _ in range(2): 444 sess.run(ops) 445 446 # Timing run 447 runs = 20 448 start = time.time() 449 for _ in range(runs): 450 sess.run(ops) 451 end = time.time() 452 return (end - start) / float(runs) 453 454 455def static_vs_dynamic_rnn_benchmark(batch_size, max_time, num_units, use_gpu): 456 config = config_pb2.ConfigProto() 457 config.allow_soft_placement = True 458 459 # Set up sequence lengths 460 np.random.seed([127]) 461 sequence_length = np.random.randint(0, max_time, size=batch_size) 462 inputs_list = [ 463 np.random.randn(batch_size, num_units).astype(np.float32) 464 for _ in range(max_time) 465 ] 466 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 467 468 # Using rnn() 469 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 470 with ops_lib.device("/cpu:0" if not use_gpu else None): 471 inputs_list_t = [ 472 variables_lib.Variable( 473 x, trainable=False).value() for x in inputs_list 474 ] 475 ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, 476 sequence_length) 477 variables_lib.global_variables_initializer().run() 478 delta_static = _timer(sess, ops) 479 480 # Using dynamic_rnn() 481 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 482 with ops_lib.device("/cpu:0" if not use_gpu else None): 483 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 484 ops = _static_vs_dynamic_rnn_benchmark_dynamic(inputs_t, sequence_length) 485 variables_lib.global_variables_initializer().run() 486 delta_dynamic = _timer(sess, ops) 487 488 print("%d \t %d \t %d \t %s \t %f \t %f \t %f" % 489 (batch_size, max_time, num_units, use_gpu, delta_static, delta_dynamic, 490 delta_dynamic / delta_static)) 491 492 return delta_static, delta_dynamic 493 494 495def _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, sequence_length): 496 (_, input_size) = inputs_list_t[0].get_shape().as_list() 497 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 498 cell = rnn_cell_impl.LSTMCell( 499 num_units=input_size, 500 use_peepholes=True, 501 initializer=initializer, 502 state_is_tuple=False) 503 outputs, final_state = rnn.static_rnn( 504 cell, 505 inputs_list_t, 506 sequence_length=sequence_length, 507 dtype=dtypes.float32) 508 509 trainable_variables = ops_lib.get_collection( 510 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 511 gradients = gradients_impl.gradients(outputs + [final_state], 512 trainable_variables) 513 514 return control_flow_ops.group(final_state, *(gradients + outputs)) 515 516 517def half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, max_time, num_units, 518 use_gpu): 519 config = config_pb2.ConfigProto() 520 config.allow_soft_placement = True 521 522 # Set up sequence lengths 523 np.random.seed([127]) 524 sequence_length = max_time * np.ones((batch_size,)) 525 inputs_list = [ 526 np.random.randn(batch_size, num_units).astype(np.float32) 527 for _ in range(max_time) 528 ] 529 530 # Halve the sequence length, full static unroll 531 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 532 with ops_lib.device("/cpu:0" if not use_gpu else None): 533 inputs_list_t = [ 534 variables_lib.Variable( 535 x, trainable=False).value() for x in inputs_list 536 ] 537 ops = _half_seq_len_vs_unroll_half_rnn_benchmark(inputs_list_t, 538 sequence_length / 2) 539 variables_lib.global_variables_initializer().run() 540 delta_half_seq_len = _timer(sess, ops) 541 542 # Halve the unroll size, don't use sequence length 543 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 544 with ops_lib.device("/cpu:0" if not use_gpu else None): 545 inputs_list_t = [ 546 variables_lib.Variable( 547 x, trainable=False).value() for x in inputs_list 548 ] 549 ops = _half_seq_len_vs_unroll_half_rnn_benchmark( 550 inputs_list_t[:(max_time // 2)], sequence_length / 2) 551 variables_lib.global_variables_initializer().run() 552 delta_unroll_half = _timer(sess, ops) 553 print("%d \t %d \t\t %d \t %s \t %f \t\t %f \t\t %f" % 554 (batch_size, max_time, num_units, use_gpu, delta_half_seq_len, 555 delta_unroll_half, delta_half_seq_len / delta_unroll_half)) 556 557 return delta_half_seq_len, delta_unroll_half 558 559 560def _concat_state_vs_tuple_state_rnn_benchmark(inputs_list_t, sequence_length, 561 state_is_tuple): 562 (_, input_size) = inputs_list_t[0].get_shape().as_list() 563 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 564 cell = rnn_cell_impl.LSTMCell( 565 num_units=input_size, 566 use_peepholes=True, 567 initializer=initializer, 568 state_is_tuple=state_is_tuple) 569 outputs, final_state = rnn.static_rnn( 570 cell, 571 inputs_list_t, 572 sequence_length=sequence_length, 573 dtype=dtypes.float32) 574 575 final_state = list(final_state) if state_is_tuple else [final_state] 576 577 trainable_variables = ops_lib.get_collection( 578 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 579 gradients = gradients_impl.gradients(outputs + final_state, 580 trainable_variables) 581 582 return control_flow_ops.group(*(final_state + gradients + outputs)) 583 584 585def concat_state_vs_tuple_state_rnn_benchmark(batch_size, max_time, num_units, 586 use_gpu): 587 config = config_pb2.ConfigProto() 588 config.allow_soft_placement = True 589 590 # Set up sequence lengths 591 np.random.seed([127]) 592 sequence_length = max_time * np.ones((batch_size,)) 593 inputs_list = [ 594 np.random.randn(batch_size, num_units).astype(np.float32) 595 for _ in range(max_time) 596 ] 597 598 # Run with concatenated states (default) 599 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 600 with ops_lib.device("/cpu:0" if not use_gpu else None): 601 inputs_list_t = [ 602 variables_lib.Variable( 603 x, trainable=False).value() for x in inputs_list 604 ] 605 ops = _concat_state_vs_tuple_state_rnn_benchmark( 606 inputs_list_t, sequence_length, state_is_tuple=False) 607 variables_lib.global_variables_initializer().run() 608 delta_concat_state = _timer(sess, ops) 609 610 # Run with tuple states (new) 611 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 612 with ops_lib.device("/cpu:0" if not use_gpu else None): 613 inputs_list_t = [ 614 variables_lib.Variable( 615 x, trainable=False).value() for x in inputs_list 616 ] 617 ops = _concat_state_vs_tuple_state_rnn_benchmark( 618 inputs_list_t, sequence_length, state_is_tuple=True) 619 variables_lib.global_variables_initializer().run() 620 delta_tuple_state = _timer(sess, ops) 621 print("%d \t %d \t %d \t %s \t %f \t\t %f \t\t %f" % 622 (batch_size, max_time, num_units, use_gpu, delta_concat_state, 623 delta_tuple_state, delta_concat_state / delta_tuple_state)) 624 625 return delta_concat_state, delta_tuple_state 626 627 628def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length, swap_memory): 629 (unused_0, unused_1, input_size) = inputs_t.get_shape().as_list() 630 initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) 631 cell = rnn_cell_impl.LSTMCell( 632 num_units=input_size, 633 use_peepholes=True, 634 initializer=initializer, 635 state_is_tuple=False) 636 outputs, final_state = rnn.dynamic_rnn( 637 cell, 638 inputs_t, 639 sequence_length=sequence_length, 640 swap_memory=swap_memory, 641 dtype=dtypes.float32) 642 643 trainable_variables = ops_lib.get_collection( 644 ops_lib.GraphKeys.TRAINABLE_VARIABLES) 645 gradients = gradients_impl.gradients([outputs, final_state], 646 trainable_variables) 647 648 return control_flow_ops.group(final_state, outputs, *gradients) 649 650 651def dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units): 652 config = config_pb2.ConfigProto() 653 config.allow_soft_placement = True 654 655 # Set up sequence lengths 656 np.random.seed([127]) 657 sequence_length = np.random.randint(0, max_time, size=batch_size) 658 inputs_list = [ 659 np.random.randn(batch_size, num_units).astype(np.float32) 660 for _ in range(max_time) 661 ] 662 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 663 664 # No memory swap 665 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 666 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 667 ops = _dynamic_rnn_swap_memory_benchmark( 668 inputs_t, sequence_length, swap_memory=False) 669 variables_lib.global_variables_initializer().run() 670 no_swap = _timer(sess, ops) 671 672 # Memory swap 673 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 674 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 675 ops = _dynamic_rnn_swap_memory_benchmark( 676 inputs_t, sequence_length, swap_memory=True) 677 variables_lib.global_variables_initializer().run() 678 swap = _timer(sess, ops) 679 680 print("%d \t %d \t %d \t %f \t %f \t %f" % 681 (batch_size, max_time, num_units, no_swap, swap, swap / no_swap)) 682 return no_swap, swap 683 684 685def rnn_long_sequence_benchmark(batch_size, seqlen, num_units, dynamic, 686 swap_memory, nn): 687 config = config_pb2.ConfigProto() 688 config.allow_soft_placement = True 689 690 # Set up sequence lengths 691 np.random.seed([127]) 692 sequence_length = [seqlen for _ in range(batch_size)] 693 inputs_list = [ 694 np.random.randn(batch_size, num_units).astype(np.float32) 695 for _ in range(seqlen) 696 ] 697 inputs = np.dstack(inputs_list).transpose([0, 2, 1]) # batch x time x depth 698 699 for _ in range(nn): 700 if dynamic: 701 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 702 inputs_t = variables_lib.Variable(inputs, trainable=False).value() 703 ops = _dynamic_rnn_swap_memory_benchmark( 704 inputs_t, sequence_length, swap_memory=swap_memory) 705 variables_lib.global_variables_initializer().run() 706 elapsed = _timer(sess, ops) 707 else: 708 with session.Session(config=config, graph=ops_lib.Graph()) as sess: 709 inputs_list_t = [ 710 variables_lib.Variable( 711 x, trainable=False).value() for x in inputs_list 712 ] 713 ops = _static_vs_dynamic_rnn_benchmark_static(inputs_list_t, 714 sequence_length) 715 variables_lib.global_variables_initializer().run() 716 elapsed = _timer(sess, ops) 717 718 print("%d \t %d \t %d \t %s \t %f \t %f" % (batch_size, seqlen, num_units, 719 dynamic, elapsed, 720 elapsed / seqlen)) 721 722 723class BenchmarkRNN(test.Benchmark): 724 725 def benchmarkGraphCreationStaticVsDynamicLSTM(self): 726 print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM") 727 print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)") 728 for max_time in (1, 25, 50): 729 s_dt, d_dt = graph_creation_static_vs_dynamic_rnn_benchmark(max_time) 730 self.report_benchmark( 731 name="graph_creation_time_static_T%02d" % max_time, 732 iters=5, 733 wall_time=s_dt) 734 self.report_benchmark( 735 name="graph_creation_time_dynamic_T%02d" % max_time, 736 iters=5, 737 wall_time=d_dt) 738 739 def benchmarkStaticUnrollVsDynamicFlowLSTM(self): 740 print("Calculation: Static Unroll with Dynamic Flow LSTM " 741 "vs. Dynamic Unroll LSTM") 742 print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) " 743 "\t dt(dynamic)/dt(static)") 744 for batch_size in (256,): 745 for max_time in (50,): 746 for num_units in (512, 256, 128): 747 for use_gpu in (False, True): 748 s_dt, d_dt = static_vs_dynamic_rnn_benchmark(batch_size, max_time, 749 num_units, use_gpu) 750 self.report_benchmark( 751 name="static_unroll_time_T%02d_B%03d_N%03d_gpu_%s" % 752 (max_time, batch_size, num_units, use_gpu), 753 iters=20, 754 wall_time=s_dt) 755 self.report_benchmark( 756 name="dynamic_unroll_time_T%02d_B%03d_N%03d_gpu_%s" % 757 (max_time, batch_size, num_units, use_gpu), 758 iters=20, 759 wall_time=d_dt) 760 761 def benchmarkDynamicLSTMNoMemorySwapVsMemorySwap(self): 762 print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap") 763 print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap") 764 for batch_size in (256, 512): 765 for max_time in (100,): 766 for num_units in (512, 256, 128): 767 no_swap, swap = dynamic_rnn_swap_memory_benchmark(batch_size, 768 max_time, num_units) 769 self.report_benchmark( 770 name="dynamic_lstm_no_memory_swap_T%02d_B%03d_N%03d" % 771 (max_time, batch_size, num_units), 772 iters=20, 773 wall_time=no_swap) 774 self.report_benchmark( 775 name="dynamic_lstm_with_memory_swap_T%02d_B%03d_N%03d" % 776 (max_time, batch_size, num_units), 777 iters=20, 778 wall_time=swap) 779 780 def benchmarkStaticUnrollHalfSequenceLengthVsHalfUnroll(self): 781 print("Calculation: Static Unroll with Halved Sequence Length " 782 "vs. Half Static Unroll") 783 print("batch \t full_t \t units \t gpu \t dt(half_seq_len) " 784 "\t dt(unroll_half) \t dt(half_seq_len)/dt(unroll_half)") 785 for batch_size in (128,): 786 for max_time in (50,): 787 for num_units in (256,): 788 for use_gpu in (False, True): 789 s_dt, d_dt = half_seq_len_vs_unroll_half_rnn_benchmark(batch_size, 790 max_time, 791 num_units, 792 use_gpu) 793 self.report_benchmark( 794 name="half_seq_len_time_T%02d_B%03d_N%03d_gpu_%s" % 795 (max_time, batch_size, num_units, use_gpu), 796 iters=20, 797 wall_time=s_dt) 798 self.report_benchmark( 799 name="unroll_half_time_T%02d_B%03d_N%03d_gpu_%s" % 800 (max_time, batch_size, num_units, use_gpu), 801 iters=20, 802 wall_time=d_dt) 803 804 def benchmarkStaticUnrollStateConcatVsStateTuple(self): 805 print("Calculation: Static Unroll with Concatenated State " 806 "vs. Tuple State") 807 print("batch \t time \t units \t gpu \t dt(concat_state) " 808 "\t dt(tuple_state) \t dt(concat_state)/dt(tuple_state)") 809 for batch_size in ( 810 16, 811 128,): 812 for max_time in (50,): 813 for num_units in ( 814 16, 815 128,): 816 for use_gpu in (False, True): 817 c_dt, t_dt = concat_state_vs_tuple_state_rnn_benchmark(batch_size, 818 max_time, 819 num_units, 820 use_gpu) 821 self.report_benchmark( 822 name="concat_state_time_T%02d_B%03d_N%03d_gpu_%s" % 823 (max_time, batch_size, num_units, use_gpu), 824 iters=20, 825 wall_time=c_dt) 826 self.report_benchmark( 827 name="tuple_state_time_T%02d_B%03d_N%03d_gpu_%s" % 828 (max_time, batch_size, num_units, use_gpu), 829 iters=20, 830 wall_time=t_dt) 831 832 def _benchmarkDynamicLSTMMemorySwapLongSeq(self): 833 """The memory swapping test for the SOSP submission.""" 834 print("Calculation: Long LSTM Sequence") 835 print("batch \t len \t units \t dynamic \t elapsed_t \t elapsed_t/len") 836 batch_size = 512 837 seqlen = 800 838 num_units = 512 839 dynamic = True 840 swap_memory = True 841 # Some warming up. 842 if swap_memory: 843 rnn_long_sequence_benchmark(batch_size, seqlen, num_units, 844 dynamic, swap_memory, 2) 845 # Measure the performance. 846 for slen in xrange(100, 1100, 100): 847 rnn_long_sequence_benchmark(batch_size, slen, num_units, dynamic, 848 swap_memory, 3) 849 850if __name__ == "__main__": 851 test.main() 852