1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Recurrent layers for TF 2.0. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import uuid 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import device 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.keras import backend as K 30from tensorflow.python.keras.engine.input_spec import InputSpec 31from tensorflow.python.keras.layers import recurrent 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import gen_cudnn_rnn_ops 34from tensorflow.python.ops import state_ops 35from tensorflow.python.util.tf_export import keras_export 36 37 38# The following string constants are used by Defun approach for unified backend 39# of LSTM and GRU. 40_DEFUN_API_NAME_ATTRIBUTE = 'api_implements' 41_DEFUN_DEVICE_ATTRIBUTE = 'api_preferred_device' 42_CPU_DEVICE_NAME = 'CPU' 43_GPU_DEVICE_NAME = 'GPU' 44 45 46@keras_export('keras.layers.GRU', v1=[]) 47class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): 48 """Gated Recurrent Unit - Cho et al. 2014. 49 50 Based on available runtime hardware and constraints, this layer 51 will choose different implementations (cuDNN-based or pure-TensorFlow) 52 to maximize the performance. If a GPU is available and all 53 the arguments to the layer meet the requirement of the CuDNN kernel 54 (see below for details), the layer will use a fast cuDNN implementation. 55 56 The requirements to use the cuDNN implementation are: 57 58 1. `activation` == 'tanh' 59 2. `recurrent_activation` == 'sigmoid' 60 3. `recurrent_dropout` == 0 61 4. `unroll` is False 62 5. `use_bias` is True 63 6. `reset_after` is True 64 7. No use of masking. 65 66 There are two variants of the GRU implementation. The default one is based on 67 [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to hidden 68 state before matrix multiplication. The other one is based on 69 [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed. 70 71 The second variant is compatible with CuDNNGRU (GPU-only) and allows 72 inference on CPU. Thus it has separate biases for `kernel` and 73 `recurrent_kernel`. To use this variant, set `'reset_after'=True` and 74 `recurrent_activation='sigmoid'`. 75 76 Arguments: 77 units: Positive integer, dimensionality of the output space. 78 activation: Activation function to use. 79 Default: hyperbolic tangent (`tanh`). 80 If you pass `None`, no activation is applied 81 (ie. "linear" activation: `a(x) = x`). 82 recurrent_activation: Activation function to use 83 for the recurrent step. 84 Default: sigmoid (`sigmoid`). 85 If you pass `None`, no activation is applied 86 (ie. "linear" activation: `a(x) = x`). 87 use_bias: Boolean, whether the layer uses a bias vector. 88 kernel_initializer: Initializer for the `kernel` weights matrix, 89 used for the linear transformation of the inputs. 90 recurrent_initializer: Initializer for the `recurrent_kernel` 91 weights matrix, 92 used for the linear transformation of the recurrent state. 93 bias_initializer: Initializer for the bias vector. 94 kernel_regularizer: Regularizer function applied to 95 the `kernel` weights matrix. 96 recurrent_regularizer: Regularizer function applied to 97 the `recurrent_kernel` weights matrix. 98 bias_regularizer: Regularizer function applied to the bias vector. 99 activity_regularizer: Regularizer function applied to 100 the output of the layer (its "activation").. 101 kernel_constraint: Constraint function applied to 102 the `kernel` weights matrix. 103 recurrent_constraint: Constraint function applied to 104 the `recurrent_kernel` weights matrix. 105 bias_constraint: Constraint function applied to the bias vector. 106 dropout: Float between 0 and 1. 107 Fraction of the units to drop for the linear transformation of the inputs. 108 recurrent_dropout: Float between 0 and 1. 109 Fraction of the units to drop for 110 the linear transformation of the recurrent state. 111 implementation: Implementation mode, either 1 or 2. 112 Mode 1 will structure its operations as a larger number of 113 smaller dot products and additions, whereas mode 2 will 114 batch them into fewer, larger operations. These modes will 115 have different performance profiles on different hardware and 116 for different applications. 117 return_sequences: Boolean. Whether to return the last output 118 in the output sequence, or the full sequence. 119 return_state: Boolean. Whether to return the last state 120 in addition to the output. 121 go_backwards: Boolean (default False). 122 If True, process the input sequence backwards and return the 123 reversed sequence. 124 stateful: Boolean (default False). If True, the last state 125 for each sample at index i in a batch will be used as initial 126 state for the sample of index i in the following batch. 127 unroll: Boolean (default False). 128 If True, the network will be unrolled, 129 else a symbolic loop will be used. 130 Unrolling can speed-up a RNN, 131 although it tends to be more memory-intensive. 132 Unrolling is only suitable for short sequences. 133 reset_after: GRU convention (whether to apply reset gate after or 134 before matrix multiplication). False = "before", 135 True = "after" (default and CuDNN compatible). 136 137 Call arguments: 138 inputs: A 3D tensor. 139 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 140 a given timestep should be masked. 141 training: Python boolean indicating whether the layer should behave in 142 training mode or in inference mode. This argument is passed to the cell 143 when calling it. This is only relevant if `dropout` or 144 `recurrent_dropout` is used. 145 initial_state: List of initial state tensors to be passed to the first 146 call of the cell. 147 """ 148 149 def __init__(self, 150 units, 151 activation='tanh', 152 recurrent_activation='sigmoid', 153 use_bias=True, 154 kernel_initializer='glorot_uniform', 155 recurrent_initializer='orthogonal', 156 bias_initializer='zeros', 157 kernel_regularizer=None, 158 recurrent_regularizer=None, 159 bias_regularizer=None, 160 activity_regularizer=None, 161 kernel_constraint=None, 162 recurrent_constraint=None, 163 bias_constraint=None, 164 dropout=0., 165 recurrent_dropout=0., 166 implementation=1, 167 return_sequences=False, 168 return_state=False, 169 go_backwards=False, 170 stateful=False, 171 unroll=False, 172 time_major=False, 173 reset_after=True, 174 **kwargs): 175 # return_runtime is a flag for testing, which shows the real backend 176 # implementation chosen by grappler in graph mode. 177 self._return_runtime = kwargs.pop('return_runtime', False) 178 179 super(GRU, self).__init__( 180 units, 181 activation=activation, 182 recurrent_activation=recurrent_activation, 183 use_bias=use_bias, 184 kernel_initializer=kernel_initializer, 185 recurrent_initializer=recurrent_initializer, 186 bias_initializer=bias_initializer, 187 kernel_regularizer=kernel_regularizer, 188 recurrent_regularizer=recurrent_regularizer, 189 bias_regularizer=bias_regularizer, 190 activity_regularizer=activity_regularizer, 191 kernel_constraint=kernel_constraint, 192 recurrent_constraint=recurrent_constraint, 193 bias_constraint=bias_constraint, 194 dropout=dropout, 195 recurrent_dropout=recurrent_dropout, 196 implementation=implementation, 197 return_sequences=return_sequences, 198 return_state=return_state, 199 go_backwards=go_backwards, 200 stateful=stateful, 201 unroll=unroll, 202 time_major=time_major, 203 reset_after=reset_after, 204 **kwargs) 205 # CuDNN uses following setting by default and not configurable. 206 self.could_use_cudnn = ( 207 activation == 'tanh' and recurrent_activation == 'sigmoid' and 208 recurrent_dropout == 0 and not unroll and use_bias and 209 reset_after) 210 211 def call(self, inputs, mask=None, training=None, initial_state=None): 212 # GRU does not support constants. Ignore it during process. 213 inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None) 214 215 if isinstance(mask, list): 216 mask = mask[0] 217 218 input_shape = K.int_shape(inputs) 219 timesteps = input_shape[0] if self.time_major else input_shape[1] 220 221 if mask is not None or not self.could_use_cudnn: 222 # CuDNN does not support masking, fall back to use the normal GRU. 223 kwargs = {'training': training} 224 225 def step(cell_inputs, cell_states): 226 return self.cell.call(cell_inputs, cell_states, **kwargs) 227 228 last_output, outputs, states = K.rnn( 229 step, 230 inputs, 231 initial_state, 232 constants=None, 233 go_backwards=self.go_backwards, 234 mask=mask, 235 unroll=self.unroll, 236 input_length=timesteps, 237 time_major=self.time_major, 238 zero_output_for_mask=self.zero_output_for_mask) 239 # This is a dummy tensor for testing purpose. 240 runtime = _runtime('unknown') 241 else: 242 last_output, outputs, runtime, states = self._defun_gru_call( 243 inputs, initial_state, training) 244 245 if self.stateful: 246 updates = [state_ops.assign(self.states[0], states[0])] 247 self.add_update(updates, inputs) 248 249 if self.return_sequences: 250 output = outputs 251 else: 252 output = last_output 253 254 if self.return_state: 255 return [output] + list(states) 256 elif self._return_runtime: 257 return output, runtime 258 else: 259 return output 260 261 def _defun_gru_call(self, inputs, initial_state, training): 262 # Use the new defun approach for backend implementation swap. 263 # Note that different implementations need to have same function 264 # signature, eg, the tensor parameters need to have same shape and dtypes. 265 if self.go_backwards: 266 # Reverse time axis. 267 inputs = K.reverse(inputs, 0 if self.time_major else 1) 268 269 self.reset_dropout_mask() 270 dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) 271 if dropout_mask is not None: 272 inputs *= dropout_mask[0] 273 if context.executing_eagerly(): 274 device_type = _get_context_device_type() 275 if device_type == _GPU_DEVICE_NAME or ( 276 device_type is None and context.num_gpus() > 0): 277 # Under eager context, check the device placement and prefer the 278 # GPU implementation when GPU is available. 279 last_output, outputs, new_h, runtime = cudnn_gru( 280 inputs=inputs, 281 init_h=initial_state[0], 282 kernel=self.cell.kernel, 283 recurrent_kernel=self.cell.recurrent_kernel, 284 bias=self.cell.bias, 285 time_major=self.time_major) 286 else: 287 last_output, outputs, new_h, runtime = standard_gru( 288 inputs=inputs, 289 init_h=initial_state[0], 290 kernel=self.cell.kernel, 291 recurrent_kernel=self.cell.recurrent_kernel, 292 bias=self.cell.bias, 293 activation=self.activation, 294 recurrent_activation=self.recurrent_activation, 295 time_major=self.time_major) 296 else: 297 api_name = 'gru_' + str(uuid.uuid4()) 298 defun_standard_gru = _generate_defun_backend( 299 api_name, _CPU_DEVICE_NAME, standard_gru) 300 defun_cudnn_gru = _generate_defun_backend( 301 api_name, _GPU_DEVICE_NAME, cudnn_gru) 302 # Call the normal GRU impl and register the CuDNN impl function. The 303 # grappler will kick in during session execution to optimize the graph. 304 last_output, outputs, new_h, runtime = defun_standard_gru( 305 inputs=inputs, 306 init_h=initial_state[0], 307 kernel=self.cell.kernel, 308 recurrent_kernel=self.cell.recurrent_kernel, 309 bias=self.cell.bias, 310 activation=self.activation, 311 recurrent_activation=self.recurrent_activation, 312 time_major=self.time_major) 313 314 function.register(defun_cudnn_gru, inputs, initial_state[0], 315 self.cell.kernel, self.cell.recurrent_kernel, 316 self.cell.bias, self.time_major) 317 states = [new_h] 318 return last_output, outputs, runtime, states 319 320 321def standard_gru(inputs, init_h, kernel, recurrent_kernel, bias, activation, 322 recurrent_activation, time_major): 323 """GRU with standard kernel implementation. 324 325 This implementation can be run on all types of hardware. 326 327 This implementation lifts out all the layer weights and make them function 328 parameters. It has same number of tensor input params as the CuDNN 329 counterpart. The RNN step logic has been simplified, eg dropout and mask is 330 removed since CuDNN implementation does not support that. 331 332 Arguments: 333 inputs: input tensor of GRU layer. 334 init_h: initial state tensor for the cell output. 335 kernel: weights for cell kernel. 336 recurrent_kernel: weights for cell recurrent kernel. 337 bias: weights for cell kernel bias and recurrent bias. The bias contains the 338 combined input_bias and recurrent_bias. 339 activation: Activation function to use for output. 340 recurrent_activation: Activation function to use for hidden recurrent state. 341 time_major: boolean, whether the inputs are in the format of 342 [time, batch, feature] or [batch, time, feature]. 343 344 Returns: 345 last_output: output tensor for the last timestep, which has shape 346 [batch, units]. 347 outputs: output tensor for all timesteps, which has shape 348 [batch, time, units]. 349 state_0: the cell output, which has same shape as init_h. 350 runtime: constant string tensor which indicate real runtime hardware. This 351 value is for testing purpose and should be used by user. 352 """ 353 input_shape = K.int_shape(inputs) 354 timesteps = input_shape[0] if time_major else input_shape[1] 355 356 input_bias, recurrent_bias = array_ops.unstack(bias) 357 358 def step(cell_inputs, cell_states): 359 """Step function that will be used by Keras RNN backend.""" 360 h_tm1 = cell_states[0] 361 362 # inputs projected by all gate matrices at once 363 matrix_x = K.dot(cell_inputs, kernel) 364 matrix_x = K.bias_add(matrix_x, input_bias) 365 366 x_z, x_r, x_h = array_ops.split(matrix_x, 3, axis=1) 367 368 # hidden state projected by all gate matrices at once 369 matrix_inner = K.dot(h_tm1, recurrent_kernel) 370 matrix_inner = K.bias_add(matrix_inner, recurrent_bias) 371 372 recurrent_z, recurrent_r, recurrent_h = array_ops.split(matrix_inner, 3, 373 axis=1) 374 z = recurrent_activation(x_z + recurrent_z) 375 r = recurrent_activation(x_r + recurrent_r) 376 hh = activation(x_h + r * recurrent_h) 377 378 # previous and candidate state mixed by update gate 379 h = z * h_tm1 + (1 - z) * hh 380 return h, [h] 381 382 last_output, outputs, new_states = K.rnn( 383 step, 384 inputs, [init_h], 385 constants=None, 386 unroll=False, 387 time_major=time_major, 388 input_length=timesteps) 389 return last_output, outputs, new_states[0], _runtime('cpu') 390 391 392def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, time_major): 393 """GRU with CuDNN implementation which is only available for GPU.""" 394 if not time_major: 395 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 396 init_h = array_ops.expand_dims(init_h, axis=0) 397 398 weights = array_ops.split(kernel, 3, axis=1) 399 weights += array_ops.split(recurrent_kernel, 3, axis=1) 400 # Note that the bias was initialized as shape (2, 3 * units), flat it into 401 # (6 * units) 402 bias = array_ops.split(K.flatten(bias), 6) 403 # Note that the gate order for CuDNN is different from the canonical format. 404 # canonical format is [z, r, h], whereas CuDNN is [r, z, h]. The swap need to 405 # be done for kernel, recurrent_kernel, input_bias, recurrent_bias. 406 # z is update gate weights. 407 # r is reset gate weights. 408 # h is output gate weights. 409 weights[0], weights[1] = weights[1], weights[0] 410 weights[3], weights[4] = weights[4], weights[3] 411 bias[0], bias[1] = bias[1], bias[0] 412 bias[3], bias[4] = bias[4], bias[3] 413 414 params = _canonical_to_params( 415 weights=weights, 416 biases=bias, 417 shape=constant_op.constant([-1]), 418 transpose_weights=True) 419 420 outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn( 421 inputs, 422 input_h=init_h, 423 input_c=0, 424 params=params, 425 is_training=True, 426 rnn_mode='gru') 427 last_output = outputs[-1] 428 if not time_major: 429 outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 430 h = h[0] 431 return last_output, outputs, h, _runtime('cudnn') 432 433 434@keras_export('keras.layers.LSTM', v1=[]) 435class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): 436 """Long Short-Term Memory layer - Hochreiter 1997. 437 438 Based on available runtime hardware and constraints, this layer 439 will choose different implementations (cuDNN-based or pure-TensorFlow) 440 to maximize the performance. If a GPU is available and all 441 the arguments to the layer meet the requirement of the CuDNN kernel 442 (see below for details), the layer will use a fast cuDNN implementation. 443 444 The requirements to use the cuDNN implementation are: 445 446 1. `activation` == 'tanh' 447 2. `recurrent_activation` == 'sigmoid' 448 3. `recurrent_dropout` == 0 449 4. `unroll` is False 450 5. `use_bias` is True 451 7. No use of masking. 452 453 Arguments: 454 units: Positive integer, dimensionality of the output space. 455 activation: Activation function to use. 456 Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation 457 is applied (ie. "linear" activation: `a(x) = x`). 458 recurrent_activation: Activation function to use for the recurrent step. 459 Default: sigmoid (`sigmoid`). If you pass `None`, no activation is 460 applied (ie. "linear" activation: `a(x) = x`). 461 use_bias: Boolean, whether the layer uses a bias vector. 462 kernel_initializer: Initializer for the `kernel` weights matrix, used for 463 the linear transformation of the inputs.. 464 recurrent_initializer: Initializer for the `recurrent_kernel` weights 465 matrix, used for the linear transformation of the recurrent state.. 466 bias_initializer: Initializer for the bias vector. 467 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at 468 initialization. Setting it to true will also force 469 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 470 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). 471 kernel_regularizer: Regularizer function applied to the `kernel` weights 472 matrix. 473 recurrent_regularizer: Regularizer function applied to the 474 `recurrent_kernel` weights matrix. 475 bias_regularizer: Regularizer function applied to the bias vector. 476 activity_regularizer: Regularizer function applied to the output of the 477 layer (its "activation").. 478 kernel_constraint: Constraint function applied to the `kernel` weights 479 matrix. 480 recurrent_constraint: Constraint function applied to the `recurrent_kernel` 481 weights matrix. 482 bias_constraint: Constraint function applied to the bias vector. 483 dropout: Float between 0 and 1. Fraction of the units to drop for the linear 484 transformation of the inputs. 485 recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for 486 the linear transformation of the recurrent state. 487 implementation: Implementation mode, either 1 or 2. Mode 1 will structure 488 its operations as a larger number of smaller dot products and additions, 489 whereas mode 2 will batch them into fewer, larger operations. These modes 490 will have different performance profiles on different hardware and for 491 different applications. 492 return_sequences: Boolean. Whether to return the last output. in the output 493 sequence, or the full sequence. 494 return_state: Boolean. Whether to return the last state in addition to the 495 output. 496 go_backwards: Boolean (default False). If True, process the input sequence 497 backwards and return the reversed sequence. 498 stateful: Boolean (default False). If True, the last state for each sample 499 at index i in a batch will be used as initial state for the sample of 500 index i in the following batch. 501 unroll: Boolean (default False). If True, the network will be unrolled, else 502 a symbolic loop will be used. Unrolling can speed-up a RNN, although it 503 tends to be more memory-intensive. Unrolling is only suitable for short 504 sequences. 505 506 Call arguments: 507 inputs: A 3D tensor. 508 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 509 a given timestep should be masked. 510 training: Python boolean indicating whether the layer should behave in 511 training mode or in inference mode. This argument is passed to the cell 512 when calling it. This is only relevant if `dropout` or 513 `recurrent_dropout` is used. 514 initial_state: List of initial state tensors to be passed to the first 515 call of the cell. 516 """ 517 518 def __init__(self, 519 units, 520 activation='tanh', 521 recurrent_activation='sigmoid', 522 use_bias=True, 523 kernel_initializer='glorot_uniform', 524 recurrent_initializer='orthogonal', 525 bias_initializer='zeros', 526 unit_forget_bias=True, 527 kernel_regularizer=None, 528 recurrent_regularizer=None, 529 bias_regularizer=None, 530 activity_regularizer=None, 531 kernel_constraint=None, 532 recurrent_constraint=None, 533 bias_constraint=None, 534 dropout=0., 535 recurrent_dropout=0., 536 implementation=1, 537 return_sequences=False, 538 return_state=False, 539 go_backwards=False, 540 stateful=False, 541 time_major=False, 542 unroll=False, 543 **kwargs): 544 # return_runtime is a flag for testing, which shows the real backend 545 # implementation chosen by grappler in graph mode. 546 self.return_runtime = kwargs.pop('return_runtime', False) 547 548 super(LSTM, self).__init__( 549 units, 550 activation=activation, 551 recurrent_activation=recurrent_activation, 552 use_bias=use_bias, 553 kernel_initializer=kernel_initializer, 554 recurrent_initializer=recurrent_initializer, 555 bias_initializer=bias_initializer, 556 unit_forget_bias=unit_forget_bias, 557 kernel_regularizer=kernel_regularizer, 558 recurrent_regularizer=recurrent_regularizer, 559 bias_regularizer=bias_regularizer, 560 activity_regularizer=activity_regularizer, 561 kernel_constraint=kernel_constraint, 562 recurrent_constraint=recurrent_constraint, 563 bias_constraint=bias_constraint, 564 dropout=dropout, 565 recurrent_dropout=recurrent_dropout, 566 implementation=implementation, 567 return_sequences=return_sequences, 568 return_state=return_state, 569 go_backwards=go_backwards, 570 stateful=stateful, 571 time_major=time_major, 572 unroll=unroll, 573 **kwargs) 574 575 self.state_spec = [ 576 InputSpec(shape=(None, dim)) for dim in (self.units, self.units) 577 ] 578 self.could_use_cudnn = ( 579 activation == 'tanh' and recurrent_activation == 'sigmoid' and 580 recurrent_dropout == 0 and not unroll and use_bias) 581 582 def call(self, inputs, mask=None, training=None, initial_state=None): 583 # LSTM does not support constants. Ignore it during process. 584 inputs, initial_state, _ = self._process_inputs(inputs, initial_state, None) 585 586 if isinstance(mask, list): 587 mask = mask[0] 588 589 input_shape = K.int_shape(inputs) 590 timesteps = input_shape[0] if self.time_major else input_shape[1] 591 592 if mask is not None or not self.could_use_cudnn: 593 # CuDNN does not support masking, fall back to use the normal LSTM. 594 kwargs = {'training': training} 595 596 def step(inputs, states): 597 return self.cell.call(inputs, states, **kwargs) 598 599 last_output, outputs, states = K.rnn( 600 step, 601 inputs, 602 initial_state, 603 constants=None, 604 go_backwards=self.go_backwards, 605 mask=mask, 606 unroll=self.unroll, 607 input_length=timesteps, 608 time_major=self.time_major, 609 zero_output_for_mask=self.zero_output_for_mask) 610 runtime = _runtime('unknown') 611 else: 612 # Use the new defun approach for backend implementation swap. 613 # Note that different implementations need to have same function 614 # signature, eg, the tensor parameters need to have same shape and dtypes. 615 # Since the CuDNN has an extra set of bias, those bias will be passed to 616 # both normal and CuDNN implementations. 617 if self.go_backwards: 618 # Reverse time axis. 619 inputs = K.reverse(inputs, 0 if self.time_major else 1) 620 621 self.reset_dropout_mask() 622 dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 623 if dropout_mask is not None: 624 inputs *= dropout_mask[0] 625 626 if context.executing_eagerly(): 627 device_type = _get_context_device_type() 628 if device_type == _GPU_DEVICE_NAME or ( 629 device_type is None and context.num_gpus() > 0): 630 # Under eager context, check the device placement and prefer the 631 # GPU implementation when GPU is available. 632 last_output, outputs, new_h, new_c, runtime = cudnn_lstm( 633 inputs, initial_state[0], initial_state[1], self.cell.kernel, 634 self.cell.recurrent_kernel, self.cell.bias, self.time_major) 635 else: 636 last_output, outputs, new_h, new_c, runtime = standard_lstm( 637 inputs, initial_state[0], initial_state[1], self.cell.kernel, 638 self.cell.recurrent_kernel, self.cell.bias, self.activation, 639 self.recurrent_activation, self.time_major) 640 else: 641 # Each time a `tf.function` is called, we will give it a unique 642 # identifiable API name, so that Grappler won't get confused when it 643 # sees multiple LSTM layers added into same graph, and it will be able 644 # to pair up the different implementations across them. 645 api_name = 'lstm_' + str(uuid.uuid4()) 646 defun_standard_lstm = _generate_defun_backend( 647 api_name, _CPU_DEVICE_NAME, standard_lstm) 648 defun_cudnn_lstm = _generate_defun_backend( 649 api_name, _GPU_DEVICE_NAME, cudnn_lstm) 650 651 # Call the normal LSTM impl and register the CuDNN impl function. The 652 # grappler will kick in during session execution to optimize the graph. 653 last_output, outputs, new_h, new_c, runtime = defun_standard_lstm( 654 inputs, initial_state[0], initial_state[1], self.cell.kernel, 655 self.cell.recurrent_kernel, self.cell.bias, self.activation, 656 self.recurrent_activation, self.time_major) 657 658 function.register(defun_cudnn_lstm, inputs, initial_state[0], 659 initial_state[1], self.cell.kernel, 660 self.cell.recurrent_kernel, self.cell.bias, 661 self.time_major) 662 states = [new_h, new_c] 663 664 if self.stateful: 665 updates = [] 666 for i in range(len(states)): 667 updates.append(state_ops.assign(self.states[i], states[i])) 668 self.add_update(updates, inputs) 669 670 if self.return_sequences: 671 output = outputs 672 else: 673 output = last_output 674 675 if self.return_state: 676 return [output] + list(states) 677 elif self.return_runtime: 678 return output, runtime 679 else: 680 return output 681 682 683def _canonical_to_params(weights, biases, shape, transpose_weights=False): 684 """Utility function convert variable to CuDNN compatible parameter. 685 686 Note that Keras weights for kernels are different from the CuDNN format. Eg.: 687 688 ``` 689 Keras CuDNN 690 [[0, 1, 2], <---> [[0, 2, 4], 691 [3, 4, 5]] [1, 3, 5]] 692 ``` 693 694 If the input weights need to be in a unified format, then set 695 `transpose_weights=True` to convert the weights. 696 697 Args: 698 weights: list of weights for the individual kernels and recurrent kernels. 699 biases: list of biases for individual gate. 700 shape: the shape for the converted variables that will be feed to CuDNN. 701 transpose_weights: boolean, whether to transpose the weights. 702 703 Returns: 704 The converted weights that can be feed to CuDNN ops as param. 705 """ 706 def convert(w): 707 return array_ops.transpose(w) if transpose_weights else w 708 709 weights = [array_ops.reshape(convert(x), shape) for x in weights] 710 biases = [array_ops.reshape(x, shape) for x in biases] 711 return array_ops.concat(weights + biases, axis=0) 712 713 714def standard_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, 715 activation, recurrent_activation, time_major): 716 """LSTM with standard kernel implementation. 717 718 This implementation can be run on all types for hardware. 719 720 This implementation lifts out all the layer weights and make them function 721 parameters. It has same number of tensor input params as the CuDNN 722 counterpart. The RNN step logic has been simplified, eg dropout and mask is 723 removed since CuDNN implementation does not support that. 724 725 Note that the first half of the bias tensor should be ignored by this impl. 726 The CuDNN impl need an extra set of input gate bias. In order to make the both 727 function take same shape of parameter, that extra set of bias is also feed 728 here. 729 730 Args: 731 inputs: input tensor of LSTM layer. 732 init_h: initial state tensor for the cell output. 733 init_c: initial state tensor for the cell hidden state. 734 kernel: weights for cell kernel. 735 recurrent_kernel: weights for cell recurrent kernel. 736 bias: weights for cell kernel bias and recurrent bias. Only recurrent bias 737 is used in this case. 738 activation: Activation function to use for output. 739 recurrent_activation: Activation function to use for hidden recurrent state. 740 time_major: boolean, whether the inputs are in the format of 741 [time, batch, feature] or [batch, time, feature]. 742 743 Returns: 744 last_output: output tensor for the last timestep, which has shape 745 [batch, units]. 746 outputs: output tensor for all timesteps, which has shape 747 [batch, time, units]. 748 state_0: the cell output, which has same shape as init_h. 749 state_1: the cell hidden state, which has same shape as init_c. 750 runtime: constant string tensor which indicate real runtime hardware. This 751 value is for testing purpose and should be used by user. 752 """ 753 input_shape = K.int_shape(inputs) 754 timesteps = input_shape[0] if time_major else input_shape[1] 755 756 def step(cell_inputs, cell_states): 757 """Step function that will be used by Keras RNN backend.""" 758 h_tm1 = cell_states[0] # previous memory state 759 c_tm1 = cell_states[1] # previous carry state 760 761 z = K.dot(cell_inputs, kernel) 762 z += K.dot(h_tm1, recurrent_kernel) 763 z = K.bias_add(z, bias) 764 765 z0, z1, z2, z3 = array_ops.split(z, 4, axis=1) 766 767 i = recurrent_activation(z0) 768 f = recurrent_activation(z1) 769 c = f * c_tm1 + i * activation(z2) 770 o = recurrent_activation(z3) 771 772 h = o * activation(c) 773 return h, [h, c] 774 775 last_output, outputs, new_states = K.rnn( 776 step, 777 inputs, [init_h, init_c], 778 constants=None, 779 unroll=False, 780 time_major=time_major, 781 input_length=timesteps) 782 return last_output, outputs, new_states[0], new_states[1], _runtime('cpu') 783 784 785def cudnn_lstm(inputs, input_h, input_c, kernel, recurrent_kernel, bias, 786 time_major): 787 """LSTM with CuDNN implementation which is only available for GPU.""" 788 if not time_major: 789 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 790 input_h = array_ops.expand_dims(input_h, axis=0) 791 input_c = array_ops.expand_dims(input_c, axis=0) 792 793 weights = array_ops.split(kernel, 4, axis=1) 794 weights += array_ops.split(recurrent_kernel, 4, axis=1) 795 # CuDNN has an extra set of bias for inputs, we disable them (setting to 0), 796 # so that mathematically it is same as the canonical LSTM implementation. 797 full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0) 798 799 params = _canonical_to_params( 800 weights=weights, 801 biases=array_ops.split(full_bias, 8), 802 shape=constant_op.constant([-1]), 803 transpose_weights=True) 804 805 outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn( 806 inputs, input_h=input_h, input_c=input_c, params=params, is_training=True) 807 last_output = outputs[-1] 808 if not time_major: 809 outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 810 h = h[0] 811 c = c[0] 812 813 return last_output, outputs, h, c, _runtime('cudnn') 814 815 816def _generate_defun_backend(unique_api_name, preferred_device, func): 817 function_attributes = { 818 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, 819 _DEFUN_DEVICE_ATTRIBUTE: preferred_device, 820 } 821 return function.defun_with_attributes(func=func, 822 attributes=function_attributes) 823 824 825def _get_context_device_type(): 826 """Parse the current context and return the device type, eg CPU/GPU.""" 827 current_device = context.context().device_name 828 if current_device is None: 829 return None 830 return device.DeviceSpec.from_string(current_device).device_type 831 832 833def _runtime(runtime_name): 834 with ops.device('/cpu:0'): 835 return constant_op.constant( 836 runtime_name, dtype=dtypes.string, name='runtime') 837