1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Recurrent layers backed by cuDNN. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras import constraints 26from tensorflow.python.keras import initializers 27from tensorflow.python.keras import regularizers 28from tensorflow.python.keras.engine.input_spec import InputSpec 29from tensorflow.python.keras.layers import recurrent_v2 30from tensorflow.python.keras.layers.recurrent import RNN 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import gen_cudnn_rnn_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.util.tf_export import keras_export 35 36 37class _CuDNNRNN(RNN): 38 """Private base class for CuDNNGRU and CuDNNLSTM layers. 39 40 Arguments: 41 return_sequences: Boolean. Whether to return the last output 42 in the output sequence, or the full sequence. 43 return_state: Boolean. Whether to return the last state 44 in addition to the output. 45 go_backwards: Boolean (default False). 46 If True, process the input sequence backwards and return the 47 reversed sequence. 48 stateful: Boolean (default False). If True, the last state 49 for each sample at index i in a batch will be used as initial 50 state for the sample of index i in the following batch. 51 time_major: Boolean (default False). If true, the inputs and outputs will be 52 in shape `(timesteps, batch, ...)`, whereas in the False case, it will 53 be `(batch, timesteps, ...)`. 54 """ 55 56 def __init__(self, 57 return_sequences=False, 58 return_state=False, 59 go_backwards=False, 60 stateful=False, 61 time_major=False, 62 **kwargs): 63 # We invoke the base layer's initializer directly here because we do not 64 # want to create RNN cell instance. 65 super(RNN, self).__init__(**kwargs) # pylint: disable=bad-super-call 66 self.return_sequences = return_sequences 67 self.return_state = return_state 68 self.go_backwards = go_backwards 69 self.stateful = stateful 70 self.time_major = time_major 71 self.supports_masking = False 72 self.input_spec = [InputSpec(ndim=3)] 73 if hasattr(self.cell.state_size, '__len__'): 74 state_size = self.cell.state_size 75 else: 76 state_size = [self.cell.state_size] 77 self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] 78 self.constants_spec = None 79 self._states = None 80 self._num_constants = None 81 self._num_inputs = None 82 self._vector_shape = constant_op.constant([-1]) 83 84 def call(self, inputs, mask=None, training=None, initial_state=None): 85 if isinstance(mask, list): 86 mask = mask[0] 87 if mask is not None: 88 raise ValueError('Masking is not supported for CuDNN RNNs.') 89 90 # input shape: `(samples, time (padded with zeros), input_dim)` 91 # note that the .build() method of subclasses MUST define 92 # self.input_spec and self.state_spec with complete input shapes. 93 if isinstance(inputs, list): 94 initial_state = inputs[1:] 95 inputs = inputs[0] 96 elif initial_state is not None: 97 pass 98 elif self.stateful: 99 initial_state = self.states 100 else: 101 initial_state = self.get_initial_state(inputs) 102 103 if len(initial_state) != len(self.states): 104 raise ValueError('Layer has ' + str(len(self.states)) + 105 ' states but was passed ' + str(len(initial_state)) + 106 ' initial states.') 107 108 if self.go_backwards: 109 # Reverse time axis. 110 inputs = K.reverse(inputs, 1) 111 output, states = self._process_batch(inputs, initial_state) 112 113 if self.stateful: 114 updates = [] 115 for i in range(len(states)): 116 updates.append(state_ops.assign(self.states[i], states[i])) 117 self.add_update(updates, inputs) 118 119 if self.return_state: 120 return [output] + states 121 else: 122 return output 123 124 def get_config(self): 125 config = { 126 'return_sequences': self.return_sequences, 127 'return_state': self.return_state, 128 'go_backwards': self.go_backwards, 129 'stateful': self.stateful, 130 'time_major': self.time_major, 131 } 132 base_config = super( # pylint: disable=bad-super-call 133 RNN, self).get_config() 134 return dict(list(base_config.items()) + list(config.items())) 135 136 @classmethod 137 def from_config(cls, config): 138 return cls(**config) 139 140 @property 141 def trainable_weights(self): 142 if self.trainable and self.built: 143 return [self.kernel, self.recurrent_kernel, self.bias] 144 return [] 145 146 @property 147 def non_trainable_weights(self): 148 if not self.trainable and self.built: 149 return [self.kernel, self.recurrent_kernel, self.bias] 150 return [] 151 152 @property 153 def losses(self): 154 return super(RNN, self).losses 155 156 def get_losses_for(self, inputs=None): 157 return super( # pylint: disable=bad-super-call 158 RNN, self).get_losses_for(inputs=inputs) 159 160 161@keras_export(v1=['keras.layers.CuDNNGRU']) 162class CuDNNGRU(_CuDNNRNN): 163 """Fast GRU implementation backed by cuDNN. 164 165 More information about cuDNN can be found on the [NVIDIA 166 developer website](https://developer.nvidia.com/cudnn). 167 Can only be run on GPU. 168 169 Arguments: 170 units: Positive integer, dimensionality of the output space. 171 kernel_initializer: Initializer for the `kernel` weights matrix, used for 172 the linear transformation of the inputs. 173 recurrent_initializer: Initializer for the `recurrent_kernel` weights 174 matrix, used for the linear transformation of the recurrent state. 175 bias_initializer: Initializer for the bias vector. 176 kernel_regularizer: Regularizer function applied to the `kernel` weights 177 matrix. 178 recurrent_regularizer: Regularizer function applied to the 179 `recurrent_kernel` weights matrix. 180 bias_regularizer: Regularizer function applied to the bias vector. 181 activity_regularizer: Regularizer function applied to the output of the 182 layer (its "activation"). 183 kernel_constraint: Constraint function applied to the `kernel` weights 184 matrix. 185 recurrent_constraint: Constraint function applied to the 186 `recurrent_kernel` weights matrix. 187 bias_constraint: Constraint function applied to the bias vector. 188 return_sequences: Boolean. Whether to return the last output in the output 189 sequence, or the full sequence. 190 return_state: Boolean. Whether to return the last state in addition to the 191 output. 192 go_backwards: Boolean (default False). If True, process the input sequence 193 backwards and return the reversed sequence. 194 stateful: Boolean (default False). If True, the last state for each sample 195 at index i in a batch will be used as initial state for the sample of 196 index i in the following batch. 197 """ 198 199 def __init__(self, 200 units, 201 kernel_initializer='glorot_uniform', 202 recurrent_initializer='orthogonal', 203 bias_initializer='zeros', 204 kernel_regularizer=None, 205 recurrent_regularizer=None, 206 bias_regularizer=None, 207 activity_regularizer=None, 208 kernel_constraint=None, 209 recurrent_constraint=None, 210 bias_constraint=None, 211 return_sequences=False, 212 return_state=False, 213 go_backwards=False, 214 stateful=False, 215 **kwargs): 216 self.units = units 217 cell_spec = collections.namedtuple('cell', 'state_size') 218 self._cell = cell_spec(state_size=self.units) 219 super(CuDNNGRU, self).__init__( 220 return_sequences=return_sequences, 221 return_state=return_state, 222 go_backwards=go_backwards, 223 stateful=stateful, 224 **kwargs) 225 226 self.kernel_initializer = initializers.get(kernel_initializer) 227 self.recurrent_initializer = initializers.get(recurrent_initializer) 228 self.bias_initializer = initializers.get(bias_initializer) 229 230 self.kernel_regularizer = regularizers.get(kernel_regularizer) 231 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 232 self.bias_regularizer = regularizers.get(bias_regularizer) 233 self.activity_regularizer = regularizers.get(activity_regularizer) 234 235 self.kernel_constraint = constraints.get(kernel_constraint) 236 self.recurrent_constraint = constraints.get(recurrent_constraint) 237 self.bias_constraint = constraints.get(bias_constraint) 238 239 @property 240 def cell(self): 241 return self._cell 242 243 def build(self, input_shape): 244 super(CuDNNGRU, self).build(input_shape) 245 if isinstance(input_shape, list): 246 input_shape = input_shape[0] 247 input_dim = int(input_shape[-1]) 248 249 self.kernel = self.add_weight( 250 shape=(input_dim, self.units * 3), 251 name='kernel', 252 initializer=self.kernel_initializer, 253 regularizer=self.kernel_regularizer, 254 constraint=self.kernel_constraint) 255 256 self.recurrent_kernel = self.add_weight( 257 shape=(self.units, self.units * 3), 258 name='recurrent_kernel', 259 initializer=self.recurrent_initializer, 260 regularizer=self.recurrent_regularizer, 261 constraint=self.recurrent_constraint) 262 263 self.bias = self.add_weight( 264 shape=(self.units * 6,), 265 name='bias', 266 initializer=self.bias_initializer, 267 regularizer=self.bias_regularizer, 268 constraint=self.bias_constraint) 269 270 self.built = True 271 272 def _process_batch(self, inputs, initial_state): 273 if not self.time_major: 274 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 275 input_h = initial_state[0] 276 input_h = array_ops.expand_dims(input_h, axis=0) 277 278 params = recurrent_v2._canonical_to_params( # pylint: disable=protected-access 279 weights=[ 280 self.kernel[:, self.units:self.units * 2], 281 self.kernel[:, :self.units], 282 self.kernel[:, self.units * 2:], 283 self.recurrent_kernel[:, self.units:self.units * 2], 284 self.recurrent_kernel[:, :self.units], 285 self.recurrent_kernel[:, self.units * 2:], 286 ], 287 biases=[ 288 self.bias[self.units:self.units * 2], 289 self.bias[:self.units], 290 self.bias[self.units * 2:self.units * 3], 291 self.bias[self.units * 4:self.units * 5], 292 self.bias[self.units * 3:self.units * 4], 293 self.bias[self.units * 5:], 294 ], 295 shape=self._vector_shape) 296 297 outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn( 298 inputs, 299 input_h=input_h, 300 input_c=0, 301 params=params, 302 is_training=True, 303 rnn_mode='gru') 304 305 if self.stateful or self.return_state: 306 h = h[0] 307 if self.return_sequences: 308 if self.time_major: 309 output = outputs 310 else: 311 output = array_ops.transpose(outputs, perm=(1, 0, 2)) 312 else: 313 output = outputs[-1] 314 return output, [h] 315 316 def get_config(self): 317 config = { 318 'units': self.units, 319 'kernel_initializer': initializers.serialize(self.kernel_initializer), 320 'recurrent_initializer': 321 initializers.serialize(self.recurrent_initializer), 322 'bias_initializer': initializers.serialize(self.bias_initializer), 323 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 324 'recurrent_regularizer': 325 regularizers.serialize(self.recurrent_regularizer), 326 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 327 'activity_regularizer': 328 regularizers.serialize(self.activity_regularizer), 329 'kernel_constraint': constraints.serialize(self.kernel_constraint), 330 'recurrent_constraint': 331 constraints.serialize(self.recurrent_constraint), 332 'bias_constraint': constraints.serialize(self.bias_constraint) 333 } 334 base_config = super(CuDNNGRU, self).get_config() 335 return dict(list(base_config.items()) + list(config.items())) 336 337 338@keras_export(v1=['keras.layers.CuDNNLSTM']) 339class CuDNNLSTM(_CuDNNRNN): 340 """Fast LSTM implementation backed by cuDNN. 341 342 More information about cuDNN can be found on the [NVIDIA 343 developer website](https://developer.nvidia.com/cudnn). 344 Can only be run on GPU. 345 346 Arguments: 347 units: Positive integer, dimensionality of the output space. 348 kernel_initializer: Initializer for the `kernel` weights matrix, used for 349 the linear transformation of the inputs. 350 unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate 351 at initialization. Setting it to true will also force 352 `bias_initializer="zeros"`. This is recommended in [Jozefowicz et 353 al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 354 recurrent_initializer: Initializer for the `recurrent_kernel` weights 355 matrix, used for the linear transformation of the recurrent state. 356 bias_initializer: Initializer for the bias vector. 357 kernel_regularizer: Regularizer function applied to the `kernel` weights 358 matrix. 359 recurrent_regularizer: Regularizer function applied to the 360 `recurrent_kernel` weights matrix. 361 bias_regularizer: Regularizer function applied to the bias vector. 362 activity_regularizer: Regularizer function applied to the output of the 363 layer (its "activation"). 364 kernel_constraint: Constraint function applied to the `kernel` weights 365 matrix. 366 recurrent_constraint: Constraint function applied to the 367 `recurrent_kernel` weights matrix. 368 bias_constraint: Constraint function applied to the bias vector. 369 return_sequences: Boolean. Whether to return the last output. in the 370 output sequence, or the full sequence. 371 return_state: Boolean. Whether to return the last state in addition to the 372 output. 373 go_backwards: Boolean (default False). If True, process the input sequence 374 backwards and return the reversed sequence. 375 stateful: Boolean (default False). If True, the last state for each sample 376 at index i in a batch will be used as initial state for the sample of 377 index i in the following batch. 378 """ 379 380 def __init__(self, 381 units, 382 kernel_initializer='glorot_uniform', 383 recurrent_initializer='orthogonal', 384 bias_initializer='zeros', 385 unit_forget_bias=True, 386 kernel_regularizer=None, 387 recurrent_regularizer=None, 388 bias_regularizer=None, 389 activity_regularizer=None, 390 kernel_constraint=None, 391 recurrent_constraint=None, 392 bias_constraint=None, 393 return_sequences=False, 394 return_state=False, 395 go_backwards=False, 396 stateful=False, 397 **kwargs): 398 self.units = units 399 cell_spec = collections.namedtuple('cell', 'state_size') 400 self._cell = cell_spec(state_size=(self.units, self.units)) 401 super(CuDNNLSTM, self).__init__( 402 return_sequences=return_sequences, 403 return_state=return_state, 404 go_backwards=go_backwards, 405 stateful=stateful, 406 **kwargs) 407 408 self.kernel_initializer = initializers.get(kernel_initializer) 409 self.recurrent_initializer = initializers.get(recurrent_initializer) 410 self.bias_initializer = initializers.get(bias_initializer) 411 self.unit_forget_bias = unit_forget_bias 412 413 self.kernel_regularizer = regularizers.get(kernel_regularizer) 414 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 415 self.bias_regularizer = regularizers.get(bias_regularizer) 416 self.activity_regularizer = regularizers.get(activity_regularizer) 417 418 self.kernel_constraint = constraints.get(kernel_constraint) 419 self.recurrent_constraint = constraints.get(recurrent_constraint) 420 self.bias_constraint = constraints.get(bias_constraint) 421 422 @property 423 def cell(self): 424 return self._cell 425 426 def build(self, input_shape): 427 super(CuDNNLSTM, self).build(input_shape) 428 if isinstance(input_shape, list): 429 input_shape = input_shape[0] 430 input_dim = int(input_shape[-1]) 431 432 self.kernel = self.add_weight( 433 shape=(input_dim, self.units * 4), 434 name='kernel', 435 initializer=self.kernel_initializer, 436 regularizer=self.kernel_regularizer, 437 constraint=self.kernel_constraint) 438 439 self.recurrent_kernel = self.add_weight( 440 shape=(self.units, self.units * 4), 441 name='recurrent_kernel', 442 initializer=self.recurrent_initializer, 443 regularizer=self.recurrent_regularizer, 444 constraint=self.recurrent_constraint) 445 446 if self.unit_forget_bias: 447 448 def bias_initializer(_, *args, **kwargs): 449 return array_ops.concat([ 450 self.bias_initializer((self.units * 5,), *args, **kwargs), 451 initializers.Ones()((self.units,), *args, **kwargs), 452 self.bias_initializer((self.units * 2,), *args, **kwargs), 453 ], axis=0) 454 else: 455 bias_initializer = self.bias_initializer 456 self.bias = self.add_weight( 457 shape=(self.units * 8,), 458 name='bias', 459 initializer=bias_initializer, 460 regularizer=self.bias_regularizer, 461 constraint=self.bias_constraint) 462 463 self.built = True 464 465 def _process_batch(self, inputs, initial_state): 466 if not self.time_major: 467 inputs = array_ops.transpose(inputs, perm=(1, 0, 2)) 468 input_h = initial_state[0] 469 input_c = initial_state[1] 470 input_h = array_ops.expand_dims(input_h, axis=0) 471 input_c = array_ops.expand_dims(input_c, axis=0) 472 473 params = recurrent_v2._canonical_to_params( # pylint: disable=protected-access 474 weights=[ 475 self.kernel[:, :self.units], 476 self.kernel[:, self.units:self.units * 2], 477 self.kernel[:, self.units * 2:self.units * 3], 478 self.kernel[:, self.units * 3:], 479 self.recurrent_kernel[:, :self.units], 480 self.recurrent_kernel[:, self.units:self.units * 2], 481 self.recurrent_kernel[:, self.units * 2:self.units * 3], 482 self.recurrent_kernel[:, self.units * 3:], 483 ], 484 biases=[ 485 self.bias[:self.units], 486 self.bias[self.units:self.units * 2], 487 self.bias[self.units * 2:self.units * 3], 488 self.bias[self.units * 3:self.units * 4], 489 self.bias[self.units * 4:self.units * 5], 490 self.bias[self.units * 5:self.units * 6], 491 self.bias[self.units * 6:self.units * 7], 492 self.bias[self.units * 7:], 493 ], 494 shape=self._vector_shape) 495 496 outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn( 497 inputs, 498 input_h=input_h, 499 input_c=input_c, 500 params=params, 501 is_training=True) 502 503 if self.stateful or self.return_state: 504 h = h[0] 505 c = c[0] 506 if self.return_sequences: 507 if self.time_major: 508 output = outputs 509 else: 510 output = array_ops.transpose(outputs, perm=(1, 0, 2)) 511 else: 512 output = outputs[-1] 513 return output, [h, c] 514 515 def get_config(self): 516 config = { 517 'units': self.units, 518 'kernel_initializer': initializers.serialize(self.kernel_initializer), 519 'recurrent_initializer': 520 initializers.serialize(self.recurrent_initializer), 521 'bias_initializer': initializers.serialize(self.bias_initializer), 522 'unit_forget_bias': self.unit_forget_bias, 523 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 524 'recurrent_regularizer': 525 regularizers.serialize(self.recurrent_regularizer), 526 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 527 'activity_regularizer': 528 regularizers.serialize(self.activity_regularizer), 529 'kernel_constraint': constraints.serialize(self.kernel_constraint), 530 'recurrent_constraint': 531 constraints.serialize(self.recurrent_constraint), 532 'bias_constraint': constraints.serialize(self.bias_constraint) 533 } 534 base_config = super(CuDNNLSTM, self).get_config() 535 return dict(list(base_config.items()) + list(config.items())) 536