1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for constructing RNN Cells.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import math 22 23from tensorflow.contrib.compiler import jit 24from tensorflow.contrib.layers.python.layers import layers 25from tensorflow.contrib.rnn.python.ops import core_rnn_cell 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import op_def_registry 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.layers import base as base_layer 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import clip_ops 33from tensorflow.python.ops import init_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn_impl # pylint: disable=unused-import 36from tensorflow.python.ops import nn_ops 37from tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import 38from tensorflow.python.ops import random_ops 39from tensorflow.python.ops import rnn_cell_impl 40from tensorflow.python.ops import variable_scope as vs 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.util import nest 43 44 45def _get_concat_variable(name, shape, dtype, num_shards): 46 """Get a sharded variable concatenated into one tensor.""" 47 sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 48 if len(sharded_variable) == 1: 49 return sharded_variable[0] 50 51 concat_name = name + "/concat" 52 concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 53 for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 54 if value.name == concat_full_name: 55 return value 56 57 concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 58 ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) 59 return concat_variable 60 61 62def _get_sharded_variable(name, shape, dtype, num_shards): 63 """Get a list of sharded variables with the given dtype.""" 64 if num_shards > shape[0]: 65 raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, 66 num_shards)) 67 unit_shard_size = int(math.floor(shape[0] / num_shards)) 68 remaining_rows = shape[0] - unit_shard_size * num_shards 69 70 shards = [] 71 for i in range(num_shards): 72 current_size = unit_shard_size 73 if i < remaining_rows: 74 current_size += 1 75 shards.append( 76 vs.get_variable( 77 name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) 78 return shards 79 80 81def _norm(g, b, inp, scope): 82 shape = inp.get_shape()[-1:] 83 gamma_init = init_ops.constant_initializer(g) 84 beta_init = init_ops.constant_initializer(b) 85 with vs.variable_scope(scope): 86 # Initialize beta and gamma for use by layer_norm. 87 vs.get_variable("gamma", shape=shape, initializer=gamma_init) 88 vs.get_variable("beta", shape=shape, initializer=beta_init) 89 normalized = layers.layer_norm(inp, reuse=True, scope=scope) 90 return normalized 91 92 93class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): 94 """Long short-term memory unit (LSTM) recurrent network cell. 95 96 The default non-peephole implementation is based on: 97 98 http://www.bioinf.jku.at/publications/older/2604.pdf 99 100 S. Hochreiter and J. Schmidhuber. 101 "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 102 103 The peephole implementation is based on: 104 105 https://research.google.com/pubs/archive/43905.pdf 106 107 Hasim Sak, Andrew Senior, and Francoise Beaufays. 108 "Long short-term memory recurrent neural network architectures for 109 large scale acoustic modeling." INTERSPEECH, 2014. 110 111 The coupling of input and forget gate is based on: 112 113 http://arxiv.org/pdf/1503.04069.pdf 114 115 Greff et al. "LSTM: A Search Space Odyssey" 116 117 The class uses optional peep-hole connections, and an optional projection 118 layer. 119 Layer normalization implementation is based on: 120 121 https://arxiv.org/abs/1607.06450. 122 123 "Layer Normalization" 124 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 125 126 and is applied before the internal nonlinearities. 127 128 """ 129 130 def __init__(self, 131 num_units, 132 use_peepholes=False, 133 initializer=None, 134 num_proj=None, 135 proj_clip=None, 136 num_unit_shards=1, 137 num_proj_shards=1, 138 forget_bias=1.0, 139 state_is_tuple=True, 140 activation=math_ops.tanh, 141 reuse=None, 142 layer_norm=False, 143 norm_gain=1.0, 144 norm_shift=0.0): 145 """Initialize the parameters for an LSTM cell. 146 147 Args: 148 num_units: int, The number of units in the LSTM cell 149 use_peepholes: bool, set True to enable diagonal/peephole connections. 150 initializer: (optional) The initializer to use for the weight and 151 projection matrices. 152 num_proj: (optional) int, The output dimensionality for the projection 153 matrices. If None, no projection is performed. 154 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 155 provided, then the projected values are clipped elementwise to within 156 `[-proj_clip, proj_clip]`. 157 num_unit_shards: How to split the weight matrix. If >1, the weight 158 matrix is stored across num_unit_shards. 159 num_proj_shards: How to split the projection matrix. If >1, the 160 projection matrix is stored across num_proj_shards. 161 forget_bias: Biases of the forget gate are initialized by default to 1 162 in order to reduce the scale of forgetting at the beginning of 163 the training. 164 state_is_tuple: If True, accepted and returned states are 2-tuples of 165 the `c_state` and `m_state`. By default (False), they are concatenated 166 along the column axis. This default behavior will soon be deprecated. 167 activation: Activation function of the inner states. 168 reuse: (optional) Python boolean describing whether to reuse variables 169 in an existing scope. If not `True`, and the existing scope already has 170 the given variables, an error is raised. 171 layer_norm: If `True`, layer normalization will be applied. 172 norm_gain: float, The layer normalization gain initial value. If 173 `layer_norm` has been set to `False`, this argument will be ignored. 174 norm_shift: float, The layer normalization shift initial value. If 175 `layer_norm` has been set to `False`, this argument will be ignored. 176 """ 177 super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 178 if not state_is_tuple: 179 logging.warn("%s: Using a concatenated state is slower and will soon be " 180 "deprecated. Use state_is_tuple=True.", self) 181 self._num_units = num_units 182 self._use_peepholes = use_peepholes 183 self._initializer = initializer 184 self._num_proj = num_proj 185 self._proj_clip = proj_clip 186 self._num_unit_shards = num_unit_shards 187 self._num_proj_shards = num_proj_shards 188 self._forget_bias = forget_bias 189 self._state_is_tuple = state_is_tuple 190 self._activation = activation 191 self._reuse = reuse 192 self._layer_norm = layer_norm 193 self._norm_gain = norm_gain 194 self._norm_shift = norm_shift 195 196 if num_proj: 197 self._state_size = ( 198 rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 199 if state_is_tuple else num_units + num_proj) 200 self._output_size = num_proj 201 else: 202 self._state_size = ( 203 rnn_cell_impl.LSTMStateTuple(num_units, num_units) 204 if state_is_tuple else 2 * num_units) 205 self._output_size = num_units 206 207 @property 208 def state_size(self): 209 return self._state_size 210 211 @property 212 def output_size(self): 213 return self._output_size 214 215 def call(self, inputs, state): 216 """Run one step of LSTM. 217 218 Args: 219 inputs: input Tensor, 2D, batch x num_units. 220 state: if `state_is_tuple` is False, this must be a state Tensor, 221 `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 222 tuple of state Tensors, both `2-D`, with column sizes `c_state` and 223 `m_state`. 224 225 Returns: 226 A tuple containing: 227 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 228 LSTM after reading `inputs` when previous state was `state`. 229 Here output_dim is: 230 num_proj if num_proj was set, 231 num_units otherwise. 232 - Tensor(s) representing the new state of LSTM after reading `inputs` when 233 the previous state was `state`. Same type and shape(s) as `state`. 234 235 Raises: 236 ValueError: If input size cannot be inferred from inputs via 237 static shape inference. 238 """ 239 sigmoid = math_ops.sigmoid 240 241 num_proj = self._num_units if self._num_proj is None else self._num_proj 242 243 if self._state_is_tuple: 244 (c_prev, m_prev) = state 245 else: 246 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 247 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 248 249 dtype = inputs.dtype 250 input_size = inputs.get_shape().with_rank(2)[1] 251 if input_size.value is None: 252 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 253 concat_w = _get_concat_variable( 254 "W", [input_size.value + num_proj, 3 * self._num_units], dtype, 255 self._num_unit_shards) 256 257 b = vs.get_variable( 258 "B", 259 shape=[3 * self._num_units], 260 initializer=init_ops.zeros_initializer(), 261 dtype=dtype) 262 263 # j = new_input, f = forget_gate, o = output_gate 264 cell_inputs = array_ops.concat([inputs, m_prev], 1) 265 lstm_matrix = math_ops.matmul(cell_inputs, concat_w) 266 267 # If layer nomalization is applied, do not add bias 268 if not self._layer_norm: 269 lstm_matrix = nn_ops.bias_add(lstm_matrix, b) 270 271 j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) 272 273 # Apply layer normalization 274 if self._layer_norm: 275 j = _norm(self._norm_gain, self._norm_shift, j, "transform") 276 f = _norm(self._norm_gain, self._norm_shift, f, "forget") 277 o = _norm(self._norm_gain, self._norm_shift, o, "output") 278 279 # Diagonal connections 280 if self._use_peepholes: 281 w_f_diag = vs.get_variable( 282 "W_F_diag", shape=[self._num_units], dtype=dtype) 283 w_o_diag = vs.get_variable( 284 "W_O_diag", shape=[self._num_units], dtype=dtype) 285 286 if self._use_peepholes: 287 f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) 288 else: 289 f_act = sigmoid(f + self._forget_bias) 290 c = (f_act * c_prev + (1 - f_act) * self._activation(j)) 291 292 # Apply layer normalization 293 if self._layer_norm: 294 c = _norm(self._norm_gain, self._norm_shift, c, "state") 295 296 if self._use_peepholes: 297 m = sigmoid(o + w_o_diag * c) * self._activation(c) 298 else: 299 m = sigmoid(o) * self._activation(c) 300 301 if self._num_proj is not None: 302 concat_w_proj = _get_concat_variable("W_P", 303 [self._num_units, self._num_proj], 304 dtype, self._num_proj_shards) 305 306 m = math_ops.matmul(m, concat_w_proj) 307 if self._proj_clip is not None: 308 # pylint: disable=invalid-unary-operand-type 309 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 310 # pylint: enable=invalid-unary-operand-type 311 312 new_state = ( 313 rnn_cell_impl.LSTMStateTuple(c, m) 314 if self._state_is_tuple else array_ops.concat([c, m], 1)) 315 return m, new_state 316 317 318class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): 319 """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. 320 321 This implementation is based on: 322 323 Tara N. Sainath and Bo Li 324 "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 325 for LVCSR Tasks." submitted to INTERSPEECH, 2016. 326 327 It uses peep-hole connections and optional cell clipping. 328 """ 329 330 def __init__(self, 331 num_units, 332 use_peepholes=False, 333 cell_clip=None, 334 initializer=None, 335 num_unit_shards=1, 336 forget_bias=1.0, 337 feature_size=None, 338 frequency_skip=1, 339 reuse=None): 340 """Initialize the parameters for an LSTM cell. 341 342 Args: 343 num_units: int, The number of units in the LSTM cell 344 use_peepholes: bool, set True to enable diagonal/peephole connections. 345 cell_clip: (optional) A float value, if provided the cell state is clipped 346 by this value prior to the cell output activation. 347 initializer: (optional) The initializer to use for the weight and 348 projection matrices. 349 num_unit_shards: int, How to split the weight matrix. If >1, the weight 350 matrix is stored across num_unit_shards. 351 forget_bias: float, Biases of the forget gate are initialized by default 352 to 1 in order to reduce the scale of forgetting at the beginning 353 of the training. 354 feature_size: int, The size of the input feature the LSTM spans over. 355 frequency_skip: int, The amount the LSTM filter is shifted by in 356 frequency. 357 reuse: (optional) Python boolean describing whether to reuse variables 358 in an existing scope. If not `True`, and the existing scope already has 359 the given variables, an error is raised. 360 """ 361 super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) 362 self._num_units = num_units 363 self._use_peepholes = use_peepholes 364 self._cell_clip = cell_clip 365 self._initializer = initializer 366 self._num_unit_shards = num_unit_shards 367 self._forget_bias = forget_bias 368 self._feature_size = feature_size 369 self._frequency_skip = frequency_skip 370 self._state_size = 2 * num_units 371 self._output_size = num_units 372 self._reuse = reuse 373 374 @property 375 def output_size(self): 376 return self._output_size 377 378 @property 379 def state_size(self): 380 return self._state_size 381 382 def call(self, inputs, state): 383 """Run one step of LSTM. 384 385 Args: 386 inputs: input Tensor, 2D, batch x num_units. 387 state: state Tensor, 2D, batch x state_size. 388 389 Returns: 390 A tuple containing: 391 - A 2D, batch x output_dim, Tensor representing the output of the LSTM 392 after reading "inputs" when previous state was "state". 393 Here output_dim is num_units. 394 - A 2D, batch x state_size, Tensor representing the new state of LSTM 395 after reading "inputs" when previous state was "state". 396 Raises: 397 ValueError: if an input_size was specified and the provided inputs have 398 a different dimension. 399 """ 400 sigmoid = math_ops.sigmoid 401 tanh = math_ops.tanh 402 403 freq_inputs = self._make_tf_features(inputs) 404 dtype = inputs.dtype 405 actual_input_size = freq_inputs[0].get_shape().as_list()[1] 406 407 concat_w = _get_concat_variable( 408 "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units], 409 dtype, self._num_unit_shards) 410 411 b = vs.get_variable( 412 "B", 413 shape=[4 * self._num_units], 414 initializer=init_ops.zeros_initializer(), 415 dtype=dtype) 416 417 # Diagonal connections 418 if self._use_peepholes: 419 w_f_diag = vs.get_variable( 420 "W_F_diag", shape=[self._num_units], dtype=dtype) 421 w_i_diag = vs.get_variable( 422 "W_I_diag", shape=[self._num_units], dtype=dtype) 423 w_o_diag = vs.get_variable( 424 "W_O_diag", shape=[self._num_units], dtype=dtype) 425 426 # initialize the first freq state to be zero 427 m_prev_freq = array_ops.zeros( 428 [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], 429 dtype) 430 for fq in range(len(freq_inputs)): 431 c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], 432 [-1, self._num_units]) 433 m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units], 434 [-1, self._num_units]) 435 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 436 cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1) 437 lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 438 i, j, f, o = array_ops.split( 439 value=lstm_matrix, num_or_size_splits=4, axis=1) 440 441 if self._use_peepholes: 442 c = ( 443 sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 444 sigmoid(i + w_i_diag * c_prev) * tanh(j)) 445 else: 446 c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) 447 448 if self._cell_clip is not None: 449 # pylint: disable=invalid-unary-operand-type 450 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 451 # pylint: enable=invalid-unary-operand-type 452 453 if self._use_peepholes: 454 m = sigmoid(o + w_o_diag * c) * tanh(c) 455 else: 456 m = sigmoid(o) * tanh(c) 457 m_prev_freq = m 458 if fq == 0: 459 state_out = array_ops.concat([c, m], 1) 460 m_out = m 461 else: 462 state_out = array_ops.concat([state_out, c, m], 1) 463 m_out = array_ops.concat([m_out, m], 1) 464 return m_out, state_out 465 466 def _make_tf_features(self, input_feat): 467 """Make the frequency features. 468 469 Args: 470 input_feat: input Tensor, 2D, batch x num_units. 471 472 Returns: 473 A list of frequency features, with each element containing: 474 - A 2D, batch x output_dim, Tensor representing the time-frequency feature 475 for that frequency index. Here output_dim is feature_size. 476 Raises: 477 ValueError: if input_size cannot be inferred from static shape inference. 478 """ 479 input_size = input_feat.get_shape().with_rank(2)[-1].value 480 if input_size is None: 481 raise ValueError("Cannot infer input_size from static shape inference.") 482 num_feats = int( 483 (input_size - self._feature_size) / (self._frequency_skip)) + 1 484 freq_inputs = [] 485 for f in range(num_feats): 486 cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip], 487 [-1, self._feature_size]) 488 freq_inputs.append(cur_input) 489 return freq_inputs 490 491 492class GridLSTMCell(rnn_cell_impl.RNNCell): 493 """Grid Long short-term memory unit (LSTM) recurrent network cell. 494 495 The default is based on: 496 Nal Kalchbrenner, Ivo Danihelka and Alex Graves 497 "Grid Long Short-Term Memory," Proc. ICLR 2016. 498 http://arxiv.org/abs/1507.01526 499 500 When peephole connections are used, the implementation is based on: 501 Tara N. Sainath and Bo Li 502 "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 503 for LVCSR Tasks." submitted to INTERSPEECH, 2016. 504 505 The code uses optional peephole connections, shared_weights and cell clipping. 506 """ 507 508 def __init__(self, 509 num_units, 510 use_peepholes=False, 511 share_time_frequency_weights=False, 512 cell_clip=None, 513 initializer=None, 514 num_unit_shards=1, 515 forget_bias=1.0, 516 feature_size=None, 517 frequency_skip=None, 518 num_frequency_blocks=None, 519 start_freqindex_list=None, 520 end_freqindex_list=None, 521 couple_input_forget_gates=False, 522 state_is_tuple=True, 523 reuse=None): 524 """Initialize the parameters for an LSTM cell. 525 526 Args: 527 num_units: int, The number of units in the LSTM cell 528 use_peepholes: (optional) bool, default False. Set True to enable 529 diagonal/peephole connections. 530 share_time_frequency_weights: (optional) bool, default False. Set True to 531 enable shared cell weights between time and frequency LSTMs. 532 cell_clip: (optional) A float value, default None, if provided the cell 533 state is clipped by this value prior to the cell output activation. 534 initializer: (optional) The initializer to use for the weight and 535 projection matrices, default None. 536 num_unit_shards: (optional) int, default 1, How to split the weight 537 matrix. If > 1,the weight matrix is stored across num_unit_shards. 538 forget_bias: (optional) float, default 1.0, The initial bias of the 539 forget gates, used to reduce the scale of forgetting at the beginning 540 of the training. 541 feature_size: (optional) int, default None, The size of the input feature 542 the LSTM spans over. 543 frequency_skip: (optional) int, default None, The amount the LSTM filter 544 is shifted by in frequency. 545 num_frequency_blocks: [required] A list of frequency blocks needed to 546 cover the whole input feature splitting defined by start_freqindex_list 547 and end_freqindex_list. 548 start_freqindex_list: [optional], list of ints, default None, The 549 starting frequency index for each frequency block. 550 end_freqindex_list: [optional], list of ints, default None. The ending 551 frequency index for each frequency block. 552 couple_input_forget_gates: (optional) bool, default False, Whether to 553 couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 554 model parameters and computation cost. 555 state_is_tuple: If True, accepted and returned states are 2-tuples of 556 the `c_state` and `m_state`. By default (False), they are concatenated 557 along the column axis. This default behavior will soon be deprecated. 558 reuse: (optional) Python boolean describing whether to reuse variables 559 in an existing scope. If not `True`, and the existing scope already has 560 the given variables, an error is raised. 561 Raises: 562 ValueError: if the num_frequency_blocks list is not specified 563 """ 564 super(GridLSTMCell, self).__init__(_reuse=reuse) 565 if not state_is_tuple: 566 logging.warn("%s: Using a concatenated state is slower and will soon be " 567 "deprecated. Use state_is_tuple=True.", self) 568 self._num_units = num_units 569 self._use_peepholes = use_peepholes 570 self._share_time_frequency_weights = share_time_frequency_weights 571 self._couple_input_forget_gates = couple_input_forget_gates 572 self._state_is_tuple = state_is_tuple 573 self._cell_clip = cell_clip 574 self._initializer = initializer 575 self._num_unit_shards = num_unit_shards 576 self._forget_bias = forget_bias 577 self._feature_size = feature_size 578 self._frequency_skip = frequency_skip 579 self._start_freqindex_list = start_freqindex_list 580 self._end_freqindex_list = end_freqindex_list 581 self._num_frequency_blocks = num_frequency_blocks 582 self._total_blocks = 0 583 self._reuse = reuse 584 if self._num_frequency_blocks is None: 585 raise ValueError("Must specify num_frequency_blocks") 586 587 for block_index in range(len(self._num_frequency_blocks)): 588 self._total_blocks += int(self._num_frequency_blocks[block_index]) 589 if state_is_tuple: 590 state_names = "" 591 for block_index in range(len(self._num_frequency_blocks)): 592 for freq_index in range(self._num_frequency_blocks[block_index]): 593 name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) 594 state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 595 self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple", 596 state_names.strip(",")) 597 self._state_size = self._state_tuple_type(*( 598 [num_units, num_units] * self._total_blocks)) 599 else: 600 self._state_tuple_type = None 601 self._state_size = num_units * self._total_blocks * 2 602 self._output_size = num_units * self._total_blocks * 2 603 604 @property 605 def output_size(self): 606 return self._output_size 607 608 @property 609 def state_size(self): 610 return self._state_size 611 612 @property 613 def state_tuple_type(self): 614 return self._state_tuple_type 615 616 def call(self, inputs, state): 617 """Run one step of LSTM. 618 619 Args: 620 inputs: input Tensor, 2D, [batch, feature_size]. 621 state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the 622 flag self._state_is_tuple. 623 624 Returns: 625 A tuple containing: 626 - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 627 after reading "inputs" when previous state was "state". 628 Here output_dim is num_units. 629 - A 2D, [batch, state_size], Tensor representing the new state of LSTM 630 after reading "inputs" when previous state was "state". 631 Raises: 632 ValueError: if an input_size was specified and the provided inputs have 633 a different dimension. 634 """ 635 batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 636 freq_inputs = self._make_tf_features(inputs) 637 m_out_lst = [] 638 state_out_lst = [] 639 for block in range(len(freq_inputs)): 640 m_out_lst_current, state_out_lst_current = self._compute( 641 freq_inputs[block], 642 block, 643 state, 644 batch_size, 645 state_is_tuple=self._state_is_tuple) 646 m_out_lst.extend(m_out_lst_current) 647 state_out_lst.extend(state_out_lst_current) 648 if self._state_is_tuple: 649 state_out = self._state_tuple_type(*state_out_lst) 650 else: 651 state_out = array_ops.concat(state_out_lst, 1) 652 m_out = array_ops.concat(m_out_lst, 1) 653 return m_out, state_out 654 655 def _compute(self, 656 freq_inputs, 657 block, 658 state, 659 batch_size, 660 state_prefix="state", 661 state_is_tuple=True): 662 """Run the actual computation of one step LSTM. 663 664 Args: 665 freq_inputs: list of Tensors, 2D, [batch, feature_size]. 666 block: int, current frequency block index to process. 667 state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on 668 the flag state_is_tuple. 669 batch_size: int32, batch size. 670 state_prefix: (optional) string, name prefix for states, defaults to 671 "state". 672 state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. 673 674 Returns: 675 A tuple, containing: 676 - A list of [batch, output_dim] Tensors, representing the output of the 677 LSTM given the inputs and state. 678 - A list of [batch, state_size] Tensors, representing the LSTM state 679 values given the inputs and previous state. 680 """ 681 sigmoid = math_ops.sigmoid 682 tanh = math_ops.tanh 683 num_gates = 3 if self._couple_input_forget_gates else 4 684 dtype = freq_inputs[0].dtype 685 actual_input_size = freq_inputs[0].get_shape().as_list()[1] 686 687 concat_w_f = _get_concat_variable( 688 "W_f_%d" % block, 689 [actual_input_size + 2 * self._num_units, num_gates * self._num_units], 690 dtype, self._num_unit_shards) 691 b_f = vs.get_variable( 692 "B_f_%d" % block, 693 shape=[num_gates * self._num_units], 694 initializer=init_ops.zeros_initializer(), 695 dtype=dtype) 696 if not self._share_time_frequency_weights: 697 concat_w_t = _get_concat_variable("W_t_%d" % block, [ 698 actual_input_size + 2 * self._num_units, num_gates * self._num_units 699 ], dtype, self._num_unit_shards) 700 b_t = vs.get_variable( 701 "B_t_%d" % block, 702 shape=[num_gates * self._num_units], 703 initializer=init_ops.zeros_initializer(), 704 dtype=dtype) 705 706 if self._use_peepholes: 707 # Diagonal connections 708 if not self._couple_input_forget_gates: 709 w_f_diag_freqf = vs.get_variable( 710 "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 711 w_f_diag_freqt = vs.get_variable( 712 "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 713 w_i_diag_freqf = vs.get_variable( 714 "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 715 w_i_diag_freqt = vs.get_variable( 716 "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 717 w_o_diag_freqf = vs.get_variable( 718 "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 719 w_o_diag_freqt = vs.get_variable( 720 "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 721 if not self._share_time_frequency_weights: 722 if not self._couple_input_forget_gates: 723 w_f_diag_timef = vs.get_variable( 724 "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 725 w_f_diag_timet = vs.get_variable( 726 "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 727 w_i_diag_timef = vs.get_variable( 728 "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 729 w_i_diag_timet = vs.get_variable( 730 "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 731 w_o_diag_timef = vs.get_variable( 732 "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 733 w_o_diag_timet = vs.get_variable( 734 "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 735 736 # initialize the first freq state to be zero 737 m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 738 c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 739 for freq_index in range(len(freq_inputs)): 740 if state_is_tuple: 741 name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) 742 c_prev_time = getattr(state, name_prefix + "_c") 743 m_prev_time = getattr(state, name_prefix + "_m") 744 else: 745 c_prev_time = array_ops.slice( 746 state, [0, 2 * freq_index * self._num_units], [-1, self._num_units]) 747 m_prev_time = array_ops.slice( 748 state, [0, (2 * freq_index + 1) * self._num_units], 749 [-1, self._num_units]) 750 751 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 752 cell_inputs = array_ops.concat( 753 [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) 754 755 # F-LSTM 756 lstm_matrix_freq = nn_ops.bias_add( 757 math_ops.matmul(cell_inputs, concat_w_f), b_f) 758 if self._couple_input_forget_gates: 759 i_freq, j_freq, o_freq = array_ops.split( 760 value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 761 f_freq = None 762 else: 763 i_freq, j_freq, f_freq, o_freq = array_ops.split( 764 value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 765 # T-LSTM 766 if self._share_time_frequency_weights: 767 i_time = i_freq 768 j_time = j_freq 769 f_time = f_freq 770 o_time = o_freq 771 else: 772 lstm_matrix_time = nn_ops.bias_add( 773 math_ops.matmul(cell_inputs, concat_w_t), b_t) 774 if self._couple_input_forget_gates: 775 i_time, j_time, o_time = array_ops.split( 776 value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 777 f_time = None 778 else: 779 i_time, j_time, f_time, o_time = array_ops.split( 780 value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 781 782 # F-LSTM c_freq 783 # input gate activations 784 if self._use_peepholes: 785 i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq + 786 w_i_diag_freqt * c_prev_time) 787 else: 788 i_freq_g = sigmoid(i_freq) 789 # forget gate activations 790 if self._couple_input_forget_gates: 791 f_freq_g = 1.0 - i_freq_g 792 else: 793 if self._use_peepholes: 794 f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf * 795 c_prev_freq + w_f_diag_freqt * c_prev_time) 796 else: 797 f_freq_g = sigmoid(f_freq + self._forget_bias) 798 # cell state 799 c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) 800 if self._cell_clip is not None: 801 # pylint: disable=invalid-unary-operand-type 802 c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, 803 self._cell_clip) 804 # pylint: enable=invalid-unary-operand-type 805 806 # T-LSTM c_freq 807 # input gate activations 808 if self._use_peepholes: 809 if self._share_time_frequency_weights: 810 i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq + 811 w_i_diag_freqt * c_prev_time) 812 else: 813 i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq + 814 w_i_diag_timet * c_prev_time) 815 else: 816 i_time_g = sigmoid(i_time) 817 # forget gate activations 818 if self._couple_input_forget_gates: 819 f_time_g = 1.0 - i_time_g 820 else: 821 if self._use_peepholes: 822 if self._share_time_frequency_weights: 823 f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf * 824 c_prev_freq + w_f_diag_freqt * c_prev_time) 825 else: 826 f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef * 827 c_prev_freq + w_f_diag_timet * c_prev_time) 828 else: 829 f_time_g = sigmoid(f_time + self._forget_bias) 830 # cell state 831 c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) 832 if self._cell_clip is not None: 833 # pylint: disable=invalid-unary-operand-type 834 c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, 835 self._cell_clip) 836 # pylint: enable=invalid-unary-operand-type 837 838 # F-LSTM m_freq 839 if self._use_peepholes: 840 m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq + 841 w_o_diag_freqt * c_time) * tanh(c_freq) 842 else: 843 m_freq = sigmoid(o_freq) * tanh(c_freq) 844 845 # T-LSTM m_time 846 if self._use_peepholes: 847 if self._share_time_frequency_weights: 848 m_time = sigmoid(o_time + w_o_diag_freqf * c_freq + 849 w_o_diag_freqt * c_time) * tanh(c_time) 850 else: 851 m_time = sigmoid(o_time + w_o_diag_timef * c_freq + 852 w_o_diag_timet * c_time) * tanh(c_time) 853 else: 854 m_time = sigmoid(o_time) * tanh(c_time) 855 856 m_prev_freq = m_freq 857 c_prev_freq = c_freq 858 # Concatenate the outputs for T-LSTM and F-LSTM for each shift 859 if freq_index == 0: 860 state_out_lst = [c_time, m_time] 861 m_out_lst = [m_time, m_freq] 862 else: 863 state_out_lst.extend([c_time, m_time]) 864 m_out_lst.extend([m_time, m_freq]) 865 866 return m_out_lst, state_out_lst 867 868 def _make_tf_features(self, input_feat, slice_offset=0): 869 """Make the frequency features. 870 871 Args: 872 input_feat: input Tensor, 2D, [batch, num_units]. 873 slice_offset: (optional) Python int, default 0, the slicing offset is only 874 used for the backward processing in the BidirectionalGridLSTMCell. It 875 specifies a different starting point instead of always 0 to enable the 876 forward and backward processing look at different frequency blocks. 877 878 Returns: 879 A list of frequency features, with each element containing: 880 - A 2D, [batch, output_dim], Tensor representing the time-frequency 881 feature for that frequency index. Here output_dim is feature_size. 882 Raises: 883 ValueError: if input_size cannot be inferred from static shape inference. 884 """ 885 input_size = input_feat.get_shape().with_rank(2)[-1].value 886 if input_size is None: 887 raise ValueError("Cannot infer input_size from static shape inference.") 888 if slice_offset > 0: 889 # Padding to the end 890 inputs = array_ops.pad(input_feat, 891 array_ops.constant( 892 [0, 0, 0, slice_offset], 893 shape=[2, 2], 894 dtype=dtypes.int32), "CONSTANT") 895 elif slice_offset < 0: 896 # Padding to the front 897 inputs = array_ops.pad(input_feat, 898 array_ops.constant( 899 [0, 0, -slice_offset, 0], 900 shape=[2, 2], 901 dtype=dtypes.int32), "CONSTANT") 902 slice_offset = 0 903 else: 904 inputs = input_feat 905 freq_inputs = [] 906 if not self._start_freqindex_list: 907 if len(self._num_frequency_blocks) != 1: 908 raise ValueError("Length of num_frequency_blocks" 909 " is not 1, but instead is %d", 910 len(self._num_frequency_blocks)) 911 num_feats = int( 912 (input_size - self._feature_size) / (self._frequency_skip)) + 1 913 if num_feats != self._num_frequency_blocks[0]: 914 raise ValueError( 915 "Invalid num_frequency_blocks, requires %d but gets %d, please" 916 " check the input size and filter config are correct." % 917 (self._num_frequency_blocks[0], num_feats)) 918 block_inputs = [] 919 for f in range(num_feats): 920 cur_input = array_ops.slice( 921 inputs, [0, slice_offset + f * self._frequency_skip], 922 [-1, self._feature_size]) 923 block_inputs.append(cur_input) 924 freq_inputs.append(block_inputs) 925 else: 926 if len(self._start_freqindex_list) != len(self._end_freqindex_list): 927 raise ValueError("Length of start and end freqindex_list" 928 " does not match %d %d", 929 len(self._start_freqindex_list), 930 len(self._end_freqindex_list)) 931 if len(self._num_frequency_blocks) != len(self._start_freqindex_list): 932 raise ValueError("Length of num_frequency_blocks" 933 " is not equal to start_freqindex_list %d %d", 934 len(self._num_frequency_blocks), 935 len(self._start_freqindex_list)) 936 for b in range(len(self._start_freqindex_list)): 937 start_index = self._start_freqindex_list[b] 938 end_index = self._end_freqindex_list[b] 939 cur_size = end_index - start_index 940 block_feats = int( 941 (cur_size - self._feature_size) / (self._frequency_skip)) + 1 942 if block_feats != self._num_frequency_blocks[b]: 943 raise ValueError( 944 "Invalid num_frequency_blocks, requires %d but gets %d, please" 945 " check the input size and filter config are correct." % 946 (self._num_frequency_blocks[b], block_feats)) 947 block_inputs = [] 948 for f in range(block_feats): 949 cur_input = array_ops.slice( 950 inputs, 951 [0, start_index + slice_offset + f * self._frequency_skip], 952 [-1, self._feature_size]) 953 block_inputs.append(cur_input) 954 freq_inputs.append(block_inputs) 955 return freq_inputs 956 957 958class BidirectionalGridLSTMCell(GridLSTMCell): 959 """Bidirectional GridLstm cell. 960 961 The bidirection connection is only used in the frequency direction, which 962 hence doesn't affect the time direction's real-time processing that is 963 required for online recognition systems. 964 The current implementation uses different weights for the two directions. 965 """ 966 967 def __init__(self, 968 num_units, 969 use_peepholes=False, 970 share_time_frequency_weights=False, 971 cell_clip=None, 972 initializer=None, 973 num_unit_shards=1, 974 forget_bias=1.0, 975 feature_size=None, 976 frequency_skip=None, 977 num_frequency_blocks=None, 978 start_freqindex_list=None, 979 end_freqindex_list=None, 980 couple_input_forget_gates=False, 981 backward_slice_offset=0, 982 reuse=None): 983 """Initialize the parameters for an LSTM cell. 984 985 Args: 986 num_units: int, The number of units in the LSTM cell 987 use_peepholes: (optional) bool, default False. Set True to enable 988 diagonal/peephole connections. 989 share_time_frequency_weights: (optional) bool, default False. Set True to 990 enable shared cell weights between time and frequency LSTMs. 991 cell_clip: (optional) A float value, default None, if provided the cell 992 state is clipped by this value prior to the cell output activation. 993 initializer: (optional) The initializer to use for the weight and 994 projection matrices, default None. 995 num_unit_shards: (optional) int, default 1, How to split the weight 996 matrix. If > 1,the weight matrix is stored across num_unit_shards. 997 forget_bias: (optional) float, default 1.0, The initial bias of the 998 forget gates, used to reduce the scale of forgetting at the beginning 999 of the training. 1000 feature_size: (optional) int, default None, The size of the input feature 1001 the LSTM spans over. 1002 frequency_skip: (optional) int, default None, The amount the LSTM filter 1003 is shifted by in frequency. 1004 num_frequency_blocks: [required] A list of frequency blocks needed to 1005 cover the whole input feature splitting defined by start_freqindex_list 1006 and end_freqindex_list. 1007 start_freqindex_list: [optional], list of ints, default None, The 1008 starting frequency index for each frequency block. 1009 end_freqindex_list: [optional], list of ints, default None. The ending 1010 frequency index for each frequency block. 1011 couple_input_forget_gates: (optional) bool, default False, Whether to 1012 couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 1013 model parameters and computation cost. 1014 backward_slice_offset: (optional) int32, default 0, the starting offset to 1015 slice the feature for backward processing. 1016 reuse: (optional) Python boolean describing whether to reuse variables 1017 in an existing scope. If not `True`, and the existing scope already has 1018 the given variables, an error is raised. 1019 """ 1020 super(BidirectionalGridLSTMCell, self).__init__( 1021 num_units, use_peepholes, share_time_frequency_weights, cell_clip, 1022 initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, 1023 num_frequency_blocks, start_freqindex_list, end_freqindex_list, 1024 couple_input_forget_gates, True, reuse) 1025 self._backward_slice_offset = int(backward_slice_offset) 1026 state_names = "" 1027 for direction in ["fwd", "bwd"]: 1028 for block_index in range(len(self._num_frequency_blocks)): 1029 for freq_index in range(self._num_frequency_blocks[block_index]): 1030 name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, 1031 block_index) 1032 state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 1033 self._state_tuple_type = collections.namedtuple( 1034 "BidirectionalGridLSTMStateTuple", state_names.strip(",")) 1035 self._state_size = self._state_tuple_type(*( 1036 [num_units, num_units] * self._total_blocks * 2)) 1037 self._output_size = 2 * num_units * self._total_blocks * 2 1038 1039 def call(self, inputs, state): 1040 """Run one step of LSTM. 1041 1042 Args: 1043 inputs: input Tensor, 2D, [batch, num_units]. 1044 state: tuple of Tensors, 2D, [batch, state_size]. 1045 1046 Returns: 1047 A tuple containing: 1048 - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 1049 after reading "inputs" when previous state was "state". 1050 Here output_dim is num_units. 1051 - A 2D, [batch, state_size], Tensor representing the new state of LSTM 1052 after reading "inputs" when previous state was "state". 1053 Raises: 1054 ValueError: if an input_size was specified and the provided inputs have 1055 a different dimension. 1056 """ 1057 batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 1058 fwd_inputs = self._make_tf_features(inputs) 1059 if self._backward_slice_offset: 1060 bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) 1061 else: 1062 bwd_inputs = fwd_inputs 1063 1064 # Forward processing 1065 with vs.variable_scope("fwd"): 1066 fwd_m_out_lst = [] 1067 fwd_state_out_lst = [] 1068 for block in range(len(fwd_inputs)): 1069 fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( 1070 fwd_inputs[block], 1071 block, 1072 state, 1073 batch_size, 1074 state_prefix="fwd_state", 1075 state_is_tuple=True) 1076 fwd_m_out_lst.extend(fwd_m_out_lst_current) 1077 fwd_state_out_lst.extend(fwd_state_out_lst_current) 1078 # Backward processing 1079 bwd_m_out_lst = [] 1080 bwd_state_out_lst = [] 1081 with vs.variable_scope("bwd"): 1082 for block in range(len(bwd_inputs)): 1083 # Reverse the blocks 1084 bwd_inputs_reverse = bwd_inputs[block][::-1] 1085 bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( 1086 bwd_inputs_reverse, 1087 block, 1088 state, 1089 batch_size, 1090 state_prefix="bwd_state", 1091 state_is_tuple=True) 1092 bwd_m_out_lst.extend(bwd_m_out_lst_current) 1093 bwd_state_out_lst.extend(bwd_state_out_lst_current) 1094 state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) 1095 # Outputs are always concated as it is never used separately. 1096 m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) 1097 return m_out, state_out 1098 1099 1100# pylint: disable=protected-access 1101_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name 1102 1103# pylint: enable=protected-access 1104 1105 1106class AttentionCellWrapper(rnn_cell_impl.RNNCell): 1107 """Basic attention cell wrapper. 1108 1109 Implementation based on https://arxiv.org/abs/1409.0473. 1110 """ 1111 1112 def __init__(self, 1113 cell, 1114 attn_length, 1115 attn_size=None, 1116 attn_vec_size=None, 1117 input_size=None, 1118 state_is_tuple=True, 1119 reuse=None): 1120 """Create a cell with attention. 1121 1122 Args: 1123 cell: an RNNCell, an attention is added to it. 1124 attn_length: integer, the size of an attention window. 1125 attn_size: integer, the size of an attention vector. Equal to 1126 cell.output_size by default. 1127 attn_vec_size: integer, the number of convolutional features calculated 1128 on attention state and a size of the hidden layer built from 1129 base cell state. Equal attn_size to by default. 1130 input_size: integer, the size of a hidden linear layer, 1131 built from inputs and attention. Derived from the input tensor 1132 by default. 1133 state_is_tuple: If True, accepted and returned states are n-tuples, where 1134 `n = len(cells)`. By default (False), the states are all 1135 concatenated along the column axis. 1136 reuse: (optional) Python boolean describing whether to reuse variables 1137 in an existing scope. If not `True`, and the existing scope already has 1138 the given variables, an error is raised. 1139 1140 Raises: 1141 TypeError: if cell is not an RNNCell. 1142 ValueError: if cell returns a state tuple but the flag 1143 `state_is_tuple` is `False` or if attn_length is zero or less. 1144 """ 1145 super(AttentionCellWrapper, self).__init__(_reuse=reuse) 1146 if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access 1147 raise TypeError("The parameter cell is not RNNCell.") 1148 if nest.is_sequence(cell.state_size) and not state_is_tuple: 1149 raise ValueError( 1150 "Cell returns tuple of states, but the flag " 1151 "state_is_tuple is not set. State size is: %s" % str(cell.state_size)) 1152 if attn_length <= 0: 1153 raise ValueError( 1154 "attn_length should be greater than zero, got %s" % str(attn_length)) 1155 if not state_is_tuple: 1156 logging.warn("%s: Using a concatenated state is slower and will soon be " 1157 "deprecated. Use state_is_tuple=True.", self) 1158 if attn_size is None: 1159 attn_size = cell.output_size 1160 if attn_vec_size is None: 1161 attn_vec_size = attn_size 1162 self._state_is_tuple = state_is_tuple 1163 self._cell = cell 1164 self._attn_vec_size = attn_vec_size 1165 self._input_size = input_size 1166 self._attn_size = attn_size 1167 self._attn_length = attn_length 1168 self._reuse = reuse 1169 self._linear1 = None 1170 self._linear2 = None 1171 self._linear3 = None 1172 1173 @property 1174 def state_size(self): 1175 size = (self._cell.state_size, self._attn_size, 1176 self._attn_size * self._attn_length) 1177 if self._state_is_tuple: 1178 return size 1179 else: 1180 return sum(list(size)) 1181 1182 @property 1183 def output_size(self): 1184 return self._attn_size 1185 1186 def call(self, inputs, state): 1187 """Long short-term memory cell with attention (LSTMA).""" 1188 if self._state_is_tuple: 1189 state, attns, attn_states = state 1190 else: 1191 states = state 1192 state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) 1193 attns = array_ops.slice(states, [0, self._cell.state_size], 1194 [-1, self._attn_size]) 1195 attn_states = array_ops.slice( 1196 states, [0, self._cell.state_size + self._attn_size], 1197 [-1, self._attn_size * self._attn_length]) 1198 attn_states = array_ops.reshape(attn_states, 1199 [-1, self._attn_length, self._attn_size]) 1200 input_size = self._input_size 1201 if input_size is None: 1202 input_size = inputs.get_shape().as_list()[1] 1203 if self._linear1 is None: 1204 self._linear1 = _Linear([inputs, attns], input_size, True) 1205 inputs = self._linear1([inputs, attns]) 1206 cell_output, new_state = self._cell(inputs, state) 1207 if self._state_is_tuple: 1208 new_state_cat = array_ops.concat(nest.flatten(new_state), 1) 1209 else: 1210 new_state_cat = new_state 1211 new_attns, new_attn_states = self._attention(new_state_cat, attn_states) 1212 with vs.variable_scope("attn_output_projection"): 1213 if self._linear2 is None: 1214 self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) 1215 output = self._linear2([cell_output, new_attns]) 1216 new_attn_states = array_ops.concat( 1217 [new_attn_states, array_ops.expand_dims(output, 1)], 1) 1218 new_attn_states = array_ops.reshape( 1219 new_attn_states, [-1, self._attn_length * self._attn_size]) 1220 new_state = (new_state, new_attns, new_attn_states) 1221 if not self._state_is_tuple: 1222 new_state = array_ops.concat(list(new_state), 1) 1223 return output, new_state 1224 1225 def _attention(self, query, attn_states): 1226 conv2d = nn_ops.conv2d 1227 reduce_sum = math_ops.reduce_sum 1228 softmax = nn_ops.softmax 1229 tanh = math_ops.tanh 1230 1231 with vs.variable_scope("attention"): 1232 k = vs.get_variable("attn_w", 1233 [1, 1, self._attn_size, self._attn_vec_size]) 1234 v = vs.get_variable("attn_v", [self._attn_vec_size]) 1235 hidden = array_ops.reshape(attn_states, 1236 [-1, self._attn_length, 1, self._attn_size]) 1237 hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") 1238 if self._linear3 is None: 1239 self._linear3 = _Linear(query, self._attn_vec_size, True) 1240 y = self._linear3(query) 1241 y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) 1242 s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) 1243 a = softmax(s) 1244 d = reduce_sum( 1245 array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) 1246 new_attns = array_ops.reshape(d, [-1, self._attn_size]) 1247 new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) 1248 return new_attns, new_attn_states 1249 1250 1251class HighwayWrapper(rnn_cell_impl.RNNCell): 1252 """RNNCell wrapper that adds highway connection on cell input and output. 1253 1254 Based on: 1255 R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", 1256 arXiv preprint arXiv:1505.00387, 2015. 1257 https://arxiv.org/abs/1505.00387 1258 """ 1259 1260 def __init__(self, 1261 cell, 1262 couple_carry_transform_gates=True, 1263 carry_bias_init=1.0): 1264 """Constructs a `HighwayWrapper` for `cell`. 1265 1266 Args: 1267 cell: An instance of `RNNCell`. 1268 couple_carry_transform_gates: boolean, should the Carry and Transform gate 1269 be coupled. 1270 carry_bias_init: float, carry gates bias initialization. 1271 """ 1272 self._cell = cell 1273 self._couple_carry_transform_gates = couple_carry_transform_gates 1274 self._carry_bias_init = carry_bias_init 1275 1276 @property 1277 def state_size(self): 1278 return self._cell.state_size 1279 1280 @property 1281 def output_size(self): 1282 return self._cell.output_size 1283 1284 def zero_state(self, batch_size, dtype): 1285 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1286 return self._cell.zero_state(batch_size, dtype) 1287 1288 def _highway(self, inp, out): 1289 input_size = inp.get_shape().with_rank(2)[1].value 1290 carry_weight = vs.get_variable("carry_w", [input_size, input_size]) 1291 carry_bias = vs.get_variable( 1292 "carry_b", [input_size], 1293 initializer=init_ops.constant_initializer(self._carry_bias_init)) 1294 carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) 1295 if self._couple_carry_transform_gates: 1296 transform = 1 - carry 1297 else: 1298 transform_weight = vs.get_variable("transform_w", 1299 [input_size, input_size]) 1300 transform_bias = vs.get_variable( 1301 "transform_b", [input_size], 1302 initializer=init_ops.constant_initializer(-self._carry_bias_init)) 1303 transform = math_ops.sigmoid( 1304 nn_ops.xw_plus_b(inp, transform_weight, transform_bias)) 1305 return inp * carry + out * transform 1306 1307 def __call__(self, inputs, state, scope=None): 1308 """Run the cell and add its inputs to its outputs. 1309 1310 Args: 1311 inputs: cell inputs. 1312 state: cell state. 1313 scope: optional cell scope. 1314 1315 Returns: 1316 Tuple of cell outputs and new state. 1317 1318 Raises: 1319 TypeError: If cell inputs and outputs have different structure (type). 1320 ValueError: If cell inputs and outputs have different structure (value). 1321 """ 1322 outputs, new_state = self._cell(inputs, state, scope=scope) 1323 nest.assert_same_structure(inputs, outputs) 1324 1325 # Ensure shapes match 1326 def assert_shape_match(inp, out): 1327 inp.get_shape().assert_is_compatible_with(out.get_shape()) 1328 1329 nest.map_structure(assert_shape_match, inputs, outputs) 1330 res_outputs = nest.map_structure(self._highway, inputs, outputs) 1331 return (res_outputs, new_state) 1332 1333 1334class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): 1335 """LSTM unit with layer normalization and recurrent dropout. 1336 1337 This class adds layer normalization and recurrent dropout to a 1338 basic LSTM unit. Layer normalization implementation is based on: 1339 1340 https://arxiv.org/abs/1607.06450. 1341 1342 "Layer Normalization" 1343 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 1344 1345 and is applied before the internal nonlinearities. 1346 Recurrent dropout is base on: 1347 1348 https://arxiv.org/abs/1603.05118 1349 1350 "Recurrent Dropout without Memory Loss" 1351 Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. 1352 """ 1353 1354 def __init__(self, 1355 num_units, 1356 forget_bias=1.0, 1357 input_size=None, 1358 activation=math_ops.tanh, 1359 layer_norm=True, 1360 norm_gain=1.0, 1361 norm_shift=0.0, 1362 dropout_keep_prob=1.0, 1363 dropout_prob_seed=None, 1364 reuse=None): 1365 """Initializes the basic LSTM cell. 1366 1367 Args: 1368 num_units: int, The number of units in the LSTM cell. 1369 forget_bias: float, The bias added to forget gates (see above). 1370 input_size: Deprecated and unused. 1371 activation: Activation function of the inner states. 1372 layer_norm: If `True`, layer normalization will be applied. 1373 norm_gain: float, The layer normalization gain initial value. If 1374 `layer_norm` has been set to `False`, this argument will be ignored. 1375 norm_shift: float, The layer normalization shift initial value. If 1376 `layer_norm` has been set to `False`, this argument will be ignored. 1377 dropout_keep_prob: unit Tensor or float between 0 and 1 representing the 1378 recurrent dropout probability value. If float and 1.0, no dropout will 1379 be applied. 1380 dropout_prob_seed: (optional) integer, the randomness seed. 1381 reuse: (optional) Python boolean describing whether to reuse variables 1382 in an existing scope. If not `True`, and the existing scope already has 1383 the given variables, an error is raised. 1384 """ 1385 super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) 1386 1387 if input_size is not None: 1388 logging.warn("%s: The input_size parameter is deprecated.", self) 1389 1390 self._num_units = num_units 1391 self._activation = activation 1392 self._forget_bias = forget_bias 1393 self._keep_prob = dropout_keep_prob 1394 self._seed = dropout_prob_seed 1395 self._layer_norm = layer_norm 1396 self._norm_gain = norm_gain 1397 self._norm_shift = norm_shift 1398 self._reuse = reuse 1399 1400 @property 1401 def state_size(self): 1402 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 1403 1404 @property 1405 def output_size(self): 1406 return self._num_units 1407 1408 def _norm(self, inp, scope, dtype=dtypes.float32): 1409 shape = inp.get_shape()[-1:] 1410 gamma_init = init_ops.constant_initializer(self._norm_gain) 1411 beta_init = init_ops.constant_initializer(self._norm_shift) 1412 with vs.variable_scope(scope): 1413 # Initialize beta and gamma for use by layer_norm. 1414 vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) 1415 vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) 1416 normalized = layers.layer_norm(inp, reuse=True, scope=scope) 1417 return normalized 1418 1419 def _linear(self, args): 1420 out_size = 4 * self._num_units 1421 proj_size = args.get_shape()[-1] 1422 dtype = args.dtype 1423 weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) 1424 out = math_ops.matmul(args, weights) 1425 if not self._layer_norm: 1426 bias = vs.get_variable("bias", [out_size], dtype=dtype) 1427 out = nn_ops.bias_add(out, bias) 1428 return out 1429 1430 def call(self, inputs, state): 1431 """LSTM cell with layer normalization and recurrent dropout.""" 1432 c, h = state 1433 args = array_ops.concat([inputs, h], 1) 1434 concat = self._linear(args) 1435 dtype = args.dtype 1436 1437 i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 1438 if self._layer_norm: 1439 i = self._norm(i, "input", dtype=dtype) 1440 j = self._norm(j, "transform", dtype=dtype) 1441 f = self._norm(f, "forget", dtype=dtype) 1442 o = self._norm(o, "output", dtype=dtype) 1443 1444 g = self._activation(j) 1445 if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 1446 g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 1447 1448 new_c = ( 1449 c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) 1450 if self._layer_norm: 1451 new_c = self._norm(new_c, "state", dtype=dtype) 1452 new_h = self._activation(new_c) * math_ops.sigmoid(o) 1453 1454 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 1455 return new_h, new_state 1456 1457 1458class NASCell(rnn_cell_impl.RNNCell): 1459 """Neural Architecture Search (NAS) recurrent network cell. 1460 1461 This implements the recurrent cell from the paper: 1462 1463 https://arxiv.org/abs/1611.01578 1464 1465 Barret Zoph and Quoc V. Le. 1466 "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. 1467 1468 The class uses an optional projection layer. 1469 """ 1470 1471 def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): 1472 """Initialize the parameters for a NAS cell. 1473 1474 Args: 1475 num_units: int, The number of units in the NAS cell 1476 num_proj: (optional) int, The output dimensionality for the projection 1477 matrices. If None, no projection is performed. 1478 use_biases: (optional) bool, If True then use biases within the cell. This 1479 is False by default. 1480 reuse: (optional) Python boolean describing whether to reuse variables 1481 in an existing scope. If not `True`, and the existing scope already has 1482 the given variables, an error is raised. 1483 """ 1484 super(NASCell, self).__init__(_reuse=reuse) 1485 self._num_units = num_units 1486 self._num_proj = num_proj 1487 self._use_biases = use_biases 1488 self._reuse = reuse 1489 1490 if num_proj is not None: 1491 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 1492 self._output_size = num_proj 1493 else: 1494 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 1495 self._output_size = num_units 1496 1497 @property 1498 def state_size(self): 1499 return self._state_size 1500 1501 @property 1502 def output_size(self): 1503 return self._output_size 1504 1505 def call(self, inputs, state): 1506 """Run one step of NAS Cell. 1507 1508 Args: 1509 inputs: input Tensor, 2D, batch x num_units. 1510 state: This must be a tuple of state Tensors, both `2-D`, with column 1511 sizes `c_state` and `m_state`. 1512 1513 Returns: 1514 A tuple containing: 1515 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 1516 NAS Cell after reading `inputs` when previous state was `state`. 1517 Here output_dim is: 1518 num_proj if num_proj was set, 1519 num_units otherwise. 1520 - Tensor(s) representing the new state of NAS Cell after reading `inputs` 1521 when the previous state was `state`. Same type and shape(s) as `state`. 1522 1523 Raises: 1524 ValueError: If input size cannot be inferred from inputs via 1525 static shape inference. 1526 """ 1527 sigmoid = math_ops.sigmoid 1528 tanh = math_ops.tanh 1529 relu = nn_ops.relu 1530 1531 num_proj = self._num_units if self._num_proj is None else self._num_proj 1532 1533 (c_prev, m_prev) = state 1534 1535 dtype = inputs.dtype 1536 input_size = inputs.get_shape().with_rank(2)[1] 1537 if input_size.value is None: 1538 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1539 # Variables for the NAS cell. W_m is all matrices multiplying the 1540 # hiddenstate and W_inputs is all matrices multiplying the inputs. 1541 concat_w_m = vs.get_variable("recurrent_kernel", 1542 [num_proj, 8 * self._num_units], dtype) 1543 concat_w_inputs = vs.get_variable( 1544 "kernel", [input_size.value, 8 * self._num_units], dtype) 1545 1546 m_matrix = math_ops.matmul(m_prev, concat_w_m) 1547 inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) 1548 1549 if self._use_biases: 1550 b = vs.get_variable( 1551 "bias", 1552 shape=[8 * self._num_units], 1553 initializer=init_ops.zeros_initializer(), 1554 dtype=dtype) 1555 m_matrix = nn_ops.bias_add(m_matrix, b) 1556 1557 # The NAS cell branches into 8 different splits for both the hiddenstate 1558 # and the input 1559 m_matrix_splits = array_ops.split( 1560 axis=1, num_or_size_splits=8, value=m_matrix) 1561 inputs_matrix_splits = array_ops.split( 1562 axis=1, num_or_size_splits=8, value=inputs_matrix) 1563 1564 # First layer 1565 layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) 1566 layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) 1567 layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) 1568 layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) 1569 layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) 1570 layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) 1571 layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) 1572 layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) 1573 1574 # Second layer 1575 l2_0 = tanh(layer1_0 * layer1_1) 1576 l2_1 = tanh(layer1_2 + layer1_3) 1577 l2_2 = tanh(layer1_4 * layer1_5) 1578 l2_3 = sigmoid(layer1_6 + layer1_7) 1579 1580 # Inject the cell 1581 l2_0 = tanh(l2_0 + c_prev) 1582 1583 # Third layer 1584 l3_0_pre = l2_0 * l2_1 1585 new_c = l3_0_pre # create new cell 1586 l3_0 = l3_0_pre 1587 l3_1 = tanh(l2_2 + l2_3) 1588 1589 # Final layer 1590 new_m = tanh(l3_0 * l3_1) 1591 1592 # Projection layer if specified 1593 if self._num_proj is not None: 1594 concat_w_proj = vs.get_variable("projection_weights", 1595 [self._num_units, self._num_proj], dtype) 1596 new_m = math_ops.matmul(new_m, concat_w_proj) 1597 1598 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) 1599 return new_m, new_state 1600 1601 1602class UGRNNCell(rnn_cell_impl.RNNCell): 1603 """Update Gate Recurrent Neural Network (UGRNN) cell. 1604 1605 Compromise between a LSTM/GRU and a vanilla RNN. There is only one 1606 gate, and that is to determine whether the unit should be 1607 integrating or computing instantaneously. This is the recurrent 1608 idea of the feedforward Highway Network. 1609 1610 This implements the recurrent cell from the paper: 1611 1612 https://arxiv.org/abs/1611.09913 1613 1614 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1615 "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1616 """ 1617 1618 def __init__(self, 1619 num_units, 1620 initializer=None, 1621 forget_bias=1.0, 1622 activation=math_ops.tanh, 1623 reuse=None): 1624 """Initialize the parameters for an UGRNN cell. 1625 1626 Args: 1627 num_units: int, The number of units in the UGRNN cell 1628 initializer: (optional) The initializer to use for the weight matrices. 1629 forget_bias: (optional) float, default 1.0, The initial bias of the 1630 forget gate, used to reduce the scale of forgetting at the beginning 1631 of the training. 1632 activation: (optional) Activation function of the inner states. 1633 Default is `tf.tanh`. 1634 reuse: (optional) Python boolean describing whether to reuse variables 1635 in an existing scope. If not `True`, and the existing scope already has 1636 the given variables, an error is raised. 1637 """ 1638 super(UGRNNCell, self).__init__(_reuse=reuse) 1639 self._num_units = num_units 1640 self._initializer = initializer 1641 self._forget_bias = forget_bias 1642 self._activation = activation 1643 self._reuse = reuse 1644 self._linear = None 1645 1646 @property 1647 def state_size(self): 1648 return self._num_units 1649 1650 @property 1651 def output_size(self): 1652 return self._num_units 1653 1654 def call(self, inputs, state): 1655 """Run one step of UGRNN. 1656 1657 Args: 1658 inputs: input Tensor, 2D, batch x input size. 1659 state: state Tensor, 2D, batch x num units. 1660 1661 Returns: 1662 new_output: batch x num units, Tensor representing the output of the UGRNN 1663 after reading `inputs` when previous state was `state`. Identical to 1664 `new_state`. 1665 new_state: batch x num units, Tensor representing the state of the UGRNN 1666 after reading `inputs` when previous state was `state`. 1667 1668 Raises: 1669 ValueError: If input size cannot be inferred from inputs via 1670 static shape inference. 1671 """ 1672 sigmoid = math_ops.sigmoid 1673 1674 input_size = inputs.get_shape().with_rank(2)[1] 1675 if input_size.value is None: 1676 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1677 1678 with vs.variable_scope( 1679 vs.get_variable_scope(), initializer=self._initializer): 1680 cell_inputs = array_ops.concat([inputs, state], 1) 1681 if self._linear is None: 1682 self._linear = _Linear(cell_inputs, 2 * self._num_units, True) 1683 rnn_matrix = self._linear(cell_inputs) 1684 1685 [g_act, c_act] = array_ops.split( 1686 axis=1, num_or_size_splits=2, value=rnn_matrix) 1687 1688 c = self._activation(c_act) 1689 g = sigmoid(g_act + self._forget_bias) 1690 new_state = g * state + (1.0 - g) * c 1691 new_output = new_state 1692 1693 return new_output, new_state 1694 1695 1696class IntersectionRNNCell(rnn_cell_impl.RNNCell): 1697 """Intersection Recurrent Neural Network (+RNN) cell. 1698 1699 Architecture with coupled recurrent gate as well as coupled depth 1700 gate, designed to improve information flow through stacked RNNs. As the 1701 architecture uses depth gating, the dimensionality of the depth 1702 output (y) also should not change through depth (input size == output size). 1703 To achieve this, the first layer of a stacked Intersection RNN projects 1704 the inputs to N (num units) dimensions. Therefore when initializing an 1705 IntersectionRNNCell, one should set `num_in_proj = N` for the first layer 1706 and use default settings for subsequent layers. 1707 1708 This implements the recurrent cell from the paper: 1709 1710 https://arxiv.org/abs/1611.09913 1711 1712 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1713 "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1714 1715 The Intersection RNN is built for use in deeply stacked 1716 RNNs so it may not achieve best performance with depth 1. 1717 """ 1718 1719 def __init__(self, 1720 num_units, 1721 num_in_proj=None, 1722 initializer=None, 1723 forget_bias=1.0, 1724 y_activation=nn_ops.relu, 1725 reuse=None): 1726 """Initialize the parameters for an +RNN cell. 1727 1728 Args: 1729 num_units: int, The number of units in the +RNN cell 1730 num_in_proj: (optional) int, The input dimensionality for the RNN. 1731 If creating the first layer of an +RNN, this should be set to 1732 `num_units`. Otherwise, this should be set to `None` (default). 1733 If `None`, dimensionality of `inputs` should be equal to `num_units`, 1734 otherwise ValueError is thrown. 1735 initializer: (optional) The initializer to use for the weight matrices. 1736 forget_bias: (optional) float, default 1.0, The initial bias of the 1737 forget gates, used to reduce the scale of forgetting at the beginning 1738 of the training. 1739 y_activation: (optional) Activation function of the states passed 1740 through depth. Default is 'tf.nn.relu`. 1741 reuse: (optional) Python boolean describing whether to reuse variables 1742 in an existing scope. If not `True`, and the existing scope already has 1743 the given variables, an error is raised. 1744 """ 1745 super(IntersectionRNNCell, self).__init__(_reuse=reuse) 1746 self._num_units = num_units 1747 self._initializer = initializer 1748 self._forget_bias = forget_bias 1749 self._num_input_proj = num_in_proj 1750 self._y_activation = y_activation 1751 self._reuse = reuse 1752 self._linear1 = None 1753 self._linear2 = None 1754 1755 @property 1756 def state_size(self): 1757 return self._num_units 1758 1759 @property 1760 def output_size(self): 1761 return self._num_units 1762 1763 def call(self, inputs, state): 1764 """Run one step of the Intersection RNN. 1765 1766 Args: 1767 inputs: input Tensor, 2D, batch x input size. 1768 state: state Tensor, 2D, batch x num units. 1769 1770 Returns: 1771 new_y: batch x num units, Tensor representing the output of the +RNN 1772 after reading `inputs` when previous state was `state`. 1773 new_state: batch x num units, Tensor representing the state of the +RNN 1774 after reading `inputs` when previous state was `state`. 1775 1776 Raises: 1777 ValueError: If input size cannot be inferred from `inputs` via 1778 static shape inference. 1779 ValueError: If input size != output size (these must be equal when 1780 using the Intersection RNN). 1781 """ 1782 sigmoid = math_ops.sigmoid 1783 tanh = math_ops.tanh 1784 1785 input_size = inputs.get_shape().with_rank(2)[1] 1786 if input_size.value is None: 1787 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1788 1789 with vs.variable_scope( 1790 vs.get_variable_scope(), initializer=self._initializer): 1791 # read-in projections (should be used for first layer in deep +RNN 1792 # to transform size of inputs from I --> N) 1793 if input_size.value != self._num_units: 1794 if self._num_input_proj: 1795 with vs.variable_scope("in_projection"): 1796 if self._linear1 is None: 1797 self._linear1 = _Linear(inputs, self._num_units, True) 1798 inputs = self._linear1(inputs) 1799 else: 1800 raise ValueError("Must have input size == output size for " 1801 "Intersection RNN. To fix, num_in_proj should " 1802 "be set to num_units at cell init.") 1803 1804 n_dim = i_dim = self._num_units 1805 cell_inputs = array_ops.concat([inputs, state], 1) 1806 if self._linear2 is None: 1807 self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True) 1808 rnn_matrix = self._linear2(cell_inputs) 1809 1810 gh_act = rnn_matrix[:, :n_dim] # b x n 1811 h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n 1812 gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i 1813 y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i 1814 1815 h = tanh(h_act) 1816 y = self._y_activation(y_act) 1817 gh = sigmoid(gh_act + self._forget_bias) 1818 gy = sigmoid(gy_act + self._forget_bias) 1819 1820 new_state = gh * state + (1.0 - gh) * h # passed thru time 1821 new_y = gy * inputs + (1.0 - gy) * y # passed thru depth 1822 1823 return new_y, new_state 1824 1825 1826_REGISTERED_OPS = None 1827 1828 1829class CompiledWrapper(rnn_cell_impl.RNNCell): 1830 """Wraps step execution in an XLA JIT scope.""" 1831 1832 def __init__(self, cell, compile_stateful=False): 1833 """Create CompiledWrapper cell. 1834 1835 Args: 1836 cell: Instance of `RNNCell`. 1837 compile_stateful: Whether to compile stateful ops like initializers 1838 and random number generators (default: False). 1839 """ 1840 self._cell = cell 1841 self._compile_stateful = compile_stateful 1842 1843 @property 1844 def state_size(self): 1845 return self._cell.state_size 1846 1847 @property 1848 def output_size(self): 1849 return self._cell.output_size 1850 1851 def zero_state(self, batch_size, dtype): 1852 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1853 return self._cell.zero_state(batch_size, dtype) 1854 1855 def __call__(self, inputs, state, scope=None): 1856 if self._compile_stateful: 1857 compile_ops = True 1858 else: 1859 1860 def compile_ops(node_def): 1861 global _REGISTERED_OPS 1862 if _REGISTERED_OPS is None: 1863 _REGISTERED_OPS = op_def_registry.get_registered_ops() 1864 return not _REGISTERED_OPS[node_def.op].is_stateful 1865 1866 with jit.experimental_jit_scope(compile_ops=compile_ops): 1867 return self._cell(inputs, state, scope=scope) 1868 1869 1870def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32): 1871 """Returns an exponential distribution initializer. 1872 1873 Args: 1874 minval: float or a scalar float Tensor. With value > 0. Lower bound of the 1875 range of random values to generate. 1876 maxval: float or a scalar float Tensor. With value > minval. Upper bound of 1877 the range of random values to generate. 1878 seed: An integer. Used to create random seeds. 1879 dtype: The data type. 1880 1881 Returns: 1882 An initializer that generates tensors with an exponential distribution. 1883 """ 1884 1885 def _initializer(shape, dtype=dtype, partition_info=None): 1886 del partition_info # Unused. 1887 return math_ops.exp( 1888 random_ops.random_uniform( 1889 shape, math_ops.log(minval), math_ops.log(maxval), dtype, 1890 seed=seed)) 1891 1892 return _initializer 1893 1894 1895class PhasedLSTMCell(rnn_cell_impl.RNNCell): 1896 """Phased LSTM recurrent network cell. 1897 1898 https://arxiv.org/pdf/1610.09513v1.pdf 1899 """ 1900 1901 def __init__(self, 1902 num_units, 1903 use_peepholes=False, 1904 leak=0.001, 1905 ratio_on=0.1, 1906 trainable_ratio_on=True, 1907 period_init_min=1.0, 1908 period_init_max=1000.0, 1909 reuse=None): 1910 """Initialize the Phased LSTM cell. 1911 1912 Args: 1913 num_units: int, The number of units in the Phased LSTM cell. 1914 use_peepholes: bool, set True to enable peephole connections. 1915 leak: float or scalar float Tensor with value in [0, 1]. Leak applied 1916 during training. 1917 ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the 1918 period during which the gates are open. 1919 trainable_ratio_on: bool, weather ratio_on is trainable. 1920 period_init_min: float or scalar float Tensor. With value > 0. 1921 Minimum value of the initialized period. 1922 The period values are initialized by drawing from the distribution: 1923 e^U(log(period_init_min), log(period_init_max)) 1924 Where U(.,.) is the uniform distribution. 1925 period_init_max: float or scalar float Tensor. 1926 With value > period_init_min. Maximum value of the initialized period. 1927 reuse: (optional) Python boolean describing whether to reuse variables 1928 in an existing scope. If not `True`, and the existing scope already has 1929 the given variables, an error is raised. 1930 """ 1931 super(PhasedLSTMCell, self).__init__(_reuse=reuse) 1932 self._num_units = num_units 1933 self._use_peepholes = use_peepholes 1934 self._leak = leak 1935 self._ratio_on = ratio_on 1936 self._trainable_ratio_on = trainable_ratio_on 1937 self._period_init_min = period_init_min 1938 self._period_init_max = period_init_max 1939 self._reuse = reuse 1940 self._linear1 = None 1941 self._linear2 = None 1942 self._linear3 = None 1943 1944 @property 1945 def state_size(self): 1946 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 1947 1948 @property 1949 def output_size(self): 1950 return self._num_units 1951 1952 def _mod(self, x, y): 1953 """Modulo function that propagates x gradients.""" 1954 return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x 1955 1956 def _get_cycle_ratio(self, time, phase, period): 1957 """Compute the cycle ratio in the dtype of the time.""" 1958 phase_casted = math_ops.cast(phase, dtype=time.dtype) 1959 period_casted = math_ops.cast(period, dtype=time.dtype) 1960 shifted_time = time - phase_casted 1961 cycle_ratio = self._mod(shifted_time, period_casted) / period_casted 1962 return math_ops.cast(cycle_ratio, dtype=dtypes.float32) 1963 1964 def call(self, inputs, state): 1965 """Phased LSTM Cell. 1966 1967 Args: 1968 inputs: A tuple of 2 Tensor. 1969 The first Tensor has shape [batch, 1], and type float32 or float64. 1970 It stores the time. 1971 The second Tensor has shape [batch, features_size], and type float32. 1972 It stores the features. 1973 state: rnn_cell_impl.LSTMStateTuple, state from previous timestep. 1974 1975 Returns: 1976 A tuple containing: 1977 - A Tensor of float32, and shape [batch_size, num_units], representing the 1978 output of the cell. 1979 - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape 1980 [batch_size, num_units], representing the new state and the output. 1981 """ 1982 (c_prev, h_prev) = state 1983 (time, x) = inputs 1984 1985 in_mask_gates = [x, h_prev] 1986 if self._use_peepholes: 1987 in_mask_gates.append(c_prev) 1988 1989 with vs.variable_scope("mask_gates"): 1990 if self._linear1 is None: 1991 self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) 1992 1993 mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates)) 1994 [input_gate, forget_gate] = array_ops.split( 1995 axis=1, num_or_size_splits=2, value=mask_gates) 1996 1997 with vs.variable_scope("new_input"): 1998 if self._linear2 is None: 1999 self._linear2 = _Linear([x, h_prev], self._num_units, True) 2000 new_input = math_ops.tanh(self._linear2([x, h_prev])) 2001 2002 new_c = (c_prev * forget_gate + input_gate * new_input) 2003 2004 in_out_gate = [x, h_prev] 2005 if self._use_peepholes: 2006 in_out_gate.append(new_c) 2007 2008 with vs.variable_scope("output_gate"): 2009 if self._linear3 is None: 2010 self._linear3 = _Linear(in_out_gate, self._num_units, True) 2011 output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) 2012 2013 new_h = math_ops.tanh(new_c) * output_gate 2014 2015 period = vs.get_variable( 2016 "period", [self._num_units], 2017 initializer=_random_exp_initializer(self._period_init_min, 2018 self._period_init_max)) 2019 phase = vs.get_variable( 2020 "phase", [self._num_units], 2021 initializer=init_ops.random_uniform_initializer(0., 2022 period.initial_value)) 2023 ratio_on = vs.get_variable( 2024 "ratio_on", [self._num_units], 2025 initializer=init_ops.constant_initializer(self._ratio_on), 2026 trainable=self._trainable_ratio_on) 2027 2028 cycle_ratio = self._get_cycle_ratio(time, phase, period) 2029 2030 k_up = 2 * cycle_ratio / ratio_on 2031 k_down = 2 - k_up 2032 k_closed = self._leak * cycle_ratio 2033 2034 k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) 2035 k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) 2036 2037 new_c = k * new_c + (1 - k) * c_prev 2038 new_h = k * new_h + (1 - k) * h_prev 2039 2040 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 2041 2042 return new_h, new_state 2043 2044 2045class ConvLSTMCell(rnn_cell_impl.RNNCell): 2046 """Convolutional LSTM recurrent network cell. 2047 2048 https://arxiv.org/pdf/1506.04214v1.pdf 2049 """ 2050 2051 def __init__(self, 2052 conv_ndims, 2053 input_shape, 2054 output_channels, 2055 kernel_shape, 2056 use_bias=True, 2057 skip_connection=False, 2058 forget_bias=1.0, 2059 initializers=None, 2060 name="conv_lstm_cell"): 2061 """Construct ConvLSTMCell. 2062 Args: 2063 conv_ndims: Convolution dimensionality (1, 2 or 3). 2064 input_shape: Shape of the input as int tuple, excluding the batch size. 2065 output_channels: int, number of output channels of the conv LSTM. 2066 kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). 2067 use_bias: Use bias in convolutions. 2068 skip_connection: If set to `True`, concatenate the input to the 2069 output of the conv LSTM. Default: `False`. 2070 forget_bias: Forget bias. 2071 name: Name of the module. 2072 Raises: 2073 ValueError: If `skip_connection` is `True` and stride is different from 1 2074 or if `input_shape` is incompatible with `conv_ndims`. 2075 """ 2076 super(ConvLSTMCell, self).__init__(name=name) 2077 2078 if conv_ndims != len(input_shape) - 1: 2079 raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( 2080 input_shape, conv_ndims)) 2081 2082 self._conv_ndims = conv_ndims 2083 self._input_shape = input_shape 2084 self._output_channels = output_channels 2085 self._kernel_shape = kernel_shape 2086 self._use_bias = use_bias 2087 self._forget_bias = forget_bias 2088 self._skip_connection = skip_connection 2089 2090 self._total_output_channels = output_channels 2091 if self._skip_connection: 2092 self._total_output_channels += self._input_shape[-1] 2093 2094 state_size = tensor_shape.TensorShape( 2095 self._input_shape[:-1] + [self._output_channels]) 2096 self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) 2097 self._output_size = tensor_shape.TensorShape( 2098 self._input_shape[:-1] + [self._total_output_channels]) 2099 2100 @property 2101 def output_size(self): 2102 return self._output_size 2103 2104 @property 2105 def state_size(self): 2106 return self._state_size 2107 2108 def call(self, inputs, state, scope=None): 2109 cell, hidden = state 2110 new_hidden = _conv([inputs, hidden], self._kernel_shape, 2111 4 * self._output_channels, self._use_bias) 2112 gates = array_ops.split( 2113 value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) 2114 2115 input_gate, new_input, forget_gate, output_gate = gates 2116 new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell 2117 new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) 2118 output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) 2119 2120 if self._skip_connection: 2121 output = array_ops.concat([output, inputs], axis=-1) 2122 new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) 2123 return output, new_state 2124 2125 2126class Conv1DLSTMCell(ConvLSTMCell): 2127 """1D Convolutional LSTM recurrent network cell. 2128 2129 https://arxiv.org/pdf/1506.04214v1.pdf 2130 """ 2131 2132 def __init__(self, name="conv_1d_lstm_cell", **kwargs): 2133 """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" 2134 super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs) 2135 2136 2137class Conv2DLSTMCell(ConvLSTMCell): 2138 """2D Convolutional LSTM recurrent network cell. 2139 2140 https://arxiv.org/pdf/1506.04214v1.pdf 2141 """ 2142 2143 def __init__(self, name="conv_2d_lstm_cell", **kwargs): 2144 """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" 2145 super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs) 2146 2147 2148class Conv3DLSTMCell(ConvLSTMCell): 2149 """3D Convolutional LSTM recurrent network cell. 2150 2151 https://arxiv.org/pdf/1506.04214v1.pdf 2152 """ 2153 2154 def __init__(self, name="conv_3d_lstm_cell", **kwargs): 2155 """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" 2156 super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) 2157 2158 2159def _conv(args, filter_size, num_features, bias, bias_start=0.0): 2160 """convolution: 2161 Args: 2162 args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, 2163 batch x n, Tensors. 2164 filter_size: int tuple of filter height and width. 2165 num_features: int, number of features. 2166 bias_start: starting value to initialize the bias; 0 by default. 2167 Returns: 2168 A 3D, 4D, or 5D Tensor with shape [batch ... num_features] 2169 Raises: 2170 ValueError: if some of the arguments has unspecified or wrong shape. 2171 """ 2172 2173 # Calculate the total size of arguments on dimension 1. 2174 total_arg_size_depth = 0 2175 shapes = [a.get_shape().as_list() for a in args] 2176 shape_length = len(shapes[0]) 2177 for shape in shapes: 2178 if len(shape) not in [3, 4, 5]: 2179 raise ValueError("Conv Linear expects 3D, 4D " 2180 "or 5D arguments: %s" % str(shapes)) 2181 if len(shape) != len(shapes[0]): 2182 raise ValueError("Conv Linear expects all args " 2183 "to be of same Dimension: %s" % str(shapes)) 2184 else: 2185 total_arg_size_depth += shape[-1] 2186 dtype = [a.dtype for a in args][0] 2187 2188 # determine correct conv operation 2189 if shape_length == 3: 2190 conv_op = nn_ops.conv1d 2191 strides = 1 2192 elif shape_length == 4: 2193 conv_op = nn_ops.conv2d 2194 strides = shape_length * [1] 2195 elif shape_length == 5: 2196 conv_op = nn_ops.conv3d 2197 strides = shape_length * [1] 2198 2199 # Now the computation. 2200 kernel = vs.get_variable( 2201 "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) 2202 if len(args) == 1: 2203 res = conv_op(args[0], kernel, strides, padding="SAME") 2204 else: 2205 res = conv_op( 2206 array_ops.concat(axis=shape_length - 1, values=args), 2207 kernel, 2208 strides, 2209 padding="SAME") 2210 if not bias: 2211 return res 2212 bias_term = vs.get_variable( 2213 "biases", [num_features], 2214 dtype=dtype, 2215 initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) 2216 return res + bias_term 2217 2218 2219class GLSTMCell(rnn_cell_impl.RNNCell): 2220 """Group LSTM cell (G-LSTM). 2221 2222 The implementation is based on: 2223 2224 https://arxiv.org/abs/1703.10722 2225 2226 O. Kuchaiev and B. Ginsburg 2227 "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. 2228 """ 2229 2230 def __init__(self, 2231 num_units, 2232 initializer=None, 2233 num_proj=None, 2234 number_of_groups=1, 2235 forget_bias=1.0, 2236 activation=math_ops.tanh, 2237 reuse=None): 2238 """Initialize the parameters of G-LSTM cell. 2239 2240 Args: 2241 num_units: int, The number of units in the G-LSTM cell 2242 initializer: (optional) The initializer to use for the weight and 2243 projection matrices. 2244 num_proj: (optional) int, The output dimensionality for the projection 2245 matrices. If None, no projection is performed. 2246 number_of_groups: (optional) int, number of groups to use. 2247 If `number_of_groups` is 1, then it should be equivalent to LSTM cell 2248 forget_bias: Biases of the forget gate are initialized by default to 1 2249 in order to reduce the scale of forgetting at the beginning of 2250 the training. 2251 activation: Activation function of the inner states. 2252 reuse: (optional) Python boolean describing whether to reuse variables 2253 in an existing scope. If not `True`, and the existing scope already 2254 has the given variables, an error is raised. 2255 2256 Raises: 2257 ValueError: If `num_units` or `num_proj` is not divisible by 2258 `number_of_groups`. 2259 """ 2260 super(GLSTMCell, self).__init__(_reuse=reuse) 2261 self._num_units = num_units 2262 self._initializer = initializer 2263 self._num_proj = num_proj 2264 self._forget_bias = forget_bias 2265 self._activation = activation 2266 self._number_of_groups = number_of_groups 2267 2268 if self._num_units % self._number_of_groups != 0: 2269 raise ValueError("num_units must be divisible by number_of_groups") 2270 if self._num_proj: 2271 if self._num_proj % self._number_of_groups != 0: 2272 raise ValueError("num_proj must be divisible by number_of_groups") 2273 self._group_shape = [ 2274 int(self._num_proj / self._number_of_groups), 2275 int(self._num_units / self._number_of_groups) 2276 ] 2277 else: 2278 self._group_shape = [ 2279 int(self._num_units / self._number_of_groups), 2280 int(self._num_units / self._number_of_groups) 2281 ] 2282 2283 if num_proj: 2284 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2285 self._output_size = num_proj 2286 else: 2287 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2288 self._output_size = num_units 2289 self._linear1 = [None] * number_of_groups 2290 self._linear2 = None 2291 2292 @property 2293 def state_size(self): 2294 return self._state_size 2295 2296 @property 2297 def output_size(self): 2298 return self._output_size 2299 2300 def _get_input_for_group(self, inputs, group_id, group_size): 2301 """Slices inputs into groups to prepare for processing by cell's groups 2302 2303 Args: 2304 inputs: cell input or it's previous state, 2305 a Tensor, 2D, [batch x num_units] 2306 group_id: group id, a Scalar, for which to prepare input 2307 group_size: size of the group 2308 2309 Returns: 2310 subset of inputs corresponding to group "group_id", 2311 a Tensor, 2D, [batch x num_units/number_of_groups] 2312 """ 2313 return array_ops.slice( 2314 input_=inputs, 2315 begin=[0, group_id * group_size], 2316 size=[self._batch_size, group_size], 2317 name=("GLSTM_group%d_input_generation" % group_id)) 2318 2319 def call(self, inputs, state): 2320 """Run one step of G-LSTM. 2321 2322 Args: 2323 inputs: input Tensor, 2D, [batch x num_units]. 2324 state: this must be a tuple of state Tensors, both `2-D`, 2325 with column sizes `c_state` and `m_state`. 2326 2327 Returns: 2328 A tuple containing: 2329 2330 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2331 G-LSTM after reading `inputs` when previous state was `state`. 2332 Here output_dim is: 2333 num_proj if num_proj was set, 2334 num_units otherwise. 2335 - LSTMStateTuple representing the new state of G-LSTM cell 2336 after reading `inputs` when the previous state was `state`. 2337 2338 Raises: 2339 ValueError: If input size cannot be inferred from inputs via 2340 static shape inference. 2341 """ 2342 (c_prev, m_prev) = state 2343 2344 self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 2345 dtype = inputs.dtype 2346 scope = vs.get_variable_scope() 2347 with vs.variable_scope(scope, initializer=self._initializer): 2348 i_parts = [] 2349 j_parts = [] 2350 f_parts = [] 2351 o_parts = [] 2352 2353 for group_id in range(self._number_of_groups): 2354 with vs.variable_scope("group%d" % group_id): 2355 x_g_id = array_ops.concat( 2356 [ 2357 self._get_input_for_group(inputs, group_id, 2358 self._group_shape[0]), 2359 self._get_input_for_group(m_prev, group_id, 2360 self._group_shape[0]) 2361 ], 2362 axis=1) 2363 linear = self._linear1[group_id] 2364 if linear is None: 2365 linear = _Linear(x_g_id, 4 * self._group_shape[1], False) 2366 self._linear1[group_id] = linear 2367 R_k = linear(x_g_id) # pylint: disable=invalid-name 2368 i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) 2369 2370 i_parts.append(i_k) 2371 j_parts.append(j_k) 2372 f_parts.append(f_k) 2373 o_parts.append(o_k) 2374 2375 bi = vs.get_variable( 2376 name="bias_i", 2377 shape=[self._num_units], 2378 dtype=dtype, 2379 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2380 bj = vs.get_variable( 2381 name="bias_j", 2382 shape=[self._num_units], 2383 dtype=dtype, 2384 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2385 bf = vs.get_variable( 2386 name="bias_f", 2387 shape=[self._num_units], 2388 dtype=dtype, 2389 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2390 bo = vs.get_variable( 2391 name="bias_o", 2392 shape=[self._num_units], 2393 dtype=dtype, 2394 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2395 2396 i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) 2397 j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) 2398 f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) 2399 o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) 2400 2401 c = ( 2402 math_ops.sigmoid(f + self._forget_bias) * c_prev + 2403 math_ops.sigmoid(i) * math_ops.tanh(j)) 2404 m = math_ops.sigmoid(o) * self._activation(c) 2405 2406 if self._num_proj is not None: 2407 with vs.variable_scope("projection"): 2408 if self._linear2 is None: 2409 self._linear2 = _Linear(m, self._num_proj, False) 2410 m = self._linear2(m) 2411 2412 new_state = rnn_cell_impl.LSTMStateTuple(c, m) 2413 return m, new_state 2414 2415 2416class LayerNormLSTMCell(rnn_cell_impl.RNNCell): 2417 """Long short-term memory unit (LSTM) recurrent network cell. 2418 2419 The default non-peephole implementation is based on: 2420 2421 http://www.bioinf.jku.at/publications/older/2604.pdf 2422 2423 S. Hochreiter and J. Schmidhuber. 2424 "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 2425 2426 The peephole implementation is based on: 2427 2428 https://research.google.com/pubs/archive/43905.pdf 2429 2430 Hasim Sak, Andrew Senior, and Francoise Beaufays. 2431 "Long short-term memory recurrent neural network architectures for 2432 large scale acoustic modeling." INTERSPEECH, 2014. 2433 2434 The class uses optional peep-hole connections, optional cell clipping, and 2435 an optional projection layer. 2436 2437 Layer normalization implementation is based on: 2438 2439 https://arxiv.org/abs/1607.06450. 2440 2441 "Layer Normalization" 2442 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 2443 2444 and is applied before the internal nonlinearities. 2445 2446 """ 2447 2448 def __init__(self, 2449 num_units, 2450 use_peepholes=False, 2451 cell_clip=None, 2452 initializer=None, 2453 num_proj=None, 2454 proj_clip=None, 2455 forget_bias=1.0, 2456 activation=None, 2457 layer_norm=False, 2458 norm_gain=1.0, 2459 norm_shift=0.0, 2460 reuse=None): 2461 """Initialize the parameters for an LSTM cell. 2462 2463 Args: 2464 num_units: int, The number of units in the LSTM cell 2465 use_peepholes: bool, set True to enable diagonal/peephole connections. 2466 cell_clip: (optional) A float value, if provided the cell state is clipped 2467 by this value prior to the cell output activation. 2468 initializer: (optional) The initializer to use for the weight and 2469 projection matrices. 2470 num_proj: (optional) int, The output dimensionality for the projection 2471 matrices. If None, no projection is performed. 2472 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2473 provided, then the projected values are clipped elementwise to within 2474 `[-proj_clip, proj_clip]`. 2475 forget_bias: Biases of the forget gate are initialized by default to 1 2476 in order to reduce the scale of forgetting at the beginning of 2477 the training. Must set it manually to `0.0` when restoring from 2478 CudnnLSTM trained checkpoints. 2479 activation: Activation function of the inner states. Default: `tanh`. 2480 layer_norm: If `True`, layer normalization will be applied. 2481 norm_gain: float, The layer normalization gain initial value. If 2482 `layer_norm` has been set to `False`, this argument will be ignored. 2483 norm_shift: float, The layer normalization shift initial value. If 2484 `layer_norm` has been set to `False`, this argument will be ignored. 2485 reuse: (optional) Python boolean describing whether to reuse variables 2486 in an existing scope. If not `True`, and the existing scope already has 2487 the given variables, an error is raised. 2488 2489 When restoring from CudnnLSTM-trained checkpoints, must use 2490 CudnnCompatibleLSTMCell instead. 2491 """ 2492 super(LayerNormLSTMCell, self).__init__(_reuse=reuse) 2493 2494 self._num_units = num_units 2495 self._use_peepholes = use_peepholes 2496 self._cell_clip = cell_clip 2497 self._initializer = initializer 2498 self._num_proj = num_proj 2499 self._proj_clip = proj_clip 2500 self._forget_bias = forget_bias 2501 self._activation = activation or math_ops.tanh 2502 self._layer_norm = layer_norm 2503 self._norm_gain = norm_gain 2504 self._norm_shift = norm_shift 2505 2506 if num_proj: 2507 self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) 2508 self._output_size = num_proj 2509 else: 2510 self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) 2511 self._output_size = num_units 2512 2513 @property 2514 def state_size(self): 2515 return self._state_size 2516 2517 @property 2518 def output_size(self): 2519 return self._output_size 2520 2521 def _linear(self, 2522 args, 2523 output_size, 2524 bias, 2525 bias_initializer=None, 2526 kernel_initializer=None, 2527 layer_norm=False): 2528 """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. 2529 2530 Args: 2531 args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2532 output_size: int, second dimension of W[i]. 2533 bias: boolean, whether to add a bias term or not. 2534 bias_initializer: starting value to initialize the bias 2535 (default is all zeros). 2536 kernel_initializer: starting value to initialize the weight. 2537 layer_norm: boolean, whether to apply layer normalization. 2538 2539 2540 Returns: 2541 A 2D Tensor with shape [batch x output_size] taking value 2542 sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. 2543 2544 Raises: 2545 ValueError: if some of the arguments has unspecified or wrong shape. 2546 """ 2547 if args is None or (nest.is_sequence(args) and not args): 2548 raise ValueError("`args` must be specified") 2549 if not nest.is_sequence(args): 2550 args = [args] 2551 2552 # Calculate the total size of arguments on dimension 1. 2553 total_arg_size = 0 2554 shapes = [a.get_shape() for a in args] 2555 for shape in shapes: 2556 if shape.ndims != 2: 2557 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2558 if shape[1].value is None: 2559 raise ValueError("linear expects shape[1] to be provided for shape %s, " 2560 "but saw %s" % (shape, shape[1])) 2561 else: 2562 total_arg_size += shape[1].value 2563 2564 dtype = [a.dtype for a in args][0] 2565 2566 # Now the computation. 2567 scope = vs.get_variable_scope() 2568 with vs.variable_scope(scope) as outer_scope: 2569 weights = vs.get_variable( 2570 "kernel", [total_arg_size, output_size], 2571 dtype=dtype, 2572 initializer=kernel_initializer) 2573 if len(args) == 1: 2574 res = math_ops.matmul(args[0], weights) 2575 else: 2576 res = math_ops.matmul(array_ops.concat(args, 1), weights) 2577 if not bias: 2578 return res 2579 with vs.variable_scope(outer_scope) as inner_scope: 2580 inner_scope.set_partitioner(None) 2581 if bias_initializer is None: 2582 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2583 biases = vs.get_variable( 2584 "bias", [output_size], dtype=dtype, initializer=bias_initializer) 2585 2586 if not layer_norm: 2587 res = nn_ops.bias_add(res, biases) 2588 2589 return res 2590 2591 def call(self, inputs, state): 2592 """Run one step of LSTM. 2593 2594 Args: 2595 inputs: input Tensor, 2D, batch x num_units. 2596 state: this must be a tuple of state Tensors, 2597 both `2-D`, with column sizes `c_state` and 2598 `m_state`. 2599 2600 Returns: 2601 A tuple containing: 2602 2603 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2604 LSTM after reading `inputs` when previous state was `state`. 2605 Here output_dim is: 2606 num_proj if num_proj was set, 2607 num_units otherwise. 2608 - Tensor(s) representing the new state of LSTM after reading `inputs` when 2609 the previous state was `state`. Same type and shape(s) as `state`. 2610 2611 Raises: 2612 ValueError: If input size cannot be inferred from inputs via 2613 static shape inference. 2614 """ 2615 sigmoid = math_ops.sigmoid 2616 2617 (c_prev, m_prev) = state 2618 2619 dtype = inputs.dtype 2620 input_size = inputs.get_shape().with_rank(2)[1] 2621 if input_size.value is None: 2622 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 2623 scope = vs.get_variable_scope() 2624 with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: 2625 2626 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 2627 lstm_matrix = self._linear( 2628 [inputs, m_prev], 2629 4 * self._num_units, 2630 bias=True, 2631 bias_initializer=None, 2632 layer_norm=self._layer_norm) 2633 i, j, f, o = array_ops.split( 2634 value=lstm_matrix, num_or_size_splits=4, axis=1) 2635 2636 if self._layer_norm: 2637 i = _norm(self._norm_gain, self._norm_shift, i, "input") 2638 j = _norm(self._norm_gain, self._norm_shift, j, "transform") 2639 f = _norm(self._norm_gain, self._norm_shift, f, "forget") 2640 o = _norm(self._norm_gain, self._norm_shift, o, "output") 2641 2642 # Diagonal connections 2643 if self._use_peepholes: 2644 with vs.variable_scope(unit_scope): 2645 w_f_diag = vs.get_variable( 2646 "w_f_diag", shape=[self._num_units], dtype=dtype) 2647 w_i_diag = vs.get_variable( 2648 "w_i_diag", shape=[self._num_units], dtype=dtype) 2649 w_o_diag = vs.get_variable( 2650 "w_o_diag", shape=[self._num_units], dtype=dtype) 2651 2652 if self._use_peepholes: 2653 c = ( 2654 sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 2655 sigmoid(i + w_i_diag * c_prev) * self._activation(j)) 2656 else: 2657 c = ( 2658 sigmoid(f + self._forget_bias) * c_prev + 2659 sigmoid(i) * self._activation(j)) 2660 2661 if self._layer_norm: 2662 c = _norm(self._norm_gain, self._norm_shift, c, "state") 2663 2664 if self._cell_clip is not None: 2665 # pylint: disable=invalid-unary-operand-type 2666 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 2667 # pylint: enable=invalid-unary-operand-type 2668 if self._use_peepholes: 2669 m = sigmoid(o + w_o_diag * c) * self._activation(c) 2670 else: 2671 m = sigmoid(o) * self._activation(c) 2672 2673 if self._num_proj is not None: 2674 with vs.variable_scope("projection"): 2675 m = self._linear(m, self._num_proj, bias=False) 2676 2677 if self._proj_clip is not None: 2678 # pylint: disable=invalid-unary-operand-type 2679 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 2680 # pylint: enable=invalid-unary-operand-type 2681 2682 new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) 2683 return m, new_state 2684 2685 2686class SRUCell(rnn_cell_impl.LayerRNNCell): 2687 """SRU, Simple Recurrent Unit 2688 2689 Implementation based on 2690 Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). 2691 2692 This variation of RNN cell is characterized by the simplified data 2693 dependence 2694 between hidden states of two consecutive time steps. Traditionally, hidden 2695 states from a cell at time step t-1 needs to be multiplied with a matrix 2696 W_hh before being fed into the ensuing cell at time step t. 2697 This flavor of RNN replaces the matrix multiplication between h_{t-1} 2698 and W_hh with a pointwise multiplication, resulting in performance 2699 gain. 2700 2701 Args: 2702 num_units: int, The number of units in the SRU cell. 2703 activation: Nonlinearity to use. Default: `tanh`. 2704 reuse: (optional) Python boolean describing whether to reuse variables 2705 in an existing scope. If not `True`, and the existing scope already has 2706 the given variables, an error is raised. 2707 name: (optional) String, the name of the layer. Layers with the same name 2708 will share weights, but to avoid mistakes we require reuse=True in such 2709 cases. 2710 """ 2711 2712 def __init__(self, num_units, activation=None, reuse=None, name=None): 2713 super(SRUCell, self).__init__(_reuse=reuse, name=name) 2714 self._num_units = num_units 2715 self._activation = activation or math_ops.tanh 2716 2717 # Restrict inputs to be 2-dimensional matrices 2718 self.input_spec = base_layer.InputSpec(ndim=2) 2719 2720 @property 2721 def state_size(self): 2722 return self._num_units 2723 2724 @property 2725 def output_size(self): 2726 return self._num_units 2727 2728 def build(self, inputs_shape): 2729 if inputs_shape[1].value is None: 2730 raise ValueError( 2731 "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 2732 2733 input_depth = inputs_shape[1].value 2734 2735 self._kernel = self.add_variable( 2736 rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2737 shape=[input_depth, 4 * self._num_units]) 2738 2739 self._bias = self.add_variable( 2740 rnn_cell_impl._BIAS_VARIABLE_NAME, 2741 shape=[2 * self._num_units], 2742 initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) 2743 2744 self._built = True 2745 2746 def call(self, inputs, state): 2747 """Simple recurrent unit (SRU) with num_units cells.""" 2748 2749 U = math_ops.matmul(inputs, self._kernel) 2750 x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( 2751 value=U, num_or_size_splits=4, axis=1) 2752 2753 f_r = math_ops.sigmoid( 2754 nn_ops.bias_add( 2755 array_ops.concat([f_intermediate, r_intermediate], 1), self._bias)) 2756 f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) 2757 2758 c = f * state + (1.0 - f) * x_bar 2759 h = r * self._activation(c) + (1.0 - r) * x_tx 2760 2761 return h, c 2762 2763 2764class WeightNormLSTMCell(rnn_cell_impl.RNNCell): 2765 """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. 2766 2767 The weight-norm implementation is based on: 2768 https://arxiv.org/abs/1602.07868 2769 Tim Salimans, Diederik P. Kingma. 2770 Weight Normalization: A Simple Reparameterization to Accelerate 2771 Training of Deep Neural Networks 2772 2773 The default LSTM implementation based on: 2774 http://www.bioinf.jku.at/publications/older/2604.pdf 2775 S. Hochreiter and J. Schmidhuber. 2776 "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 2777 2778 The class uses optional peephole connections, optional cell clipping 2779 and an optional projection layer. 2780 2781 The optional peephole implementation is based on: 2782 https://research.google.com/pubs/archive/43905.pdf 2783 Hasim Sak, Andrew Senior, and Francoise Beaufays. 2784 "Long short-term memory recurrent neural network architectures for 2785 large scale acoustic modeling." INTERSPEECH, 2014. 2786 """ 2787 2788 def __init__(self, 2789 num_units, 2790 norm=True, 2791 use_peepholes=False, 2792 cell_clip=None, 2793 initializer=None, 2794 num_proj=None, 2795 proj_clip=None, 2796 forget_bias=1, 2797 activation=None, 2798 reuse=None): 2799 """Initialize the parameters of a weight-normalized LSTM cell. 2800 2801 Args: 2802 num_units: int, The number of units in the LSTM cell 2803 norm: If `True`, apply normalization to the weight matrices. If False, 2804 the result is identical to that obtained from `rnn_cell_impl.LSTMCell` 2805 use_peepholes: bool, set `True` to enable diagonal/peephole connections. 2806 cell_clip: (optional) A float value, if provided the cell state is clipped 2807 by this value prior to the cell output activation. 2808 initializer: (optional) The initializer to use for the weight matrices. 2809 num_proj: (optional) int, The output dimensionality for the projection 2810 matrices. If None, no projection is performed. 2811 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2812 provided, then the projected values are clipped elementwise to within 2813 `[-proj_clip, proj_clip]`. 2814 forget_bias: Biases of the forget gate are initialized by default to 1 2815 in order to reduce the scale of forgetting at the beginning of 2816 the training. 2817 activation: Activation function of the inner states. Default: `tanh`. 2818 reuse: (optional) Python boolean describing whether to reuse variables 2819 in an existing scope. If not `True`, and the existing scope already has 2820 the given variables, an error is raised. 2821 """ 2822 super(WeightNormLSTMCell, self).__init__(_reuse=reuse) 2823 2824 self._scope = "wn_lstm_cell" 2825 self._num_units = num_units 2826 self._norm = norm 2827 self._initializer = initializer 2828 self._use_peepholes = use_peepholes 2829 self._cell_clip = cell_clip 2830 self._num_proj = num_proj 2831 self._proj_clip = proj_clip 2832 self._activation = activation or math_ops.tanh 2833 self._forget_bias = forget_bias 2834 2835 self._weights_variable_name = "kernel" 2836 self._bias_variable_name = "bias" 2837 2838 if num_proj: 2839 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2840 self._output_size = num_proj 2841 else: 2842 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2843 self._output_size = num_units 2844 2845 @property 2846 def state_size(self): 2847 return self._state_size 2848 2849 @property 2850 def output_size(self): 2851 return self._output_size 2852 2853 def _normalize(self, weight, name): 2854 """Apply weight normalization. 2855 2856 Args: 2857 weight: a 2D tensor with known number of columns. 2858 name: string, variable name for the normalizer. 2859 Returns: 2860 A tensor with the same shape as `weight`. 2861 """ 2862 2863 output_size = weight.get_shape().as_list()[1] 2864 g = vs.get_variable(name, [output_size], dtype=weight.dtype) 2865 return nn_impl.l2_normalize(weight, dim=0) * g 2866 2867 def _linear(self, 2868 args, 2869 output_size, 2870 norm, 2871 bias, 2872 bias_initializer=None, 2873 kernel_initializer=None): 2874 """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 2875 2876 Args: 2877 args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2878 output_size: int, second dimension of W[i]. 2879 bias: boolean, whether to add a bias term or not. 2880 bias_initializer: starting value to initialize the bias 2881 (default is all zeros). 2882 kernel_initializer: starting value to initialize the weight. 2883 2884 Returns: 2885 A 2D Tensor with shape [batch x output_size] equal to 2886 sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 2887 2888 Raises: 2889 ValueError: if some of the arguments has unspecified or wrong shape. 2890 """ 2891 if args is None or (nest.is_sequence(args) and not args): 2892 raise ValueError("`args` must be specified") 2893 if not nest.is_sequence(args): 2894 args = [args] 2895 2896 # Calculate the total size of arguments on dimension 1. 2897 total_arg_size = 0 2898 shapes = [a.get_shape() for a in args] 2899 for shape in shapes: 2900 if shape.ndims != 2: 2901 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2902 if shape[1].value is None: 2903 raise ValueError("linear expects shape[1] to be provided for shape %s, " 2904 "but saw %s" % (shape, shape[1])) 2905 else: 2906 total_arg_size += shape[1].value 2907 2908 dtype = [a.dtype for a in args][0] 2909 2910 # Now the computation. 2911 scope = vs.get_variable_scope() 2912 with vs.variable_scope(scope) as outer_scope: 2913 weights = vs.get_variable( 2914 self._weights_variable_name, [total_arg_size, output_size], 2915 dtype=dtype, 2916 initializer=kernel_initializer) 2917 if norm: 2918 wn = [] 2919 st = 0 2920 with ops.control_dependencies(None): 2921 for i in range(len(args)): 2922 en = st + shapes[i][1].value 2923 wn.append( 2924 self._normalize(weights[st:en, :], name="norm_{}".format(i))) 2925 st = en 2926 2927 weights = array_ops.concat(wn, axis=0) 2928 2929 if len(args) == 1: 2930 res = math_ops.matmul(args[0], weights) 2931 else: 2932 res = math_ops.matmul(array_ops.concat(args, 1), weights) 2933 if not bias: 2934 return res 2935 2936 with vs.variable_scope(outer_scope) as inner_scope: 2937 inner_scope.set_partitioner(None) 2938 if bias_initializer is None: 2939 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2940 2941 biases = vs.get_variable( 2942 self._bias_variable_name, [output_size], 2943 dtype=dtype, 2944 initializer=bias_initializer) 2945 2946 return nn_ops.bias_add(res, biases) 2947 2948 def call(self, inputs, state): 2949 """Run one step of LSTM. 2950 2951 Args: 2952 inputs: input Tensor, 2D, batch x num_units. 2953 state: A tuple of state Tensors, both `2-D`, with column sizes 2954 `c_state` and `m_state`. 2955 2956 Returns: 2957 A tuple containing: 2958 2959 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2960 LSTM after reading `inputs` when previous state was `state`. 2961 Here output_dim is: 2962 num_proj if num_proj was set, 2963 num_units otherwise. 2964 - Tensor(s) representing the new state of LSTM after reading `inputs` when 2965 the previous state was `state`. Same type and shape(s) as `state`. 2966 2967 Raises: 2968 ValueError: If input size cannot be inferred from inputs via 2969 static shape inference. 2970 """ 2971 dtype = inputs.dtype 2972 num_units = self._num_units 2973 sigmoid = math_ops.sigmoid 2974 c, h = state 2975 2976 input_size = inputs.get_shape().with_rank(2)[1] 2977 if input_size.value is None: 2978 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 2979 2980 with vs.variable_scope(self._scope, initializer=self._initializer): 2981 2982 concat = self._linear( 2983 [inputs, h], 4 * num_units, norm=self._norm, bias=True) 2984 2985 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 2986 i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 2987 2988 if self._use_peepholes: 2989 w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) 2990 w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) 2991 w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) 2992 2993 new_c = ( 2994 c * sigmoid(f + self._forget_bias + w_f_diag * c) + 2995 sigmoid(i + w_i_diag * c) * self._activation(j)) 2996 else: 2997 new_c = ( 2998 c * sigmoid(f + self._forget_bias) + 2999 sigmoid(i) * self._activation(j)) 3000 3001 if self._cell_clip is not None: 3002 # pylint: disable=invalid-unary-operand-type 3003 new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) 3004 # pylint: enable=invalid-unary-operand-type 3005 if self._use_peepholes: 3006 new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) 3007 else: 3008 new_h = sigmoid(o) * self._activation(new_c) 3009 3010 if self._num_proj is not None: 3011 with vs.variable_scope("projection"): 3012 new_h = self._linear( 3013 new_h, self._num_proj, norm=self._norm, bias=False) 3014 3015 if self._proj_clip is not None: 3016 # pylint: disable=invalid-unary-operand-type 3017 new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, 3018 self._proj_clip) 3019 # pylint: enable=invalid-unary-operand-type 3020 3021 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 3022 return new_h, new_state 3023