1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for constructing RNN Cells.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import math 22 23from tensorflow.contrib.compiler import jit 24from tensorflow.contrib.layers.python.layers import layers 25from tensorflow.contrib.rnn.python.ops import core_rnn_cell 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import op_def_registry 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.keras import activations 32from tensorflow.python.keras import initializers 33from tensorflow.python.keras.engine import input_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import clip_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_array_ops 38from tensorflow.python.ops import init_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import nn_impl # pylint: disable=unused-import 41from tensorflow.python.ops import nn_ops 42from tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import 43from tensorflow.python.ops import random_ops 44from tensorflow.python.ops import rnn_cell_impl 45from tensorflow.python.ops import variable_scope as vs 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.util import nest 48 49 50def _get_concat_variable(name, shape, dtype, num_shards): 51 """Get a sharded variable concatenated into one tensor.""" 52 sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 53 if len(sharded_variable) == 1: 54 return sharded_variable[0] 55 56 concat_name = name + "/concat" 57 concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 58 for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 59 if value.name == concat_full_name: 60 return value 61 62 concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 63 ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) 64 return concat_variable 65 66 67def _get_sharded_variable(name, shape, dtype, num_shards): 68 """Get a list of sharded variables with the given dtype.""" 69 if num_shards > shape[0]: 70 raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, 71 num_shards)) 72 unit_shard_size = int(math.floor(shape[0] / num_shards)) 73 remaining_rows = shape[0] - unit_shard_size * num_shards 74 75 shards = [] 76 for i in range(num_shards): 77 current_size = unit_shard_size 78 if i < remaining_rows: 79 current_size += 1 80 shards.append( 81 vs.get_variable( 82 name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) 83 return shards 84 85 86def _norm(g, b, inp, scope): 87 shape = inp.get_shape()[-1:] 88 gamma_init = init_ops.constant_initializer(g) 89 beta_init = init_ops.constant_initializer(b) 90 with vs.variable_scope(scope): 91 # Initialize beta and gamma for use by layer_norm. 92 vs.get_variable("gamma", shape=shape, initializer=gamma_init) 93 vs.get_variable("beta", shape=shape, initializer=beta_init) 94 normalized = layers.layer_norm(inp, reuse=True, scope=scope) 95 return normalized 96 97 98class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): 99 """Long short-term memory unit (LSTM) recurrent network cell. 100 101 The default non-peephole implementation is based on: 102 103 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 104 105 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 106 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 107 108 The peephole implementation is based on: 109 110 https://research.google.com/pubs/archive/43905.pdf 111 112 Hasim Sak, Andrew Senior, and Francoise Beaufays. 113 "Long short-term memory recurrent neural network architectures for 114 large scale acoustic modeling." INTERSPEECH, 2014. 115 116 The coupling of input and forget gate is based on: 117 118 http://arxiv.org/pdf/1503.04069.pdf 119 120 Greff et al. "LSTM: A Search Space Odyssey" 121 122 The class uses optional peep-hole connections, and an optional projection 123 layer. 124 Layer normalization implementation is based on: 125 126 https://arxiv.org/abs/1607.06450. 127 128 "Layer Normalization" 129 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 130 131 and is applied before the internal nonlinearities. 132 133 """ 134 135 def __init__(self, 136 num_units, 137 use_peepholes=False, 138 initializer=None, 139 num_proj=None, 140 proj_clip=None, 141 num_unit_shards=1, 142 num_proj_shards=1, 143 forget_bias=1.0, 144 state_is_tuple=True, 145 activation=math_ops.tanh, 146 reuse=None, 147 layer_norm=False, 148 norm_gain=1.0, 149 norm_shift=0.0): 150 """Initialize the parameters for an LSTM cell. 151 152 Args: 153 num_units: int, The number of units in the LSTM cell 154 use_peepholes: bool, set True to enable diagonal/peephole connections. 155 initializer: (optional) The initializer to use for the weight and 156 projection matrices. 157 num_proj: (optional) int, The output dimensionality for the projection 158 matrices. If None, no projection is performed. 159 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 160 provided, then the projected values are clipped elementwise to within 161 `[-proj_clip, proj_clip]`. 162 num_unit_shards: How to split the weight matrix. If >1, the weight 163 matrix is stored across num_unit_shards. 164 num_proj_shards: How to split the projection matrix. If >1, the 165 projection matrix is stored across num_proj_shards. 166 forget_bias: Biases of the forget gate are initialized by default to 1 167 in order to reduce the scale of forgetting at the beginning of 168 the training. 169 state_is_tuple: If True, accepted and returned states are 2-tuples of 170 the `c_state` and `m_state`. By default (False), they are concatenated 171 along the column axis. This default behavior will soon be deprecated. 172 activation: Activation function of the inner states. 173 reuse: (optional) Python boolean describing whether to reuse variables 174 in an existing scope. If not `True`, and the existing scope already has 175 the given variables, an error is raised. 176 layer_norm: If `True`, layer normalization will be applied. 177 norm_gain: float, The layer normalization gain initial value. If 178 `layer_norm` has been set to `False`, this argument will be ignored. 179 norm_shift: float, The layer normalization shift initial value. If 180 `layer_norm` has been set to `False`, this argument will be ignored. 181 """ 182 super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 183 if not state_is_tuple: 184 logging.warn("%s: Using a concatenated state is slower and will soon be " 185 "deprecated. Use state_is_tuple=True.", self) 186 self._num_units = num_units 187 self._use_peepholes = use_peepholes 188 self._initializer = initializer 189 self._num_proj = num_proj 190 self._proj_clip = proj_clip 191 self._num_unit_shards = num_unit_shards 192 self._num_proj_shards = num_proj_shards 193 self._forget_bias = forget_bias 194 self._state_is_tuple = state_is_tuple 195 self._activation = activation 196 self._reuse = reuse 197 self._layer_norm = layer_norm 198 self._norm_gain = norm_gain 199 self._norm_shift = norm_shift 200 201 if num_proj: 202 self._state_size = ( 203 rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 204 if state_is_tuple else num_units + num_proj) 205 self._output_size = num_proj 206 else: 207 self._state_size = ( 208 rnn_cell_impl.LSTMStateTuple(num_units, num_units) 209 if state_is_tuple else 2 * num_units) 210 self._output_size = num_units 211 212 @property 213 def state_size(self): 214 return self._state_size 215 216 @property 217 def output_size(self): 218 return self._output_size 219 220 def call(self, inputs, state): 221 """Run one step of LSTM. 222 223 Args: 224 inputs: input Tensor, 2D, batch x num_units. 225 state: if `state_is_tuple` is False, this must be a state Tensor, 226 `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 227 tuple of state Tensors, both `2-D`, with column sizes `c_state` and 228 `m_state`. 229 230 Returns: 231 A tuple containing: 232 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 233 LSTM after reading `inputs` when previous state was `state`. 234 Here output_dim is: 235 num_proj if num_proj was set, 236 num_units otherwise. 237 - Tensor(s) representing the new state of LSTM after reading `inputs` when 238 the previous state was `state`. Same type and shape(s) as `state`. 239 240 Raises: 241 ValueError: If input size cannot be inferred from inputs via 242 static shape inference. 243 """ 244 sigmoid = math_ops.sigmoid 245 246 num_proj = self._num_units if self._num_proj is None else self._num_proj 247 248 if self._state_is_tuple: 249 (c_prev, m_prev) = state 250 else: 251 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 252 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 253 254 dtype = inputs.dtype 255 input_size = inputs.get_shape().with_rank(2).dims[1] 256 if input_size.value is None: 257 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 258 concat_w = _get_concat_variable( 259 "W", 260 [input_size.value + num_proj, 3 * self._num_units], 261 dtype, 262 self._num_unit_shards) 263 264 b = vs.get_variable( 265 "B", 266 shape=[3 * self._num_units], 267 initializer=init_ops.zeros_initializer(), 268 dtype=dtype) 269 270 # j = new_input, f = forget_gate, o = output_gate 271 cell_inputs = array_ops.concat([inputs, m_prev], 1) 272 lstm_matrix = math_ops.matmul(cell_inputs, concat_w) 273 274 # If layer nomalization is applied, do not add bias 275 if not self._layer_norm: 276 lstm_matrix = nn_ops.bias_add(lstm_matrix, b) 277 278 j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) 279 280 # Apply layer normalization 281 if self._layer_norm: 282 j = _norm(self._norm_gain, self._norm_shift, j, "transform") 283 f = _norm(self._norm_gain, self._norm_shift, f, "forget") 284 o = _norm(self._norm_gain, self._norm_shift, o, "output") 285 286 # Diagonal connections 287 if self._use_peepholes: 288 w_f_diag = vs.get_variable( 289 "W_F_diag", shape=[self._num_units], dtype=dtype) 290 w_o_diag = vs.get_variable( 291 "W_O_diag", shape=[self._num_units], dtype=dtype) 292 293 if self._use_peepholes: 294 f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) 295 else: 296 f_act = sigmoid(f + self._forget_bias) 297 c = (f_act * c_prev + (1 - f_act) * self._activation(j)) 298 299 # Apply layer normalization 300 if self._layer_norm: 301 c = _norm(self._norm_gain, self._norm_shift, c, "state") 302 303 if self._use_peepholes: 304 m = sigmoid(o + w_o_diag * c) * self._activation(c) 305 else: 306 m = sigmoid(o) * self._activation(c) 307 308 if self._num_proj is not None: 309 concat_w_proj = _get_concat_variable("W_P", 310 [self._num_units, self._num_proj], 311 dtype, self._num_proj_shards) 312 313 m = math_ops.matmul(m, concat_w_proj) 314 if self._proj_clip is not None: 315 # pylint: disable=invalid-unary-operand-type 316 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 317 # pylint: enable=invalid-unary-operand-type 318 319 new_state = ( 320 rnn_cell_impl.LSTMStateTuple(c, m) 321 if self._state_is_tuple else array_ops.concat([c, m], 1)) 322 return m, new_state 323 324 325class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): 326 """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. 327 328 This implementation is based on: 329 330 Tara N. Sainath and Bo Li 331 "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 332 for LVCSR Tasks." submitted to INTERSPEECH, 2016. 333 334 It uses peep-hole connections and optional cell clipping. 335 """ 336 337 def __init__(self, 338 num_units, 339 use_peepholes=False, 340 cell_clip=None, 341 initializer=None, 342 num_unit_shards=1, 343 forget_bias=1.0, 344 feature_size=None, 345 frequency_skip=1, 346 reuse=None): 347 """Initialize the parameters for an LSTM cell. 348 349 Args: 350 num_units: int, The number of units in the LSTM cell 351 use_peepholes: bool, set True to enable diagonal/peephole connections. 352 cell_clip: (optional) A float value, if provided the cell state is clipped 353 by this value prior to the cell output activation. 354 initializer: (optional) The initializer to use for the weight and 355 projection matrices. 356 num_unit_shards: int, How to split the weight matrix. If >1, the weight 357 matrix is stored across num_unit_shards. 358 forget_bias: float, Biases of the forget gate are initialized by default 359 to 1 in order to reduce the scale of forgetting at the beginning 360 of the training. 361 feature_size: int, The size of the input feature the LSTM spans over. 362 frequency_skip: int, The amount the LSTM filter is shifted by in 363 frequency. 364 reuse: (optional) Python boolean describing whether to reuse variables 365 in an existing scope. If not `True`, and the existing scope already has 366 the given variables, an error is raised. 367 """ 368 super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) 369 self._num_units = num_units 370 self._use_peepholes = use_peepholes 371 self._cell_clip = cell_clip 372 self._initializer = initializer 373 self._num_unit_shards = num_unit_shards 374 self._forget_bias = forget_bias 375 self._feature_size = feature_size 376 self._frequency_skip = frequency_skip 377 self._state_size = 2 * num_units 378 self._output_size = num_units 379 self._reuse = reuse 380 381 @property 382 def output_size(self): 383 return self._output_size 384 385 @property 386 def state_size(self): 387 return self._state_size 388 389 def call(self, inputs, state): 390 """Run one step of LSTM. 391 392 Args: 393 inputs: input Tensor, 2D, batch x num_units. 394 state: state Tensor, 2D, batch x state_size. 395 396 Returns: 397 A tuple containing: 398 - A 2D, batch x output_dim, Tensor representing the output of the LSTM 399 after reading "inputs" when previous state was "state". 400 Here output_dim is num_units. 401 - A 2D, batch x state_size, Tensor representing the new state of LSTM 402 after reading "inputs" when previous state was "state". 403 Raises: 404 ValueError: if an input_size was specified and the provided inputs have 405 a different dimension. 406 """ 407 sigmoid = math_ops.sigmoid 408 tanh = math_ops.tanh 409 410 freq_inputs = self._make_tf_features(inputs) 411 dtype = inputs.dtype 412 actual_input_size = freq_inputs[0].get_shape().as_list()[1] 413 414 concat_w = _get_concat_variable( 415 "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units], 416 dtype, self._num_unit_shards) 417 418 b = vs.get_variable( 419 "B", 420 shape=[4 * self._num_units], 421 initializer=init_ops.zeros_initializer(), 422 dtype=dtype) 423 424 # Diagonal connections 425 if self._use_peepholes: 426 w_f_diag = vs.get_variable( 427 "W_F_diag", shape=[self._num_units], dtype=dtype) 428 w_i_diag = vs.get_variable( 429 "W_I_diag", shape=[self._num_units], dtype=dtype) 430 w_o_diag = vs.get_variable( 431 "W_O_diag", shape=[self._num_units], dtype=dtype) 432 433 # initialize the first freq state to be zero 434 m_prev_freq = array_ops.zeros( 435 [inputs.shape.dims[0].value or inputs.get_shape()[0], self._num_units], 436 dtype) 437 for fq in range(len(freq_inputs)): 438 c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], 439 [-1, self._num_units]) 440 m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units], 441 [-1, self._num_units]) 442 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 443 cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1) 444 lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 445 i, j, f, o = array_ops.split( 446 value=lstm_matrix, num_or_size_splits=4, axis=1) 447 448 if self._use_peepholes: 449 c = ( 450 sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 451 sigmoid(i + w_i_diag * c_prev) * tanh(j)) 452 else: 453 c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) 454 455 if self._cell_clip is not None: 456 # pylint: disable=invalid-unary-operand-type 457 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 458 # pylint: enable=invalid-unary-operand-type 459 460 if self._use_peepholes: 461 m = sigmoid(o + w_o_diag * c) * tanh(c) 462 else: 463 m = sigmoid(o) * tanh(c) 464 m_prev_freq = m 465 if fq == 0: 466 state_out = array_ops.concat([c, m], 1) 467 m_out = m 468 else: 469 state_out = array_ops.concat([state_out, c, m], 1) 470 m_out = array_ops.concat([m_out, m], 1) 471 return m_out, state_out 472 473 def _make_tf_features(self, input_feat): 474 """Make the frequency features. 475 476 Args: 477 input_feat: input Tensor, 2D, batch x num_units. 478 479 Returns: 480 A list of frequency features, with each element containing: 481 - A 2D, batch x output_dim, Tensor representing the time-frequency feature 482 for that frequency index. Here output_dim is feature_size. 483 Raises: 484 ValueError: if input_size cannot be inferred from static shape inference. 485 """ 486 input_size = input_feat.get_shape().with_rank(2).dims[-1].value 487 if input_size is None: 488 raise ValueError("Cannot infer input_size from static shape inference.") 489 num_feats = int( 490 (input_size - self._feature_size) / (self._frequency_skip)) + 1 491 freq_inputs = [] 492 for f in range(num_feats): 493 cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip], 494 [-1, self._feature_size]) 495 freq_inputs.append(cur_input) 496 return freq_inputs 497 498 499class GridLSTMCell(rnn_cell_impl.RNNCell): 500 """Grid Long short-term memory unit (LSTM) recurrent network cell. 501 502 The default is based on: 503 Nal Kalchbrenner, Ivo Danihelka and Alex Graves 504 "Grid Long Short-Term Memory," Proc. ICLR 2016. 505 http://arxiv.org/abs/1507.01526 506 507 When peephole connections are used, the implementation is based on: 508 Tara N. Sainath and Bo Li 509 "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 510 for LVCSR Tasks." submitted to INTERSPEECH, 2016. 511 512 The code uses optional peephole connections, shared_weights and cell clipping. 513 """ 514 515 def __init__(self, 516 num_units, 517 use_peepholes=False, 518 share_time_frequency_weights=False, 519 cell_clip=None, 520 initializer=None, 521 num_unit_shards=1, 522 forget_bias=1.0, 523 feature_size=None, 524 frequency_skip=None, 525 num_frequency_blocks=None, 526 start_freqindex_list=None, 527 end_freqindex_list=None, 528 couple_input_forget_gates=False, 529 state_is_tuple=True, 530 reuse=None): 531 """Initialize the parameters for an LSTM cell. 532 533 Args: 534 num_units: int, The number of units in the LSTM cell 535 use_peepholes: (optional) bool, default False. Set True to enable 536 diagonal/peephole connections. 537 share_time_frequency_weights: (optional) bool, default False. Set True to 538 enable shared cell weights between time and frequency LSTMs. 539 cell_clip: (optional) A float value, default None, if provided the cell 540 state is clipped by this value prior to the cell output activation. 541 initializer: (optional) The initializer to use for the weight and 542 projection matrices, default None. 543 num_unit_shards: (optional) int, default 1, How to split the weight 544 matrix. If > 1, the weight matrix is stored across num_unit_shards. 545 forget_bias: (optional) float, default 1.0, The initial bias of the 546 forget gates, used to reduce the scale of forgetting at the beginning 547 of the training. 548 feature_size: (optional) int, default None, The size of the input feature 549 the LSTM spans over. 550 frequency_skip: (optional) int, default None, The amount the LSTM filter 551 is shifted by in frequency. 552 num_frequency_blocks: [required] A list of frequency blocks needed to 553 cover the whole input feature splitting defined by start_freqindex_list 554 and end_freqindex_list. 555 start_freqindex_list: [optional], list of ints, default None, The 556 starting frequency index for each frequency block. 557 end_freqindex_list: [optional], list of ints, default None. The ending 558 frequency index for each frequency block. 559 couple_input_forget_gates: (optional) bool, default False, Whether to 560 couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 561 model parameters and computation cost. 562 state_is_tuple: If True, accepted and returned states are 2-tuples of 563 the `c_state` and `m_state`. By default (False), they are concatenated 564 along the column axis. This default behavior will soon be deprecated. 565 reuse: (optional) Python boolean describing whether to reuse variables 566 in an existing scope. If not `True`, and the existing scope already has 567 the given variables, an error is raised. 568 Raises: 569 ValueError: if the num_frequency_blocks list is not specified 570 """ 571 super(GridLSTMCell, self).__init__(_reuse=reuse) 572 if not state_is_tuple: 573 logging.warn("%s: Using a concatenated state is slower and will soon be " 574 "deprecated. Use state_is_tuple=True.", self) 575 self._num_units = num_units 576 self._use_peepholes = use_peepholes 577 self._share_time_frequency_weights = share_time_frequency_weights 578 self._couple_input_forget_gates = couple_input_forget_gates 579 self._state_is_tuple = state_is_tuple 580 self._cell_clip = cell_clip 581 self._initializer = initializer 582 self._num_unit_shards = num_unit_shards 583 self._forget_bias = forget_bias 584 self._feature_size = feature_size 585 self._frequency_skip = frequency_skip 586 self._start_freqindex_list = start_freqindex_list 587 self._end_freqindex_list = end_freqindex_list 588 self._num_frequency_blocks = num_frequency_blocks 589 self._total_blocks = 0 590 self._reuse = reuse 591 if self._num_frequency_blocks is None: 592 raise ValueError("Must specify num_frequency_blocks") 593 594 for block_index in range(len(self._num_frequency_blocks)): 595 self._total_blocks += int(self._num_frequency_blocks[block_index]) 596 if state_is_tuple: 597 state_names = "" 598 for block_index in range(len(self._num_frequency_blocks)): 599 for freq_index in range(self._num_frequency_blocks[block_index]): 600 name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) 601 state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 602 self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple", 603 state_names.strip(",")) 604 self._state_size = self._state_tuple_type(*( 605 [num_units, num_units] * self._total_blocks)) 606 else: 607 self._state_tuple_type = None 608 self._state_size = num_units * self._total_blocks * 2 609 self._output_size = num_units * self._total_blocks * 2 610 611 @property 612 def output_size(self): 613 return self._output_size 614 615 @property 616 def state_size(self): 617 return self._state_size 618 619 @property 620 def state_tuple_type(self): 621 return self._state_tuple_type 622 623 def call(self, inputs, state): 624 """Run one step of LSTM. 625 626 Args: 627 inputs: input Tensor, 2D, [batch, feature_size]. 628 state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the 629 flag self._state_is_tuple. 630 631 Returns: 632 A tuple containing: 633 - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 634 after reading "inputs" when previous state was "state". 635 Here output_dim is num_units. 636 - A 2D, [batch, state_size], Tensor representing the new state of LSTM 637 after reading "inputs" when previous state was "state". 638 Raises: 639 ValueError: if an input_size was specified and the provided inputs have 640 a different dimension. 641 """ 642 batch_size = tensor_shape.dimension_value( 643 inputs.shape[0]) or array_ops.shape(inputs)[0] 644 freq_inputs = self._make_tf_features(inputs) 645 m_out_lst = [] 646 state_out_lst = [] 647 for block in range(len(freq_inputs)): 648 m_out_lst_current, state_out_lst_current = self._compute( 649 freq_inputs[block], 650 block, 651 state, 652 batch_size, 653 state_is_tuple=self._state_is_tuple) 654 m_out_lst.extend(m_out_lst_current) 655 state_out_lst.extend(state_out_lst_current) 656 if self._state_is_tuple: 657 state_out = self._state_tuple_type(*state_out_lst) 658 else: 659 state_out = array_ops.concat(state_out_lst, 1) 660 m_out = array_ops.concat(m_out_lst, 1) 661 return m_out, state_out 662 663 def _compute(self, 664 freq_inputs, 665 block, 666 state, 667 batch_size, 668 state_prefix="state", 669 state_is_tuple=True): 670 """Run the actual computation of one step LSTM. 671 672 Args: 673 freq_inputs: list of Tensors, 2D, [batch, feature_size]. 674 block: int, current frequency block index to process. 675 state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on 676 the flag state_is_tuple. 677 batch_size: int32, batch size. 678 state_prefix: (optional) string, name prefix for states, defaults to 679 "state". 680 state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. 681 682 Returns: 683 A tuple, containing: 684 - A list of [batch, output_dim] Tensors, representing the output of the 685 LSTM given the inputs and state. 686 - A list of [batch, state_size] Tensors, representing the LSTM state 687 values given the inputs and previous state. 688 """ 689 sigmoid = math_ops.sigmoid 690 tanh = math_ops.tanh 691 num_gates = 3 if self._couple_input_forget_gates else 4 692 dtype = freq_inputs[0].dtype 693 actual_input_size = freq_inputs[0].get_shape().as_list()[1] 694 695 concat_w_f = _get_concat_variable( 696 "W_f_%d" % block, 697 [actual_input_size + 2 * self._num_units, num_gates * self._num_units], 698 dtype, self._num_unit_shards) 699 b_f = vs.get_variable( 700 "B_f_%d" % block, 701 shape=[num_gates * self._num_units], 702 initializer=init_ops.zeros_initializer(), 703 dtype=dtype) 704 if not self._share_time_frequency_weights: 705 concat_w_t = _get_concat_variable("W_t_%d" % block, [ 706 actual_input_size + 2 * self._num_units, num_gates * self._num_units 707 ], dtype, self._num_unit_shards) 708 b_t = vs.get_variable( 709 "B_t_%d" % block, 710 shape=[num_gates * self._num_units], 711 initializer=init_ops.zeros_initializer(), 712 dtype=dtype) 713 714 if self._use_peepholes: 715 # Diagonal connections 716 if not self._couple_input_forget_gates: 717 w_f_diag_freqf = vs.get_variable( 718 "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 719 w_f_diag_freqt = vs.get_variable( 720 "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 721 w_i_diag_freqf = vs.get_variable( 722 "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 723 w_i_diag_freqt = vs.get_variable( 724 "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 725 w_o_diag_freqf = vs.get_variable( 726 "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 727 w_o_diag_freqt = vs.get_variable( 728 "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 729 if not self._share_time_frequency_weights: 730 if not self._couple_input_forget_gates: 731 w_f_diag_timef = vs.get_variable( 732 "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 733 w_f_diag_timet = vs.get_variable( 734 "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 735 w_i_diag_timef = vs.get_variable( 736 "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 737 w_i_diag_timet = vs.get_variable( 738 "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 739 w_o_diag_timef = vs.get_variable( 740 "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 741 w_o_diag_timet = vs.get_variable( 742 "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 743 744 # initialize the first freq state to be zero 745 m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 746 c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 747 for freq_index in range(len(freq_inputs)): 748 if state_is_tuple: 749 name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) 750 c_prev_time = getattr(state, name_prefix + "_c") 751 m_prev_time = getattr(state, name_prefix + "_m") 752 else: 753 c_prev_time = array_ops.slice( 754 state, [0, 2 * freq_index * self._num_units], [-1, self._num_units]) 755 m_prev_time = array_ops.slice( 756 state, [0, (2 * freq_index + 1) * self._num_units], 757 [-1, self._num_units]) 758 759 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 760 cell_inputs = array_ops.concat( 761 [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) 762 763 # F-LSTM 764 lstm_matrix_freq = nn_ops.bias_add( 765 math_ops.matmul(cell_inputs, concat_w_f), b_f) 766 if self._couple_input_forget_gates: 767 i_freq, j_freq, o_freq = array_ops.split( 768 value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 769 f_freq = None 770 else: 771 i_freq, j_freq, f_freq, o_freq = array_ops.split( 772 value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 773 # T-LSTM 774 if self._share_time_frequency_weights: 775 i_time = i_freq 776 j_time = j_freq 777 f_time = f_freq 778 o_time = o_freq 779 else: 780 lstm_matrix_time = nn_ops.bias_add( 781 math_ops.matmul(cell_inputs, concat_w_t), b_t) 782 if self._couple_input_forget_gates: 783 i_time, j_time, o_time = array_ops.split( 784 value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 785 f_time = None 786 else: 787 i_time, j_time, f_time, o_time = array_ops.split( 788 value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 789 790 # F-LSTM c_freq 791 # input gate activations 792 if self._use_peepholes: 793 i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq + 794 w_i_diag_freqt * c_prev_time) 795 else: 796 i_freq_g = sigmoid(i_freq) 797 # forget gate activations 798 if self._couple_input_forget_gates: 799 f_freq_g = 1.0 - i_freq_g 800 else: 801 if self._use_peepholes: 802 f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf * 803 c_prev_freq + w_f_diag_freqt * c_prev_time) 804 else: 805 f_freq_g = sigmoid(f_freq + self._forget_bias) 806 # cell state 807 c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) 808 if self._cell_clip is not None: 809 # pylint: disable=invalid-unary-operand-type 810 c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, 811 self._cell_clip) 812 # pylint: enable=invalid-unary-operand-type 813 814 # T-LSTM c_freq 815 # input gate activations 816 if self._use_peepholes: 817 if self._share_time_frequency_weights: 818 i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq + 819 w_i_diag_freqt * c_prev_time) 820 else: 821 i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq + 822 w_i_diag_timet * c_prev_time) 823 else: 824 i_time_g = sigmoid(i_time) 825 # forget gate activations 826 if self._couple_input_forget_gates: 827 f_time_g = 1.0 - i_time_g 828 else: 829 if self._use_peepholes: 830 if self._share_time_frequency_weights: 831 f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf * 832 c_prev_freq + w_f_diag_freqt * c_prev_time) 833 else: 834 f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef * 835 c_prev_freq + w_f_diag_timet * c_prev_time) 836 else: 837 f_time_g = sigmoid(f_time + self._forget_bias) 838 # cell state 839 c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) 840 if self._cell_clip is not None: 841 # pylint: disable=invalid-unary-operand-type 842 c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, 843 self._cell_clip) 844 # pylint: enable=invalid-unary-operand-type 845 846 # F-LSTM m_freq 847 if self._use_peepholes: 848 m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq + 849 w_o_diag_freqt * c_time) * tanh(c_freq) 850 else: 851 m_freq = sigmoid(o_freq) * tanh(c_freq) 852 853 # T-LSTM m_time 854 if self._use_peepholes: 855 if self._share_time_frequency_weights: 856 m_time = sigmoid(o_time + w_o_diag_freqf * c_freq + 857 w_o_diag_freqt * c_time) * tanh(c_time) 858 else: 859 m_time = sigmoid(o_time + w_o_diag_timef * c_freq + 860 w_o_diag_timet * c_time) * tanh(c_time) 861 else: 862 m_time = sigmoid(o_time) * tanh(c_time) 863 864 m_prev_freq = m_freq 865 c_prev_freq = c_freq 866 # Concatenate the outputs for T-LSTM and F-LSTM for each shift 867 if freq_index == 0: 868 state_out_lst = [c_time, m_time] 869 m_out_lst = [m_time, m_freq] 870 else: 871 state_out_lst.extend([c_time, m_time]) 872 m_out_lst.extend([m_time, m_freq]) 873 874 return m_out_lst, state_out_lst 875 876 def _make_tf_features(self, input_feat, slice_offset=0): 877 """Make the frequency features. 878 879 Args: 880 input_feat: input Tensor, 2D, [batch, num_units]. 881 slice_offset: (optional) Python int, default 0, the slicing offset is only 882 used for the backward processing in the BidirectionalGridLSTMCell. It 883 specifies a different starting point instead of always 0 to enable the 884 forward and backward processing look at different frequency blocks. 885 886 Returns: 887 A list of frequency features, with each element containing: 888 - A 2D, [batch, output_dim], Tensor representing the time-frequency 889 feature for that frequency index. Here output_dim is feature_size. 890 Raises: 891 ValueError: if input_size cannot be inferred from static shape inference. 892 """ 893 input_size = input_feat.get_shape().with_rank(2).dims[-1].value 894 if input_size is None: 895 raise ValueError("Cannot infer input_size from static shape inference.") 896 if slice_offset > 0: 897 # Padding to the end 898 inputs = array_ops.pad(input_feat, 899 array_ops.constant( 900 [0, 0, 0, slice_offset], 901 shape=[2, 2], 902 dtype=dtypes.int32), "CONSTANT") 903 elif slice_offset < 0: 904 # Padding to the front 905 inputs = array_ops.pad(input_feat, 906 array_ops.constant( 907 [0, 0, -slice_offset, 0], 908 shape=[2, 2], 909 dtype=dtypes.int32), "CONSTANT") 910 slice_offset = 0 911 else: 912 inputs = input_feat 913 freq_inputs = [] 914 if not self._start_freqindex_list: 915 if len(self._num_frequency_blocks) != 1: 916 raise ValueError("Length of num_frequency_blocks" 917 " is not 1, but instead is %d" % 918 len(self._num_frequency_blocks)) 919 num_feats = int( 920 (input_size - self._feature_size) / (self._frequency_skip)) + 1 921 if num_feats != self._num_frequency_blocks[0]: 922 raise ValueError( 923 "Invalid num_frequency_blocks, requires %d but gets %d, please" 924 " check the input size and filter config are correct." % 925 (self._num_frequency_blocks[0], num_feats)) 926 block_inputs = [] 927 for f in range(num_feats): 928 cur_input = array_ops.slice( 929 inputs, [0, slice_offset + f * self._frequency_skip], 930 [-1, self._feature_size]) 931 block_inputs.append(cur_input) 932 freq_inputs.append(block_inputs) 933 else: 934 if len(self._start_freqindex_list) != len(self._end_freqindex_list): 935 raise ValueError("Length of start and end freqindex_list" 936 " does not match %d %d", 937 len(self._start_freqindex_list), 938 len(self._end_freqindex_list)) 939 if len(self._num_frequency_blocks) != len(self._start_freqindex_list): 940 raise ValueError("Length of num_frequency_blocks" 941 " is not equal to start_freqindex_list %d %d", 942 len(self._num_frequency_blocks), 943 len(self._start_freqindex_list)) 944 for b in range(len(self._start_freqindex_list)): 945 start_index = self._start_freqindex_list[b] 946 end_index = self._end_freqindex_list[b] 947 cur_size = end_index - start_index 948 block_feats = int( 949 (cur_size - self._feature_size) / (self._frequency_skip)) + 1 950 if block_feats != self._num_frequency_blocks[b]: 951 raise ValueError( 952 "Invalid num_frequency_blocks, requires %d but gets %d, please" 953 " check the input size and filter config are correct." % 954 (self._num_frequency_blocks[b], block_feats)) 955 block_inputs = [] 956 for f in range(block_feats): 957 cur_input = array_ops.slice( 958 inputs, 959 [0, start_index + slice_offset + f * self._frequency_skip], 960 [-1, self._feature_size]) 961 block_inputs.append(cur_input) 962 freq_inputs.append(block_inputs) 963 return freq_inputs 964 965 966class BidirectionalGridLSTMCell(GridLSTMCell): 967 """Bidirectional GridLstm cell. 968 969 The bidirection connection is only used in the frequency direction, which 970 hence doesn't affect the time direction's real-time processing that is 971 required for online recognition systems. 972 The current implementation uses different weights for the two directions. 973 """ 974 975 def __init__(self, 976 num_units, 977 use_peepholes=False, 978 share_time_frequency_weights=False, 979 cell_clip=None, 980 initializer=None, 981 num_unit_shards=1, 982 forget_bias=1.0, 983 feature_size=None, 984 frequency_skip=None, 985 num_frequency_blocks=None, 986 start_freqindex_list=None, 987 end_freqindex_list=None, 988 couple_input_forget_gates=False, 989 backward_slice_offset=0, 990 reuse=None): 991 """Initialize the parameters for an LSTM cell. 992 993 Args: 994 num_units: int, The number of units in the LSTM cell 995 use_peepholes: (optional) bool, default False. Set True to enable 996 diagonal/peephole connections. 997 share_time_frequency_weights: (optional) bool, default False. Set True to 998 enable shared cell weights between time and frequency LSTMs. 999 cell_clip: (optional) A float value, default None, if provided the cell 1000 state is clipped by this value prior to the cell output activation. 1001 initializer: (optional) The initializer to use for the weight and 1002 projection matrices, default None. 1003 num_unit_shards: (optional) int, default 1, How to split the weight 1004 matrix. If > 1, the weight matrix is stored across num_unit_shards. 1005 forget_bias: (optional) float, default 1.0, The initial bias of the 1006 forget gates, used to reduce the scale of forgetting at the beginning 1007 of the training. 1008 feature_size: (optional) int, default None, The size of the input feature 1009 the LSTM spans over. 1010 frequency_skip: (optional) int, default None, The amount the LSTM filter 1011 is shifted by in frequency. 1012 num_frequency_blocks: [required] A list of frequency blocks needed to 1013 cover the whole input feature splitting defined by start_freqindex_list 1014 and end_freqindex_list. 1015 start_freqindex_list: [optional], list of ints, default None, The 1016 starting frequency index for each frequency block. 1017 end_freqindex_list: [optional], list of ints, default None. The ending 1018 frequency index for each frequency block. 1019 couple_input_forget_gates: (optional) bool, default False, Whether to 1020 couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 1021 model parameters and computation cost. 1022 backward_slice_offset: (optional) int32, default 0, the starting offset to 1023 slice the feature for backward processing. 1024 reuse: (optional) Python boolean describing whether to reuse variables 1025 in an existing scope. If not `True`, and the existing scope already has 1026 the given variables, an error is raised. 1027 """ 1028 super(BidirectionalGridLSTMCell, self).__init__( 1029 num_units, use_peepholes, share_time_frequency_weights, cell_clip, 1030 initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, 1031 num_frequency_blocks, start_freqindex_list, end_freqindex_list, 1032 couple_input_forget_gates, True, reuse) 1033 self._backward_slice_offset = int(backward_slice_offset) 1034 state_names = "" 1035 for direction in ["fwd", "bwd"]: 1036 for block_index in range(len(self._num_frequency_blocks)): 1037 for freq_index in range(self._num_frequency_blocks[block_index]): 1038 name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, 1039 block_index) 1040 state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 1041 self._state_tuple_type = collections.namedtuple( 1042 "BidirectionalGridLSTMStateTuple", state_names.strip(",")) 1043 self._state_size = self._state_tuple_type(*( 1044 [num_units, num_units] * self._total_blocks * 2)) 1045 self._output_size = 2 * num_units * self._total_blocks * 2 1046 1047 def call(self, inputs, state): 1048 """Run one step of LSTM. 1049 1050 Args: 1051 inputs: input Tensor, 2D, [batch, num_units]. 1052 state: tuple of Tensors, 2D, [batch, state_size]. 1053 1054 Returns: 1055 A tuple containing: 1056 - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 1057 after reading "inputs" when previous state was "state". 1058 Here output_dim is num_units. 1059 - A 2D, [batch, state_size], Tensor representing the new state of LSTM 1060 after reading "inputs" when previous state was "state". 1061 Raises: 1062 ValueError: if an input_size was specified and the provided inputs have 1063 a different dimension. 1064 """ 1065 batch_size = tensor_shape.dimension_value( 1066 inputs.shape[0]) or array_ops.shape(inputs)[0] 1067 fwd_inputs = self._make_tf_features(inputs) 1068 if self._backward_slice_offset: 1069 bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) 1070 else: 1071 bwd_inputs = fwd_inputs 1072 1073 # Forward processing 1074 with vs.variable_scope("fwd"): 1075 fwd_m_out_lst = [] 1076 fwd_state_out_lst = [] 1077 for block in range(len(fwd_inputs)): 1078 fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( 1079 fwd_inputs[block], 1080 block, 1081 state, 1082 batch_size, 1083 state_prefix="fwd_state", 1084 state_is_tuple=True) 1085 fwd_m_out_lst.extend(fwd_m_out_lst_current) 1086 fwd_state_out_lst.extend(fwd_state_out_lst_current) 1087 # Backward processing 1088 bwd_m_out_lst = [] 1089 bwd_state_out_lst = [] 1090 with vs.variable_scope("bwd"): 1091 for block in range(len(bwd_inputs)): 1092 # Reverse the blocks 1093 bwd_inputs_reverse = bwd_inputs[block][::-1] 1094 bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( 1095 bwd_inputs_reverse, 1096 block, 1097 state, 1098 batch_size, 1099 state_prefix="bwd_state", 1100 state_is_tuple=True) 1101 bwd_m_out_lst.extend(bwd_m_out_lst_current) 1102 bwd_state_out_lst.extend(bwd_state_out_lst_current) 1103 state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) 1104 # Outputs are always concated as it is never used separately. 1105 m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) 1106 return m_out, state_out 1107 1108 1109# pylint: disable=protected-access 1110_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name 1111 1112# pylint: enable=protected-access 1113 1114 1115class AttentionCellWrapper(rnn_cell_impl.RNNCell): 1116 """Basic attention cell wrapper. 1117 1118 Implementation based on https://arxiv.org/abs/1601.06733. 1119 """ 1120 1121 def __init__(self, 1122 cell, 1123 attn_length, 1124 attn_size=None, 1125 attn_vec_size=None, 1126 input_size=None, 1127 state_is_tuple=True, 1128 reuse=None): 1129 """Create a cell with attention. 1130 1131 Args: 1132 cell: an RNNCell, an attention is added to it. 1133 attn_length: integer, the size of an attention window. 1134 attn_size: integer, the size of an attention vector. Equal to 1135 cell.output_size by default. 1136 attn_vec_size: integer, the number of convolutional features calculated 1137 on attention state and a size of the hidden layer built from 1138 base cell state. Equal attn_size to by default. 1139 input_size: integer, the size of a hidden linear layer, 1140 built from inputs and attention. Derived from the input tensor 1141 by default. 1142 state_is_tuple: If True, accepted and returned states are n-tuples, where 1143 `n = len(cells)`. By default (False), the states are all 1144 concatenated along the column axis. 1145 reuse: (optional) Python boolean describing whether to reuse variables 1146 in an existing scope. If not `True`, and the existing scope already has 1147 the given variables, an error is raised. 1148 1149 Raises: 1150 TypeError: if cell is not an RNNCell. 1151 ValueError: if cell returns a state tuple but the flag 1152 `state_is_tuple` is `False` or if attn_length is zero or less. 1153 """ 1154 super(AttentionCellWrapper, self).__init__(_reuse=reuse) 1155 rnn_cell_impl.assert_like_rnncell("cell", cell) 1156 if nest.is_sequence(cell.state_size) and not state_is_tuple: 1157 raise ValueError( 1158 "Cell returns tuple of states, but the flag " 1159 "state_is_tuple is not set. State size is: %s" % str(cell.state_size)) 1160 if attn_length <= 0: 1161 raise ValueError( 1162 "attn_length should be greater than zero, got %s" % str(attn_length)) 1163 if not state_is_tuple: 1164 logging.warn("%s: Using a concatenated state is slower and will soon be " 1165 "deprecated. Use state_is_tuple=True.", self) 1166 if attn_size is None: 1167 attn_size = cell.output_size 1168 if attn_vec_size is None: 1169 attn_vec_size = attn_size 1170 self._state_is_tuple = state_is_tuple 1171 self._cell = cell 1172 self._attn_vec_size = attn_vec_size 1173 self._input_size = input_size 1174 self._attn_size = attn_size 1175 self._attn_length = attn_length 1176 self._reuse = reuse 1177 self._linear1 = None 1178 self._linear2 = None 1179 self._linear3 = None 1180 1181 @property 1182 def state_size(self): 1183 size = (self._cell.state_size, self._attn_size, 1184 self._attn_size * self._attn_length) 1185 if self._state_is_tuple: 1186 return size 1187 else: 1188 return sum(list(size)) 1189 1190 @property 1191 def output_size(self): 1192 return self._attn_size 1193 1194 def call(self, inputs, state): 1195 """Long short-term memory cell with attention (LSTMA).""" 1196 if self._state_is_tuple: 1197 state, attns, attn_states = state 1198 else: 1199 states = state 1200 state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) 1201 attns = array_ops.slice(states, [0, self._cell.state_size], 1202 [-1, self._attn_size]) 1203 attn_states = array_ops.slice( 1204 states, [0, self._cell.state_size + self._attn_size], 1205 [-1, self._attn_size * self._attn_length]) 1206 attn_states = array_ops.reshape(attn_states, 1207 [-1, self._attn_length, self._attn_size]) 1208 input_size = self._input_size 1209 if input_size is None: 1210 input_size = inputs.get_shape().as_list()[1] 1211 if self._linear1 is None: 1212 self._linear1 = _Linear([inputs, attns], input_size, True) 1213 inputs = self._linear1([inputs, attns]) 1214 cell_output, new_state = self._cell(inputs, state) 1215 if self._state_is_tuple: 1216 new_state_cat = array_ops.concat(nest.flatten(new_state), 1) 1217 else: 1218 new_state_cat = new_state 1219 new_attns, new_attn_states = self._attention(new_state_cat, attn_states) 1220 with vs.variable_scope("attn_output_projection"): 1221 if self._linear2 is None: 1222 self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) 1223 output = self._linear2([cell_output, new_attns]) 1224 new_attn_states = array_ops.concat( 1225 [new_attn_states, array_ops.expand_dims(output, 1)], 1) 1226 new_attn_states = array_ops.reshape( 1227 new_attn_states, [-1, self._attn_length * self._attn_size]) 1228 new_state = (new_state, new_attns, new_attn_states) 1229 if not self._state_is_tuple: 1230 new_state = array_ops.concat(list(new_state), 1) 1231 return output, new_state 1232 1233 def _attention(self, query, attn_states): 1234 conv2d = nn_ops.conv2d 1235 reduce_sum = math_ops.reduce_sum 1236 softmax = nn_ops.softmax 1237 tanh = math_ops.tanh 1238 1239 with vs.variable_scope("attention"): 1240 k = vs.get_variable("attn_w", 1241 [1, 1, self._attn_size, self._attn_vec_size]) 1242 v = vs.get_variable("attn_v", [self._attn_vec_size]) 1243 hidden = array_ops.reshape(attn_states, 1244 [-1, self._attn_length, 1, self._attn_size]) 1245 hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") 1246 if self._linear3 is None: 1247 self._linear3 = _Linear(query, self._attn_vec_size, True) 1248 y = self._linear3(query) 1249 y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) 1250 s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) 1251 a = softmax(s) 1252 d = reduce_sum( 1253 array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) 1254 new_attns = array_ops.reshape(d, [-1, self._attn_size]) 1255 new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) 1256 return new_attns, new_attn_states 1257 1258 1259class HighwayWrapper(rnn_cell_impl.RNNCell): 1260 """RNNCell wrapper that adds highway connection on cell input and output. 1261 1262 Based on: 1263 R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", 1264 arXiv preprint arXiv:1505.00387, 2015. 1265 https://arxiv.org/abs/1505.00387 1266 """ 1267 1268 def __init__(self, 1269 cell, 1270 couple_carry_transform_gates=True, 1271 carry_bias_init=1.0): 1272 """Constructs a `HighwayWrapper` for `cell`. 1273 1274 Args: 1275 cell: An instance of `RNNCell`. 1276 couple_carry_transform_gates: boolean, should the Carry and Transform gate 1277 be coupled. 1278 carry_bias_init: float, carry gates bias initialization. 1279 """ 1280 self._cell = cell 1281 self._couple_carry_transform_gates = couple_carry_transform_gates 1282 self._carry_bias_init = carry_bias_init 1283 1284 @property 1285 def state_size(self): 1286 return self._cell.state_size 1287 1288 @property 1289 def output_size(self): 1290 return self._cell.output_size 1291 1292 def zero_state(self, batch_size, dtype): 1293 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1294 return self._cell.zero_state(batch_size, dtype) 1295 1296 def _highway(self, inp, out): 1297 input_size = inp.get_shape().with_rank(2).dims[1].value 1298 carry_weight = vs.get_variable("carry_w", [input_size, input_size]) 1299 carry_bias = vs.get_variable( 1300 "carry_b", [input_size], 1301 initializer=init_ops.constant_initializer(self._carry_bias_init)) 1302 carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) 1303 if self._couple_carry_transform_gates: 1304 transform = 1 - carry 1305 else: 1306 transform_weight = vs.get_variable("transform_w", 1307 [input_size, input_size]) 1308 transform_bias = vs.get_variable( 1309 "transform_b", [input_size], 1310 initializer=init_ops.constant_initializer(-self._carry_bias_init)) 1311 transform = math_ops.sigmoid( 1312 nn_ops.xw_plus_b(inp, transform_weight, transform_bias)) 1313 return inp * carry + out * transform 1314 1315 def __call__(self, inputs, state, scope=None): 1316 """Run the cell and add its inputs to its outputs. 1317 1318 Args: 1319 inputs: cell inputs. 1320 state: cell state. 1321 scope: optional cell scope. 1322 1323 Returns: 1324 Tuple of cell outputs and new state. 1325 1326 Raises: 1327 TypeError: If cell inputs and outputs have different structure (type). 1328 ValueError: If cell inputs and outputs have different structure (value). 1329 """ 1330 outputs, new_state = self._cell(inputs, state, scope=scope) 1331 nest.assert_same_structure(inputs, outputs) 1332 1333 # Ensure shapes match 1334 def assert_shape_match(inp, out): 1335 inp.get_shape().assert_is_compatible_with(out.get_shape()) 1336 1337 nest.map_structure(assert_shape_match, inputs, outputs) 1338 res_outputs = nest.map_structure(self._highway, inputs, outputs) 1339 return (res_outputs, new_state) 1340 1341 1342class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): 1343 """LSTM unit with layer normalization and recurrent dropout. 1344 1345 This class adds layer normalization and recurrent dropout to a 1346 basic LSTM unit. Layer normalization implementation is based on: 1347 1348 https://arxiv.org/abs/1607.06450. 1349 1350 "Layer Normalization" 1351 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 1352 1353 and is applied before the internal nonlinearities. 1354 Recurrent dropout is base on: 1355 1356 https://arxiv.org/abs/1603.05118 1357 1358 "Recurrent Dropout without Memory Loss" 1359 Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. 1360 """ 1361 1362 def __init__(self, 1363 num_units, 1364 forget_bias=1.0, 1365 input_size=None, 1366 activation=math_ops.tanh, 1367 layer_norm=True, 1368 norm_gain=1.0, 1369 norm_shift=0.0, 1370 dropout_keep_prob=1.0, 1371 dropout_prob_seed=None, 1372 reuse=None): 1373 """Initializes the basic LSTM cell. 1374 1375 Args: 1376 num_units: int, The number of units in the LSTM cell. 1377 forget_bias: float, The bias added to forget gates (see above). 1378 input_size: Deprecated and unused. 1379 activation: Activation function of the inner states. 1380 layer_norm: If `True`, layer normalization will be applied. 1381 norm_gain: float, The layer normalization gain initial value. If 1382 `layer_norm` has been set to `False`, this argument will be ignored. 1383 norm_shift: float, The layer normalization shift initial value. If 1384 `layer_norm` has been set to `False`, this argument will be ignored. 1385 dropout_keep_prob: unit Tensor or float between 0 and 1 representing the 1386 recurrent dropout probability value. If float and 1.0, no dropout will 1387 be applied. 1388 dropout_prob_seed: (optional) integer, the randomness seed. 1389 reuse: (optional) Python boolean describing whether to reuse variables 1390 in an existing scope. If not `True`, and the existing scope already has 1391 the given variables, an error is raised. 1392 """ 1393 super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) 1394 1395 if input_size is not None: 1396 logging.warn("%s: The input_size parameter is deprecated.", self) 1397 1398 self._num_units = num_units 1399 self._activation = activation 1400 self._forget_bias = forget_bias 1401 self._keep_prob = dropout_keep_prob 1402 self._seed = dropout_prob_seed 1403 self._layer_norm = layer_norm 1404 self._norm_gain = norm_gain 1405 self._norm_shift = norm_shift 1406 self._reuse = reuse 1407 1408 @property 1409 def state_size(self): 1410 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 1411 1412 @property 1413 def output_size(self): 1414 return self._num_units 1415 1416 def _norm(self, inp, scope, dtype=dtypes.float32): 1417 shape = inp.get_shape()[-1:] 1418 gamma_init = init_ops.constant_initializer(self._norm_gain) 1419 beta_init = init_ops.constant_initializer(self._norm_shift) 1420 with vs.variable_scope(scope): 1421 # Initialize beta and gamma for use by layer_norm. 1422 vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) 1423 vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) 1424 normalized = layers.layer_norm(inp, reuse=True, scope=scope) 1425 return normalized 1426 1427 def _linear(self, args): 1428 out_size = 4 * self._num_units 1429 proj_size = args.get_shape()[-1] 1430 dtype = args.dtype 1431 weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) 1432 out = math_ops.matmul(args, weights) 1433 if not self._layer_norm: 1434 bias = vs.get_variable("bias", [out_size], dtype=dtype) 1435 out = nn_ops.bias_add(out, bias) 1436 return out 1437 1438 def call(self, inputs, state): 1439 """LSTM cell with layer normalization and recurrent dropout.""" 1440 c, h = state 1441 args = array_ops.concat([inputs, h], 1) 1442 concat = self._linear(args) 1443 dtype = args.dtype 1444 1445 i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 1446 if self._layer_norm: 1447 i = self._norm(i, "input", dtype=dtype) 1448 j = self._norm(j, "transform", dtype=dtype) 1449 f = self._norm(f, "forget", dtype=dtype) 1450 o = self._norm(o, "output", dtype=dtype) 1451 1452 g = self._activation(j) 1453 if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 1454 g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 1455 1456 new_c = ( 1457 c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) 1458 if self._layer_norm: 1459 new_c = self._norm(new_c, "state", dtype=dtype) 1460 new_h = self._activation(new_c) * math_ops.sigmoid(o) 1461 1462 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 1463 return new_h, new_state 1464 1465 1466class NASCell(rnn_cell_impl.LayerRNNCell): 1467 """Neural Architecture Search (NAS) recurrent network cell. 1468 1469 This implements the recurrent cell from the paper: 1470 1471 https://arxiv.org/abs/1611.01578 1472 1473 Barret Zoph and Quoc V. Le. 1474 "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. 1475 1476 The class uses an optional projection layer. 1477 """ 1478 1479 # NAS cell's architecture base. 1480 _NAS_BASE = 8 1481 1482 def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None, 1483 **kwargs): 1484 """Initialize the parameters for a NAS cell. 1485 1486 Args: 1487 num_units: int, The number of units in the NAS cell. 1488 num_proj: (optional) int, The output dimensionality for the projection 1489 matrices. If None, no projection is performed. 1490 use_bias: (optional) bool, If True then use biases within the cell. This 1491 is False by default. 1492 reuse: (optional) Python boolean describing whether to reuse variables 1493 in an existing scope. If not `True`, and the existing scope already has 1494 the given variables, an error is raised. 1495 **kwargs: Additional keyword arguments. 1496 """ 1497 super(NASCell, self).__init__(_reuse=reuse, **kwargs) 1498 self._num_units = num_units 1499 self._num_proj = num_proj 1500 self._use_bias = use_bias 1501 self._reuse = reuse 1502 1503 if num_proj is not None: 1504 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 1505 self._output_size = num_proj 1506 else: 1507 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 1508 self._output_size = num_units 1509 1510 @property 1511 def state_size(self): 1512 return self._state_size 1513 1514 @property 1515 def output_size(self): 1516 return self._output_size 1517 1518 def build(self, inputs_shape): 1519 input_size = tensor_shape.dimension_value( 1520 tensor_shape.TensorShape(inputs_shape).with_rank(2)[1]) 1521 if input_size is None: 1522 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1523 1524 num_proj = self._num_units if self._num_proj is None else self._num_proj 1525 1526 # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying 1527 # the hiddenstate and `kernel` is all matrices multiplying the inputs. 1528 self.recurrent_kernel = self.add_variable( 1529 "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units]) 1530 self.kernel = self.add_variable( 1531 "kernel", [input_size, self._NAS_BASE * self._num_units]) 1532 1533 if self._use_bias: 1534 self.bias = self.add_variable("bias", 1535 shape=[self._NAS_BASE * self._num_units], 1536 initializer=init_ops.zeros_initializer) 1537 1538 # Projection layer if specified 1539 if self._num_proj is not None: 1540 self.projection_weights = self.add_variable( 1541 "projection_weights", [self._num_units, self._num_proj]) 1542 1543 self.built = True 1544 1545 def call(self, inputs, state): 1546 """Run one step of NAS Cell. 1547 1548 Args: 1549 inputs: input Tensor, 2D, batch x num_units. 1550 state: This must be a tuple of state Tensors, both `2-D`, with column 1551 sizes `c_state` and `m_state`. 1552 1553 Returns: 1554 A tuple containing: 1555 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 1556 NAS Cell after reading `inputs` when previous state was `state`. 1557 Here output_dim is: 1558 num_proj if num_proj was set, 1559 num_units otherwise. 1560 - Tensor(s) representing the new state of NAS Cell after reading `inputs` 1561 when the previous state was `state`. Same type and shape(s) as `state`. 1562 1563 Raises: 1564 ValueError: If input size cannot be inferred from inputs via 1565 static shape inference. 1566 """ 1567 sigmoid = math_ops.sigmoid 1568 tanh = math_ops.tanh 1569 relu = nn_ops.relu 1570 1571 (c_prev, m_prev) = state 1572 1573 m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel) 1574 inputs_matrix = math_ops.matmul(inputs, self.kernel) 1575 1576 if self._use_bias: 1577 m_matrix = nn_ops.bias_add(m_matrix, self.bias) 1578 1579 # The NAS cell branches into 8 different splits for both the hiddenstate 1580 # and the input 1581 m_matrix_splits = array_ops.split( 1582 axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix) 1583 inputs_matrix_splits = array_ops.split( 1584 axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix) 1585 1586 # First layer 1587 layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) 1588 layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) 1589 layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) 1590 layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) 1591 layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) 1592 layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) 1593 layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) 1594 layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) 1595 1596 # Second layer 1597 l2_0 = tanh(layer1_0 * layer1_1) 1598 l2_1 = tanh(layer1_2 + layer1_3) 1599 l2_2 = tanh(layer1_4 * layer1_5) 1600 l2_3 = sigmoid(layer1_6 + layer1_7) 1601 1602 # Inject the cell 1603 l2_0 = tanh(l2_0 + c_prev) 1604 1605 # Third layer 1606 l3_0_pre = l2_0 * l2_1 1607 new_c = l3_0_pre # create new cell 1608 l3_0 = l3_0_pre 1609 l3_1 = tanh(l2_2 + l2_3) 1610 1611 # Final layer 1612 new_m = tanh(l3_0 * l3_1) 1613 1614 # Projection layer if specified 1615 if self._num_proj is not None: 1616 new_m = math_ops.matmul(new_m, self.projection_weights) 1617 1618 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) 1619 return new_m, new_state 1620 1621 1622class UGRNNCell(rnn_cell_impl.RNNCell): 1623 """Update Gate Recurrent Neural Network (UGRNN) cell. 1624 1625 Compromise between a LSTM/GRU and a vanilla RNN. There is only one 1626 gate, and that is to determine whether the unit should be 1627 integrating or computing instantaneously. This is the recurrent 1628 idea of the feedforward Highway Network. 1629 1630 This implements the recurrent cell from the paper: 1631 1632 https://arxiv.org/abs/1611.09913 1633 1634 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1635 "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1636 """ 1637 1638 def __init__(self, 1639 num_units, 1640 initializer=None, 1641 forget_bias=1.0, 1642 activation=math_ops.tanh, 1643 reuse=None): 1644 """Initialize the parameters for an UGRNN cell. 1645 1646 Args: 1647 num_units: int, The number of units in the UGRNN cell 1648 initializer: (optional) The initializer to use for the weight matrices. 1649 forget_bias: (optional) float, default 1.0, The initial bias of the 1650 forget gate, used to reduce the scale of forgetting at the beginning 1651 of the training. 1652 activation: (optional) Activation function of the inner states. 1653 Default is `tf.tanh`. 1654 reuse: (optional) Python boolean describing whether to reuse variables 1655 in an existing scope. If not `True`, and the existing scope already has 1656 the given variables, an error is raised. 1657 """ 1658 super(UGRNNCell, self).__init__(_reuse=reuse) 1659 self._num_units = num_units 1660 self._initializer = initializer 1661 self._forget_bias = forget_bias 1662 self._activation = activation 1663 self._reuse = reuse 1664 self._linear = None 1665 1666 @property 1667 def state_size(self): 1668 return self._num_units 1669 1670 @property 1671 def output_size(self): 1672 return self._num_units 1673 1674 def call(self, inputs, state): 1675 """Run one step of UGRNN. 1676 1677 Args: 1678 inputs: input Tensor, 2D, batch x input size. 1679 state: state Tensor, 2D, batch x num units. 1680 1681 Returns: 1682 new_output: batch x num units, Tensor representing the output of the UGRNN 1683 after reading `inputs` when previous state was `state`. Identical to 1684 `new_state`. 1685 new_state: batch x num units, Tensor representing the state of the UGRNN 1686 after reading `inputs` when previous state was `state`. 1687 1688 Raises: 1689 ValueError: If input size cannot be inferred from inputs via 1690 static shape inference. 1691 """ 1692 sigmoid = math_ops.sigmoid 1693 1694 input_size = inputs.get_shape().with_rank(2).dims[1] 1695 if input_size.value is None: 1696 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1697 1698 with vs.variable_scope( 1699 vs.get_variable_scope(), initializer=self._initializer): 1700 cell_inputs = array_ops.concat([inputs, state], 1) 1701 if self._linear is None: 1702 self._linear = _Linear(cell_inputs, 2 * self._num_units, True) 1703 rnn_matrix = self._linear(cell_inputs) 1704 1705 [g_act, c_act] = array_ops.split( 1706 axis=1, num_or_size_splits=2, value=rnn_matrix) 1707 1708 c = self._activation(c_act) 1709 g = sigmoid(g_act + self._forget_bias) 1710 new_state = g * state + (1.0 - g) * c 1711 new_output = new_state 1712 1713 return new_output, new_state 1714 1715 1716class IntersectionRNNCell(rnn_cell_impl.RNNCell): 1717 """Intersection Recurrent Neural Network (+RNN) cell. 1718 1719 Architecture with coupled recurrent gate as well as coupled depth 1720 gate, designed to improve information flow through stacked RNNs. As the 1721 architecture uses depth gating, the dimensionality of the depth 1722 output (y) also should not change through depth (input size == output size). 1723 To achieve this, the first layer of a stacked Intersection RNN projects 1724 the inputs to N (num units) dimensions. Therefore when initializing an 1725 IntersectionRNNCell, one should set `num_in_proj = N` for the first layer 1726 and use default settings for subsequent layers. 1727 1728 This implements the recurrent cell from the paper: 1729 1730 https://arxiv.org/abs/1611.09913 1731 1732 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1733 "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1734 1735 The Intersection RNN is built for use in deeply stacked 1736 RNNs so it may not achieve best performance with depth 1. 1737 """ 1738 1739 def __init__(self, 1740 num_units, 1741 num_in_proj=None, 1742 initializer=None, 1743 forget_bias=1.0, 1744 y_activation=nn_ops.relu, 1745 reuse=None): 1746 """Initialize the parameters for an +RNN cell. 1747 1748 Args: 1749 num_units: int, The number of units in the +RNN cell 1750 num_in_proj: (optional) int, The input dimensionality for the RNN. 1751 If creating the first layer of an +RNN, this should be set to 1752 `num_units`. Otherwise, this should be set to `None` (default). 1753 If `None`, dimensionality of `inputs` should be equal to `num_units`, 1754 otherwise ValueError is thrown. 1755 initializer: (optional) The initializer to use for the weight matrices. 1756 forget_bias: (optional) float, default 1.0, The initial bias of the 1757 forget gates, used to reduce the scale of forgetting at the beginning 1758 of the training. 1759 y_activation: (optional) Activation function of the states passed 1760 through depth. Default is 'tf.nn.relu`. 1761 reuse: (optional) Python boolean describing whether to reuse variables 1762 in an existing scope. If not `True`, and the existing scope already has 1763 the given variables, an error is raised. 1764 """ 1765 super(IntersectionRNNCell, self).__init__(_reuse=reuse) 1766 self._num_units = num_units 1767 self._initializer = initializer 1768 self._forget_bias = forget_bias 1769 self._num_input_proj = num_in_proj 1770 self._y_activation = y_activation 1771 self._reuse = reuse 1772 self._linear1 = None 1773 self._linear2 = None 1774 1775 @property 1776 def state_size(self): 1777 return self._num_units 1778 1779 @property 1780 def output_size(self): 1781 return self._num_units 1782 1783 def call(self, inputs, state): 1784 """Run one step of the Intersection RNN. 1785 1786 Args: 1787 inputs: input Tensor, 2D, batch x input size. 1788 state: state Tensor, 2D, batch x num units. 1789 1790 Returns: 1791 new_y: batch x num units, Tensor representing the output of the +RNN 1792 after reading `inputs` when previous state was `state`. 1793 new_state: batch x num units, Tensor representing the state of the +RNN 1794 after reading `inputs` when previous state was `state`. 1795 1796 Raises: 1797 ValueError: If input size cannot be inferred from `inputs` via 1798 static shape inference. 1799 ValueError: If input size != output size (these must be equal when 1800 using the Intersection RNN). 1801 """ 1802 sigmoid = math_ops.sigmoid 1803 tanh = math_ops.tanh 1804 1805 input_size = inputs.get_shape().with_rank(2).dims[1] 1806 if input_size.value is None: 1807 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1808 1809 with vs.variable_scope( 1810 vs.get_variable_scope(), initializer=self._initializer): 1811 # read-in projections (should be used for first layer in deep +RNN 1812 # to transform size of inputs from I --> N) 1813 if input_size.value != self._num_units: 1814 if self._num_input_proj: 1815 with vs.variable_scope("in_projection"): 1816 if self._linear1 is None: 1817 self._linear1 = _Linear(inputs, self._num_units, True) 1818 inputs = self._linear1(inputs) 1819 else: 1820 raise ValueError("Must have input size == output size for " 1821 "Intersection RNN. To fix, num_in_proj should " 1822 "be set to num_units at cell init.") 1823 1824 n_dim = i_dim = self._num_units 1825 cell_inputs = array_ops.concat([inputs, state], 1) 1826 if self._linear2 is None: 1827 self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True) 1828 rnn_matrix = self._linear2(cell_inputs) 1829 1830 gh_act = rnn_matrix[:, :n_dim] # b x n 1831 h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n 1832 gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i 1833 y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i 1834 1835 h = tanh(h_act) 1836 y = self._y_activation(y_act) 1837 gh = sigmoid(gh_act + self._forget_bias) 1838 gy = sigmoid(gy_act + self._forget_bias) 1839 1840 new_state = gh * state + (1.0 - gh) * h # passed thru time 1841 new_y = gy * inputs + (1.0 - gy) * y # passed thru depth 1842 1843 return new_y, new_state 1844 1845 1846_REGISTERED_OPS = None 1847 1848 1849class CompiledWrapper(rnn_cell_impl.RNNCell): 1850 """Wraps step execution in an XLA JIT scope.""" 1851 1852 def __init__(self, cell, compile_stateful=False): 1853 """Create CompiledWrapper cell. 1854 1855 Args: 1856 cell: Instance of `RNNCell`. 1857 compile_stateful: Whether to compile stateful ops like initializers 1858 and random number generators (default: False). 1859 """ 1860 self._cell = cell 1861 self._compile_stateful = compile_stateful 1862 1863 @property 1864 def state_size(self): 1865 return self._cell.state_size 1866 1867 @property 1868 def output_size(self): 1869 return self._cell.output_size 1870 1871 def zero_state(self, batch_size, dtype): 1872 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1873 return self._cell.zero_state(batch_size, dtype) 1874 1875 def __call__(self, inputs, state, scope=None): 1876 if self._compile_stateful: 1877 compile_ops = True 1878 else: 1879 1880 def compile_ops(node_def): 1881 global _REGISTERED_OPS 1882 if _REGISTERED_OPS is None: 1883 _REGISTERED_OPS = op_def_registry.get_registered_ops() 1884 return not _REGISTERED_OPS[node_def.op].is_stateful 1885 1886 with jit.experimental_jit_scope(compile_ops=compile_ops): 1887 return self._cell(inputs, state, scope=scope) 1888 1889 1890def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32): 1891 """Returns an exponential distribution initializer. 1892 1893 Args: 1894 minval: float or a scalar float Tensor. With value > 0. Lower bound of the 1895 range of random values to generate. 1896 maxval: float or a scalar float Tensor. With value > minval. Upper bound of 1897 the range of random values to generate. 1898 seed: An integer. Used to create random seeds. 1899 dtype: The data type. 1900 1901 Returns: 1902 An initializer that generates tensors with an exponential distribution. 1903 """ 1904 1905 def _initializer(shape, dtype=dtype, partition_info=None): 1906 del partition_info # Unused. 1907 return math_ops.exp( 1908 random_ops.random_uniform( 1909 shape, math_ops.log(minval), math_ops.log(maxval), dtype, 1910 seed=seed)) 1911 1912 return _initializer 1913 1914 1915class PhasedLSTMCell(rnn_cell_impl.RNNCell): 1916 """Phased LSTM recurrent network cell. 1917 1918 https://arxiv.org/pdf/1610.09513v1.pdf 1919 """ 1920 1921 def __init__(self, 1922 num_units, 1923 use_peepholes=False, 1924 leak=0.001, 1925 ratio_on=0.1, 1926 trainable_ratio_on=True, 1927 period_init_min=1.0, 1928 period_init_max=1000.0, 1929 reuse=None): 1930 """Initialize the Phased LSTM cell. 1931 1932 Args: 1933 num_units: int, The number of units in the Phased LSTM cell. 1934 use_peepholes: bool, set True to enable peephole connections. 1935 leak: float or scalar float Tensor with value in [0, 1]. Leak applied 1936 during training. 1937 ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the 1938 period during which the gates are open. 1939 trainable_ratio_on: bool, weather ratio_on is trainable. 1940 period_init_min: float or scalar float Tensor. With value > 0. 1941 Minimum value of the initialized period. 1942 The period values are initialized by drawing from the distribution: 1943 e^U(log(period_init_min), log(period_init_max)) 1944 Where U(.,.) is the uniform distribution. 1945 period_init_max: float or scalar float Tensor. 1946 With value > period_init_min. Maximum value of the initialized period. 1947 reuse: (optional) Python boolean describing whether to reuse variables 1948 in an existing scope. If not `True`, and the existing scope already has 1949 the given variables, an error is raised. 1950 """ 1951 super(PhasedLSTMCell, self).__init__(_reuse=reuse) 1952 self._num_units = num_units 1953 self._use_peepholes = use_peepholes 1954 self._leak = leak 1955 self._ratio_on = ratio_on 1956 self._trainable_ratio_on = trainable_ratio_on 1957 self._period_init_min = period_init_min 1958 self._period_init_max = period_init_max 1959 self._reuse = reuse 1960 self._linear1 = None 1961 self._linear2 = None 1962 self._linear3 = None 1963 1964 @property 1965 def state_size(self): 1966 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 1967 1968 @property 1969 def output_size(self): 1970 return self._num_units 1971 1972 def _mod(self, x, y): 1973 """Modulo function that propagates x gradients.""" 1974 return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x 1975 1976 def _get_cycle_ratio(self, time, phase, period): 1977 """Compute the cycle ratio in the dtype of the time.""" 1978 phase_casted = math_ops.cast(phase, dtype=time.dtype) 1979 period_casted = math_ops.cast(period, dtype=time.dtype) 1980 shifted_time = time - phase_casted 1981 cycle_ratio = self._mod(shifted_time, period_casted) / period_casted 1982 return math_ops.cast(cycle_ratio, dtype=dtypes.float32) 1983 1984 def call(self, inputs, state): 1985 """Phased LSTM Cell. 1986 1987 Args: 1988 inputs: A tuple of 2 Tensor. 1989 The first Tensor has shape [batch, 1], and type float32 or float64. 1990 It stores the time. 1991 The second Tensor has shape [batch, features_size], and type float32. 1992 It stores the features. 1993 state: rnn_cell_impl.LSTMStateTuple, state from previous timestep. 1994 1995 Returns: 1996 A tuple containing: 1997 - A Tensor of float32, and shape [batch_size, num_units], representing the 1998 output of the cell. 1999 - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape 2000 [batch_size, num_units], representing the new state and the output. 2001 """ 2002 (c_prev, h_prev) = state 2003 (time, x) = inputs 2004 2005 in_mask_gates = [x, h_prev] 2006 if self._use_peepholes: 2007 in_mask_gates.append(c_prev) 2008 2009 with vs.variable_scope("mask_gates"): 2010 if self._linear1 is None: 2011 self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) 2012 2013 mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates)) 2014 [input_gate, forget_gate] = array_ops.split( 2015 axis=1, num_or_size_splits=2, value=mask_gates) 2016 2017 with vs.variable_scope("new_input"): 2018 if self._linear2 is None: 2019 self._linear2 = _Linear([x, h_prev], self._num_units, True) 2020 new_input = math_ops.tanh(self._linear2([x, h_prev])) 2021 2022 new_c = (c_prev * forget_gate + input_gate * new_input) 2023 2024 in_out_gate = [x, h_prev] 2025 if self._use_peepholes: 2026 in_out_gate.append(new_c) 2027 2028 with vs.variable_scope("output_gate"): 2029 if self._linear3 is None: 2030 self._linear3 = _Linear(in_out_gate, self._num_units, True) 2031 output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) 2032 2033 new_h = math_ops.tanh(new_c) * output_gate 2034 2035 period = vs.get_variable( 2036 "period", [self._num_units], 2037 initializer=_random_exp_initializer(self._period_init_min, 2038 self._period_init_max)) 2039 phase = vs.get_variable( 2040 "phase", [self._num_units], 2041 initializer=init_ops.random_uniform_initializer(0., 2042 period.initial_value)) 2043 ratio_on = vs.get_variable( 2044 "ratio_on", [self._num_units], 2045 initializer=init_ops.constant_initializer(self._ratio_on), 2046 trainable=self._trainable_ratio_on) 2047 2048 cycle_ratio = self._get_cycle_ratio(time, phase, period) 2049 2050 k_up = 2 * cycle_ratio / ratio_on 2051 k_down = 2 - k_up 2052 k_closed = self._leak * cycle_ratio 2053 2054 k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) 2055 k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) 2056 2057 new_c = k * new_c + (1 - k) * c_prev 2058 new_h = k * new_h + (1 - k) * h_prev 2059 2060 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 2061 2062 return new_h, new_state 2063 2064 2065class ConvLSTMCell(rnn_cell_impl.RNNCell): 2066 """Convolutional LSTM recurrent network cell. 2067 2068 https://arxiv.org/pdf/1506.04214v1.pdf 2069 """ 2070 2071 def __init__(self, 2072 conv_ndims, 2073 input_shape, 2074 output_channels, 2075 kernel_shape, 2076 use_bias=True, 2077 skip_connection=False, 2078 forget_bias=1.0, 2079 initializers=None, 2080 name="conv_lstm_cell"): 2081 """Construct ConvLSTMCell. 2082 2083 Args: 2084 conv_ndims: Convolution dimensionality (1, 2 or 3). 2085 input_shape: Shape of the input as int tuple, excluding the batch size. 2086 output_channels: int, number of output channels of the conv LSTM. 2087 kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3). 2088 use_bias: (bool) Use bias in convolutions. 2089 skip_connection: If set to `True`, concatenate the input to the 2090 output of the conv LSTM. Default: `False`. 2091 forget_bias: Forget bias. 2092 initializers: Unused. 2093 name: Name of the module. 2094 2095 Raises: 2096 ValueError: If `skip_connection` is `True` and stride is different from 1 2097 or if `input_shape` is incompatible with `conv_ndims`. 2098 """ 2099 super(ConvLSTMCell, self).__init__(name=name) 2100 2101 if conv_ndims != len(input_shape) - 1: 2102 raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( 2103 input_shape, conv_ndims)) 2104 2105 self._conv_ndims = conv_ndims 2106 self._input_shape = input_shape 2107 self._output_channels = output_channels 2108 self._kernel_shape = list(kernel_shape) 2109 self._use_bias = use_bias 2110 self._forget_bias = forget_bias 2111 self._skip_connection = skip_connection 2112 2113 self._total_output_channels = output_channels 2114 if self._skip_connection: 2115 self._total_output_channels += self._input_shape[-1] 2116 2117 state_size = tensor_shape.TensorShape( 2118 self._input_shape[:-1] + [self._output_channels]) 2119 self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) 2120 self._output_size = tensor_shape.TensorShape( 2121 self._input_shape[:-1] + [self._total_output_channels]) 2122 2123 @property 2124 def output_size(self): 2125 return self._output_size 2126 2127 @property 2128 def state_size(self): 2129 return self._state_size 2130 2131 def call(self, inputs, state, scope=None): 2132 cell, hidden = state 2133 new_hidden = _conv([inputs, hidden], self._kernel_shape, 2134 4 * self._output_channels, self._use_bias) 2135 gates = array_ops.split( 2136 value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) 2137 2138 input_gate, new_input, forget_gate, output_gate = gates 2139 new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell 2140 new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) 2141 output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) 2142 2143 if self._skip_connection: 2144 output = array_ops.concat([output, inputs], axis=-1) 2145 new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) 2146 return output, new_state 2147 2148 2149class Conv1DLSTMCell(ConvLSTMCell): 2150 """1D Convolutional LSTM recurrent network cell. 2151 2152 https://arxiv.org/pdf/1506.04214v1.pdf 2153 """ 2154 2155 def __init__(self, name="conv_1d_lstm_cell", **kwargs): 2156 """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" 2157 super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs) 2158 2159 2160class Conv2DLSTMCell(ConvLSTMCell): 2161 """2D Convolutional LSTM recurrent network cell. 2162 2163 https://arxiv.org/pdf/1506.04214v1.pdf 2164 """ 2165 2166 def __init__(self, name="conv_2d_lstm_cell", **kwargs): 2167 """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" 2168 super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs) 2169 2170 2171class Conv3DLSTMCell(ConvLSTMCell): 2172 """3D Convolutional LSTM recurrent network cell. 2173 2174 https://arxiv.org/pdf/1506.04214v1.pdf 2175 """ 2176 2177 def __init__(self, name="conv_3d_lstm_cell", **kwargs): 2178 """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" 2179 super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs) 2180 2181 2182def _conv(args, filter_size, num_features, bias, bias_start=0.0): 2183 """Convolution. 2184 2185 Args: 2186 args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, 2187 batch x n, Tensors. 2188 filter_size: int tuple of filter shape (of size 1, 2 or 3). 2189 num_features: int, number of features. 2190 bias: Whether to use biases in the convolution layer. 2191 bias_start: starting value to initialize the bias; 0 by default. 2192 2193 Returns: 2194 A 3D, 4D, or 5D Tensor with shape [batch ... num_features] 2195 2196 Raises: 2197 ValueError: if some of the arguments has unspecified or wrong shape. 2198 """ 2199 2200 # Calculate the total size of arguments on dimension 1. 2201 total_arg_size_depth = 0 2202 shapes = [a.get_shape().as_list() for a in args] 2203 shape_length = len(shapes[0]) 2204 for shape in shapes: 2205 if len(shape) not in [3, 4, 5]: 2206 raise ValueError("Conv Linear expects 3D, 4D " 2207 "or 5D arguments: %s" % str(shapes)) 2208 if len(shape) != len(shapes[0]): 2209 raise ValueError("Conv Linear expects all args " 2210 "to be of same Dimension: %s" % str(shapes)) 2211 else: 2212 total_arg_size_depth += shape[-1] 2213 dtype = [a.dtype for a in args][0] 2214 2215 # determine correct conv operation 2216 if shape_length == 3: 2217 conv_op = nn_ops.conv1d 2218 strides = 1 2219 elif shape_length == 4: 2220 conv_op = nn_ops.conv2d 2221 strides = shape_length * [1] 2222 elif shape_length == 5: 2223 conv_op = nn_ops.conv3d 2224 strides = shape_length * [1] 2225 2226 # Now the computation. 2227 kernel = vs.get_variable( 2228 "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) 2229 if len(args) == 1: 2230 res = conv_op(args[0], kernel, strides, padding="SAME") 2231 else: 2232 res = conv_op( 2233 array_ops.concat(axis=shape_length - 1, values=args), 2234 kernel, 2235 strides, 2236 padding="SAME") 2237 if not bias: 2238 return res 2239 bias_term = vs.get_variable( 2240 "biases", [num_features], 2241 dtype=dtype, 2242 initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) 2243 return res + bias_term 2244 2245 2246class GLSTMCell(rnn_cell_impl.RNNCell): 2247 """Group LSTM cell (G-LSTM). 2248 2249 The implementation is based on: 2250 2251 https://arxiv.org/abs/1703.10722 2252 2253 O. Kuchaiev and B. Ginsburg 2254 "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. 2255 2256 In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each 2257 sub-cell operates on an evenly-sized sub-vector of the input and produces an 2258 evenly-sized sub-vector of the output. For example, a G-LSTM cell with 128 2259 units and 4 groups consists of 4 LSTMs sub-cells with 32 units each. If that 2260 G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part 2261 of the input and produces a 32-dim part of the output. 2262 """ 2263 2264 def __init__(self, 2265 num_units, 2266 initializer=None, 2267 num_proj=None, 2268 number_of_groups=1, 2269 forget_bias=1.0, 2270 activation=math_ops.tanh, 2271 reuse=None): 2272 """Initialize the parameters of G-LSTM cell. 2273 2274 Args: 2275 num_units: int, The number of units in the G-LSTM cell 2276 initializer: (optional) The initializer to use for the weight and 2277 projection matrices. 2278 num_proj: (optional) int, The output dimensionality for the projection 2279 matrices. If None, no projection is performed. 2280 number_of_groups: (optional) int, number of groups to use. 2281 If `number_of_groups` is 1, then it should be equivalent to LSTM cell 2282 forget_bias: Biases of the forget gate are initialized by default to 1 2283 in order to reduce the scale of forgetting at the beginning of 2284 the training. 2285 activation: Activation function of the inner states. 2286 reuse: (optional) Python boolean describing whether to reuse variables 2287 in an existing scope. If not `True`, and the existing scope already 2288 has the given variables, an error is raised. 2289 2290 Raises: 2291 ValueError: If `num_units` or `num_proj` is not divisible by 2292 `number_of_groups`. 2293 """ 2294 super(GLSTMCell, self).__init__(_reuse=reuse) 2295 self._num_units = num_units 2296 self._initializer = initializer 2297 self._num_proj = num_proj 2298 self._forget_bias = forget_bias 2299 self._activation = activation 2300 self._number_of_groups = number_of_groups 2301 2302 if self._num_units % self._number_of_groups != 0: 2303 raise ValueError("num_units must be divisible by number_of_groups") 2304 if self._num_proj: 2305 if self._num_proj % self._number_of_groups != 0: 2306 raise ValueError("num_proj must be divisible by number_of_groups") 2307 self._group_shape = [ 2308 int(self._num_proj / self._number_of_groups), 2309 int(self._num_units / self._number_of_groups) 2310 ] 2311 else: 2312 self._group_shape = [ 2313 int(self._num_units / self._number_of_groups), 2314 int(self._num_units / self._number_of_groups) 2315 ] 2316 2317 if num_proj: 2318 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2319 self._output_size = num_proj 2320 else: 2321 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2322 self._output_size = num_units 2323 self._linear1 = [None] * number_of_groups 2324 self._linear2 = None 2325 2326 @property 2327 def state_size(self): 2328 return self._state_size 2329 2330 @property 2331 def output_size(self): 2332 return self._output_size 2333 2334 def _get_input_for_group(self, inputs, group_id, group_size): 2335 """Slices inputs into groups to prepare for processing by cell's groups. 2336 2337 Args: 2338 inputs: cell input or it's previous state, 2339 a Tensor, 2D, [batch x num_units] 2340 group_id: group id, a Scalar, for which to prepare input 2341 group_size: size of the group 2342 2343 Returns: 2344 subset of inputs corresponding to group "group_id", 2345 a Tensor, 2D, [batch x num_units/number_of_groups] 2346 """ 2347 return array_ops.slice( 2348 input_=inputs, 2349 begin=[0, group_id * group_size], 2350 size=[self._batch_size, group_size], 2351 name=("GLSTM_group%d_input_generation" % group_id)) 2352 2353 def call(self, inputs, state): 2354 """Run one step of G-LSTM. 2355 2356 Args: 2357 inputs: input Tensor, 2D, [batch x num_inputs]. num_inputs must be 2358 statically-known and evenly divisible into groups. The innermost 2359 vectors of the inputs are split into evenly-sized sub-vectors and fed 2360 into the per-group LSTM sub-cells. 2361 state: this must be a tuple of state Tensors, both `2-D`, with column 2362 sizes `c_state` and `m_state`. 2363 2364 Returns: 2365 A tuple containing: 2366 2367 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2368 G-LSTM after reading `inputs` when previous state was `state`. 2369 Here output_dim is: 2370 num_proj if num_proj was set, 2371 num_units otherwise. 2372 - LSTMStateTuple representing the new state of G-LSTM cell 2373 after reading `inputs` when the previous state was `state`. 2374 2375 Raises: 2376 ValueError: If input size cannot be inferred from inputs via 2377 static shape inference, or if the input shape is incompatible 2378 with the number of groups. 2379 """ 2380 (c_prev, m_prev) = state 2381 2382 self._batch_size = tensor_shape.dimension_value( 2383 inputs.shape[0]) or array_ops.shape(inputs)[0] 2384 2385 # If the input size is statically-known, calculate and validate its group 2386 # size. Otherwise, use the output group size. 2387 input_size = tensor_shape.dimension_value(inputs.shape[1]) 2388 if input_size is None: 2389 raise ValueError("input size must be statically known") 2390 if input_size % self._number_of_groups != 0: 2391 raise ValueError( 2392 "input size (%d) must be divisible by number_of_groups (%d)" % 2393 (input_size, self._number_of_groups)) 2394 input_group_size = int(input_size / self._number_of_groups) 2395 2396 dtype = inputs.dtype 2397 scope = vs.get_variable_scope() 2398 with vs.variable_scope(scope, initializer=self._initializer): 2399 i_parts = [] 2400 j_parts = [] 2401 f_parts = [] 2402 o_parts = [] 2403 2404 for group_id in range(self._number_of_groups): 2405 with vs.variable_scope("group%d" % group_id): 2406 x_g_id = array_ops.concat( 2407 [ 2408 self._get_input_for_group(inputs, group_id, input_group_size), 2409 self._get_input_for_group(m_prev, group_id, 2410 self._group_shape[0]) 2411 ], 2412 axis=1) 2413 linear = self._linear1[group_id] 2414 if linear is None: 2415 linear = _Linear(x_g_id, 4 * self._group_shape[1], False) 2416 self._linear1[group_id] = linear 2417 R_k = linear(x_g_id) # pylint: disable=invalid-name 2418 i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) 2419 2420 i_parts.append(i_k) 2421 j_parts.append(j_k) 2422 f_parts.append(f_k) 2423 o_parts.append(o_k) 2424 2425 bi = vs.get_variable( 2426 name="bias_i", 2427 shape=[self._num_units], 2428 dtype=dtype, 2429 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2430 bj = vs.get_variable( 2431 name="bias_j", 2432 shape=[self._num_units], 2433 dtype=dtype, 2434 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2435 bf = vs.get_variable( 2436 name="bias_f", 2437 shape=[self._num_units], 2438 dtype=dtype, 2439 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2440 bo = vs.get_variable( 2441 name="bias_o", 2442 shape=[self._num_units], 2443 dtype=dtype, 2444 initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2445 2446 i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) 2447 j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) 2448 f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) 2449 o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) 2450 2451 c = ( 2452 math_ops.sigmoid(f + self._forget_bias) * c_prev + 2453 math_ops.sigmoid(i) * math_ops.tanh(j)) 2454 m = math_ops.sigmoid(o) * self._activation(c) 2455 2456 if self._num_proj is not None: 2457 with vs.variable_scope("projection"): 2458 if self._linear2 is None: 2459 self._linear2 = _Linear(m, self._num_proj, False) 2460 m = self._linear2(m) 2461 2462 new_state = rnn_cell_impl.LSTMStateTuple(c, m) 2463 return m, new_state 2464 2465 2466class LayerNormLSTMCell(rnn_cell_impl.RNNCell): 2467 """Long short-term memory unit (LSTM) recurrent network cell. 2468 2469 The default non-peephole implementation is based on: 2470 2471 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 2472 2473 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 2474 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 2475 2476 The peephole implementation is based on: 2477 2478 https://research.google.com/pubs/archive/43905.pdf 2479 2480 Hasim Sak, Andrew Senior, and Francoise Beaufays. 2481 "Long short-term memory recurrent neural network architectures for 2482 large scale acoustic modeling." INTERSPEECH, 2014. 2483 2484 The class uses optional peep-hole connections, optional cell clipping, and 2485 an optional projection layer. 2486 2487 Layer normalization implementation is based on: 2488 2489 https://arxiv.org/abs/1607.06450. 2490 2491 "Layer Normalization" 2492 Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 2493 2494 and is applied before the internal nonlinearities. 2495 2496 """ 2497 2498 def __init__(self, 2499 num_units, 2500 use_peepholes=False, 2501 cell_clip=None, 2502 initializer=None, 2503 num_proj=None, 2504 proj_clip=None, 2505 forget_bias=1.0, 2506 activation=None, 2507 layer_norm=False, 2508 norm_gain=1.0, 2509 norm_shift=0.0, 2510 reuse=None): 2511 """Initialize the parameters for an LSTM cell. 2512 2513 Args: 2514 num_units: int, The number of units in the LSTM cell 2515 use_peepholes: bool, set True to enable diagonal/peephole connections. 2516 cell_clip: (optional) A float value, if provided the cell state is clipped 2517 by this value prior to the cell output activation. 2518 initializer: (optional) The initializer to use for the weight and 2519 projection matrices. 2520 num_proj: (optional) int, The output dimensionality for the projection 2521 matrices. If None, no projection is performed. 2522 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2523 provided, then the projected values are clipped elementwise to within 2524 `[-proj_clip, proj_clip]`. 2525 forget_bias: Biases of the forget gate are initialized by default to 1 2526 in order to reduce the scale of forgetting at the beginning of 2527 the training. Must set it manually to `0.0` when restoring from 2528 CudnnLSTM trained checkpoints. 2529 activation: Activation function of the inner states. Default: `tanh`. 2530 layer_norm: If `True`, layer normalization will be applied. 2531 norm_gain: float, The layer normalization gain initial value. If 2532 `layer_norm` has been set to `False`, this argument will be ignored. 2533 norm_shift: float, The layer normalization shift initial value. If 2534 `layer_norm` has been set to `False`, this argument will be ignored. 2535 reuse: (optional) Python boolean describing whether to reuse variables 2536 in an existing scope. If not `True`, and the existing scope already has 2537 the given variables, an error is raised. 2538 2539 When restoring from CudnnLSTM-trained checkpoints, must use 2540 CudnnCompatibleLSTMCell instead. 2541 """ 2542 super(LayerNormLSTMCell, self).__init__(_reuse=reuse) 2543 2544 self._num_units = num_units 2545 self._use_peepholes = use_peepholes 2546 self._cell_clip = cell_clip 2547 self._initializer = initializer 2548 self._num_proj = num_proj 2549 self._proj_clip = proj_clip 2550 self._forget_bias = forget_bias 2551 self._activation = activation or math_ops.tanh 2552 self._layer_norm = layer_norm 2553 self._norm_gain = norm_gain 2554 self._norm_shift = norm_shift 2555 2556 if num_proj: 2557 self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) 2558 self._output_size = num_proj 2559 else: 2560 self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) 2561 self._output_size = num_units 2562 2563 @property 2564 def state_size(self): 2565 return self._state_size 2566 2567 @property 2568 def output_size(self): 2569 return self._output_size 2570 2571 def _linear(self, 2572 args, 2573 output_size, 2574 bias, 2575 bias_initializer=None, 2576 kernel_initializer=None, 2577 layer_norm=False): 2578 """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. 2579 2580 Args: 2581 args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2582 output_size: int, second dimension of W[i]. 2583 bias: boolean, whether to add a bias term or not. 2584 bias_initializer: starting value to initialize the bias 2585 (default is all zeros). 2586 kernel_initializer: starting value to initialize the weight. 2587 layer_norm: boolean, whether to apply layer normalization. 2588 2589 2590 Returns: 2591 A 2D Tensor with shape [batch x output_size] taking value 2592 sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. 2593 2594 Raises: 2595 ValueError: if some of the arguments has unspecified or wrong shape. 2596 """ 2597 if args is None or (nest.is_sequence(args) and not args): 2598 raise ValueError("`args` must be specified") 2599 if not nest.is_sequence(args): 2600 args = [args] 2601 2602 # Calculate the total size of arguments on dimension 1. 2603 total_arg_size = 0 2604 shapes = [a.get_shape() for a in args] 2605 for shape in shapes: 2606 if shape.ndims != 2: 2607 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2608 if tensor_shape.dimension_value(shape[1]) is None: 2609 raise ValueError("linear expects shape[1] to be provided for shape %s, " 2610 "but saw %s" % (shape, shape[1])) 2611 else: 2612 total_arg_size += tensor_shape.dimension_value(shape[1]) 2613 2614 dtype = [a.dtype for a in args][0] 2615 2616 # Now the computation. 2617 scope = vs.get_variable_scope() 2618 with vs.variable_scope(scope) as outer_scope: 2619 weights = vs.get_variable( 2620 "kernel", [total_arg_size, output_size], 2621 dtype=dtype, 2622 initializer=kernel_initializer) 2623 if len(args) == 1: 2624 res = math_ops.matmul(args[0], weights) 2625 else: 2626 res = math_ops.matmul(array_ops.concat(args, 1), weights) 2627 if not bias: 2628 return res 2629 with vs.variable_scope(outer_scope) as inner_scope: 2630 inner_scope.set_partitioner(None) 2631 if bias_initializer is None: 2632 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2633 biases = vs.get_variable( 2634 "bias", [output_size], dtype=dtype, initializer=bias_initializer) 2635 2636 if not layer_norm: 2637 res = nn_ops.bias_add(res, biases) 2638 2639 return res 2640 2641 def call(self, inputs, state): 2642 """Run one step of LSTM. 2643 2644 Args: 2645 inputs: input Tensor, 2D, batch x num_units. 2646 state: this must be a tuple of state Tensors, 2647 both `2-D`, with column sizes `c_state` and 2648 `m_state`. 2649 2650 Returns: 2651 A tuple containing: 2652 2653 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2654 LSTM after reading `inputs` when previous state was `state`. 2655 Here output_dim is: 2656 num_proj if num_proj was set, 2657 num_units otherwise. 2658 - Tensor(s) representing the new state of LSTM after reading `inputs` when 2659 the previous state was `state`. Same type and shape(s) as `state`. 2660 2661 Raises: 2662 ValueError: If input size cannot be inferred from inputs via 2663 static shape inference. 2664 """ 2665 sigmoid = math_ops.sigmoid 2666 2667 (c_prev, m_prev) = state 2668 2669 dtype = inputs.dtype 2670 input_size = inputs.get_shape().with_rank(2).dims[1] 2671 if input_size.value is None: 2672 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 2673 scope = vs.get_variable_scope() 2674 with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: 2675 2676 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 2677 lstm_matrix = self._linear( 2678 [inputs, m_prev], 2679 4 * self._num_units, 2680 bias=True, 2681 bias_initializer=None, 2682 layer_norm=self._layer_norm) 2683 i, j, f, o = array_ops.split( 2684 value=lstm_matrix, num_or_size_splits=4, axis=1) 2685 2686 if self._layer_norm: 2687 i = _norm(self._norm_gain, self._norm_shift, i, "input") 2688 j = _norm(self._norm_gain, self._norm_shift, j, "transform") 2689 f = _norm(self._norm_gain, self._norm_shift, f, "forget") 2690 o = _norm(self._norm_gain, self._norm_shift, o, "output") 2691 2692 # Diagonal connections 2693 if self._use_peepholes: 2694 with vs.variable_scope(unit_scope): 2695 w_f_diag = vs.get_variable( 2696 "w_f_diag", shape=[self._num_units], dtype=dtype) 2697 w_i_diag = vs.get_variable( 2698 "w_i_diag", shape=[self._num_units], dtype=dtype) 2699 w_o_diag = vs.get_variable( 2700 "w_o_diag", shape=[self._num_units], dtype=dtype) 2701 2702 if self._use_peepholes: 2703 c = ( 2704 sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 2705 sigmoid(i + w_i_diag * c_prev) * self._activation(j)) 2706 else: 2707 c = ( 2708 sigmoid(f + self._forget_bias) * c_prev + 2709 sigmoid(i) * self._activation(j)) 2710 2711 if self._layer_norm: 2712 c = _norm(self._norm_gain, self._norm_shift, c, "state") 2713 2714 if self._cell_clip is not None: 2715 # pylint: disable=invalid-unary-operand-type 2716 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 2717 # pylint: enable=invalid-unary-operand-type 2718 if self._use_peepholes: 2719 m = sigmoid(o + w_o_diag * c) * self._activation(c) 2720 else: 2721 m = sigmoid(o) * self._activation(c) 2722 2723 if self._num_proj is not None: 2724 with vs.variable_scope("projection"): 2725 m = self._linear(m, self._num_proj, bias=False) 2726 2727 if self._proj_clip is not None: 2728 # pylint: disable=invalid-unary-operand-type 2729 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 2730 # pylint: enable=invalid-unary-operand-type 2731 2732 new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) 2733 return m, new_state 2734 2735 2736class SRUCell(rnn_cell_impl.LayerRNNCell): 2737 """SRU, Simple Recurrent Unit. 2738 2739 Implementation based on 2740 Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). 2741 2742 This variation of RNN cell is characterized by the simplified data 2743 dependence 2744 between hidden states of two consecutive time steps. Traditionally, hidden 2745 states from a cell at time step t-1 needs to be multiplied with a matrix 2746 W_hh before being fed into the ensuing cell at time step t. 2747 This flavor of RNN replaces the matrix multiplication between h_{t-1} 2748 and W_hh with a pointwise multiplication, resulting in performance 2749 gain. 2750 2751 Args: 2752 num_units: int, The number of units in the SRU cell. 2753 activation: Nonlinearity to use. Default: `tanh`. 2754 reuse: (optional) Python boolean describing whether to reuse variables 2755 in an existing scope. If not `True`, and the existing scope already has 2756 the given variables, an error is raised. 2757 name: (optional) String, the name of the layer. Layers with the same name 2758 will share weights, but to avoid mistakes we require reuse=True in such 2759 cases. 2760 **kwargs: Additional keyword arguments. 2761 """ 2762 2763 def __init__(self, num_units, activation=None, reuse=None, name=None, 2764 **kwargs): 2765 super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs) 2766 self._num_units = num_units 2767 self._activation = activation or math_ops.tanh 2768 2769 # Restrict inputs to be 2-dimensional matrices 2770 self.input_spec = input_spec.InputSpec(ndim=2) 2771 2772 @property 2773 def state_size(self): 2774 return self._num_units 2775 2776 @property 2777 def output_size(self): 2778 return self._num_units 2779 2780 def build(self, inputs_shape): 2781 if tensor_shape.dimension_value(inputs_shape[1]) is None: 2782 raise ValueError( 2783 "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 2784 2785 input_depth = tensor_shape.dimension_value(inputs_shape[1]) 2786 2787 # pylint: disable=protected-access 2788 self._kernel = self.add_variable( 2789 rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2790 shape=[input_depth, 4 * self._num_units]) 2791 # pylint: enable=protected-access 2792 self._bias = self.add_variable( 2793 rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access 2794 shape=[2 * self._num_units], 2795 initializer=init_ops.zeros_initializer) 2796 2797 self._built = True 2798 2799 def call(self, inputs, state): 2800 """Simple recurrent unit (SRU) with num_units cells.""" 2801 2802 U = math_ops.matmul(inputs, self._kernel) # pylint: disable=invalid-name 2803 x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( 2804 value=U, num_or_size_splits=4, axis=1) 2805 2806 f_r = math_ops.sigmoid( 2807 nn_ops.bias_add( 2808 array_ops.concat([f_intermediate, r_intermediate], 1), self._bias)) 2809 f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) 2810 2811 c = f * state + (1.0 - f) * x_bar 2812 h = r * self._activation(c) + (1.0 - r) * x_tx 2813 2814 return h, c 2815 2816 2817class WeightNormLSTMCell(rnn_cell_impl.RNNCell): 2818 """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. 2819 2820 The weight-norm implementation is based on: 2821 https://arxiv.org/abs/1602.07868 2822 Tim Salimans, Diederik P. Kingma. 2823 Weight Normalization: A Simple Reparameterization to Accelerate 2824 Training of Deep Neural Networks 2825 2826 The default LSTM implementation based on: 2827 2828 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 2829 2830 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 2831 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 2832 2833 The class uses optional peephole connections, optional cell clipping 2834 and an optional projection layer. 2835 2836 The optional peephole implementation is based on: 2837 https://research.google.com/pubs/archive/43905.pdf 2838 Hasim Sak, Andrew Senior, and Francoise Beaufays. 2839 "Long short-term memory recurrent neural network architectures for 2840 large scale acoustic modeling." INTERSPEECH, 2014. 2841 """ 2842 2843 def __init__(self, 2844 num_units, 2845 norm=True, 2846 use_peepholes=False, 2847 cell_clip=None, 2848 initializer=None, 2849 num_proj=None, 2850 proj_clip=None, 2851 forget_bias=1, 2852 activation=None, 2853 reuse=None): 2854 """Initialize the parameters of a weight-normalized LSTM cell. 2855 2856 Args: 2857 num_units: int, The number of units in the LSTM cell 2858 norm: If `True`, apply normalization to the weight matrices. If False, 2859 the result is identical to that obtained from `rnn_cell_impl.LSTMCell` 2860 use_peepholes: bool, set `True` to enable diagonal/peephole connections. 2861 cell_clip: (optional) A float value, if provided the cell state is clipped 2862 by this value prior to the cell output activation. 2863 initializer: (optional) The initializer to use for the weight matrices. 2864 num_proj: (optional) int, The output dimensionality for the projection 2865 matrices. If None, no projection is performed. 2866 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2867 provided, then the projected values are clipped elementwise to within 2868 `[-proj_clip, proj_clip]`. 2869 forget_bias: Biases of the forget gate are initialized by default to 1 2870 in order to reduce the scale of forgetting at the beginning of 2871 the training. 2872 activation: Activation function of the inner states. Default: `tanh`. 2873 reuse: (optional) Python boolean describing whether to reuse variables 2874 in an existing scope. If not `True`, and the existing scope already has 2875 the given variables, an error is raised. 2876 """ 2877 super(WeightNormLSTMCell, self).__init__(_reuse=reuse) 2878 2879 self._scope = "wn_lstm_cell" 2880 self._num_units = num_units 2881 self._norm = norm 2882 self._initializer = initializer 2883 self._use_peepholes = use_peepholes 2884 self._cell_clip = cell_clip 2885 self._num_proj = num_proj 2886 self._proj_clip = proj_clip 2887 self._activation = activation or math_ops.tanh 2888 self._forget_bias = forget_bias 2889 2890 self._weights_variable_name = "kernel" 2891 self._bias_variable_name = "bias" 2892 2893 if num_proj: 2894 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2895 self._output_size = num_proj 2896 else: 2897 self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2898 self._output_size = num_units 2899 2900 @property 2901 def state_size(self): 2902 return self._state_size 2903 2904 @property 2905 def output_size(self): 2906 return self._output_size 2907 2908 def _normalize(self, weight, name): 2909 """Apply weight normalization. 2910 2911 Args: 2912 weight: a 2D tensor with known number of columns. 2913 name: string, variable name for the normalizer. 2914 Returns: 2915 A tensor with the same shape as `weight`. 2916 """ 2917 2918 output_size = weight.get_shape().as_list()[1] 2919 g = vs.get_variable(name, [output_size], dtype=weight.dtype) 2920 return nn_impl.l2_normalize(weight, axis=0) * g 2921 2922 def _linear(self, 2923 args, 2924 output_size, 2925 norm, 2926 bias, 2927 bias_initializer=None, 2928 kernel_initializer=None): 2929 """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 2930 2931 Args: 2932 args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2933 output_size: int, second dimension of W[i]. 2934 norm: bool, whether to normalize the weights. 2935 bias: boolean, whether to add a bias term or not. 2936 bias_initializer: starting value to initialize the bias 2937 (default is all zeros). 2938 kernel_initializer: starting value to initialize the weight. 2939 2940 Returns: 2941 A 2D Tensor with shape [batch x output_size] equal to 2942 sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 2943 2944 Raises: 2945 ValueError: if some of the arguments has unspecified or wrong shape. 2946 """ 2947 if args is None or (nest.is_sequence(args) and not args): 2948 raise ValueError("`args` must be specified") 2949 if not nest.is_sequence(args): 2950 args = [args] 2951 2952 # Calculate the total size of arguments on dimension 1. 2953 total_arg_size = 0 2954 shapes = [a.get_shape() for a in args] 2955 for shape in shapes: 2956 if shape.ndims != 2: 2957 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2958 if tensor_shape.dimension_value(shape[1]) is None: 2959 raise ValueError("linear expects shape[1] to be provided for shape %s, " 2960 "but saw %s" % (shape, shape[1])) 2961 else: 2962 total_arg_size += tensor_shape.dimension_value(shape[1]) 2963 2964 dtype = [a.dtype for a in args][0] 2965 2966 # Now the computation. 2967 scope = vs.get_variable_scope() 2968 with vs.variable_scope(scope) as outer_scope: 2969 weights = vs.get_variable( 2970 self._weights_variable_name, [total_arg_size, output_size], 2971 dtype=dtype, 2972 initializer=kernel_initializer) 2973 if norm: 2974 wn = [] 2975 st = 0 2976 with ops.control_dependencies(None): 2977 for i in range(len(args)): 2978 en = st + tensor_shape.dimension_value(shapes[i][1]) 2979 wn.append( 2980 self._normalize(weights[st:en, :], name="norm_{}".format(i))) 2981 st = en 2982 2983 weights = array_ops.concat(wn, axis=0) 2984 2985 if len(args) == 1: 2986 res = math_ops.matmul(args[0], weights) 2987 else: 2988 res = math_ops.matmul(array_ops.concat(args, 1), weights) 2989 if not bias: 2990 return res 2991 2992 with vs.variable_scope(outer_scope) as inner_scope: 2993 inner_scope.set_partitioner(None) 2994 if bias_initializer is None: 2995 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2996 2997 biases = vs.get_variable( 2998 self._bias_variable_name, [output_size], 2999 dtype=dtype, 3000 initializer=bias_initializer) 3001 3002 return nn_ops.bias_add(res, biases) 3003 3004 def call(self, inputs, state): 3005 """Run one step of LSTM. 3006 3007 Args: 3008 inputs: input Tensor, 2D, batch x num_units. 3009 state: A tuple of state Tensors, both `2-D`, with column sizes 3010 `c_state` and `m_state`. 3011 3012 Returns: 3013 A tuple containing: 3014 3015 - A `2-D, [batch x output_dim]`, Tensor representing the output of the 3016 LSTM after reading `inputs` when previous state was `state`. 3017 Here output_dim is: 3018 num_proj if num_proj was set, 3019 num_units otherwise. 3020 - Tensor(s) representing the new state of LSTM after reading `inputs` when 3021 the previous state was `state`. Same type and shape(s) as `state`. 3022 3023 Raises: 3024 ValueError: If input size cannot be inferred from inputs via 3025 static shape inference. 3026 """ 3027 dtype = inputs.dtype 3028 num_units = self._num_units 3029 sigmoid = math_ops.sigmoid 3030 c, h = state 3031 3032 input_size = inputs.get_shape().with_rank(2).dims[1] 3033 if input_size.value is None: 3034 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 3035 3036 with vs.variable_scope(self._scope, initializer=self._initializer): 3037 3038 concat = self._linear( 3039 [inputs, h], 4 * num_units, norm=self._norm, bias=True) 3040 3041 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 3042 i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 3043 3044 if self._use_peepholes: 3045 w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) 3046 w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) 3047 w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) 3048 3049 new_c = ( 3050 c * sigmoid(f + self._forget_bias + w_f_diag * c) + 3051 sigmoid(i + w_i_diag * c) * self._activation(j)) 3052 else: 3053 new_c = ( 3054 c * sigmoid(f + self._forget_bias) + 3055 sigmoid(i) * self._activation(j)) 3056 3057 if self._cell_clip is not None: 3058 # pylint: disable=invalid-unary-operand-type 3059 new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) 3060 # pylint: enable=invalid-unary-operand-type 3061 if self._use_peepholes: 3062 new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) 3063 else: 3064 new_h = sigmoid(o) * self._activation(new_c) 3065 3066 if self._num_proj is not None: 3067 with vs.variable_scope("projection"): 3068 new_h = self._linear( 3069 new_h, self._num_proj, norm=self._norm, bias=False) 3070 3071 if self._proj_clip is not None: 3072 # pylint: disable=invalid-unary-operand-type 3073 new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, 3074 self._proj_clip) 3075 # pylint: enable=invalid-unary-operand-type 3076 3077 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 3078 return new_h, new_state 3079 3080 3081class IndRNNCell(rnn_cell_impl.LayerRNNCell): 3082 """Independently Recurrent Neural Network (IndRNN) cell 3083 (cf. https://arxiv.org/abs/1803.04831). 3084 3085 Args: 3086 num_units: int, The number of units in the RNN cell. 3087 activation: Nonlinearity to use. Default: `tanh`. 3088 reuse: (optional) Python boolean describing whether to reuse variables 3089 in an existing scope. If not `True`, and the existing scope already has 3090 the given variables, an error is raised. 3091 name: String, the name of the layer. Layers with the same name will 3092 share weights, but to avoid mistakes we require reuse=True in such 3093 cases. 3094 dtype: Default dtype of the layer (default of `None` means use the type 3095 of the first input). Required when `build` is called before `call`. 3096 """ 3097 3098 def __init__(self, 3099 num_units, 3100 activation=None, 3101 reuse=None, 3102 name=None, 3103 dtype=None): 3104 super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) 3105 3106 # Inputs must be 2-dimensional. 3107 self.input_spec = input_spec.InputSpec(ndim=2) 3108 3109 self._num_units = num_units 3110 self._activation = activation or math_ops.tanh 3111 3112 @property 3113 def state_size(self): 3114 return self._num_units 3115 3116 @property 3117 def output_size(self): 3118 return self._num_units 3119 3120 def build(self, inputs_shape): 3121 if tensor_shape.dimension_value(inputs_shape[1]) is None: 3122 raise ValueError( 3123 "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 3124 3125 input_depth = tensor_shape.dimension_value(inputs_shape[1]) 3126 # pylint: disable=protected-access 3127 self._kernel_w = self.add_variable( 3128 "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3129 shape=[input_depth, self._num_units]) 3130 self._kernel_u = self.add_variable( 3131 "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3132 shape=[1, self._num_units], 3133 initializer=init_ops.random_uniform_initializer( 3134 minval=-1, maxval=1, dtype=self.dtype)) 3135 self._bias = self.add_variable( 3136 rnn_cell_impl._BIAS_VARIABLE_NAME, 3137 shape=[self._num_units], 3138 initializer=init_ops.zeros_initializer(dtype=self.dtype)) 3139 # pylint: enable=protected-access 3140 3141 self.built = True 3142 3143 def call(self, inputs, state): 3144 """IndRNN: output = new_state = act(W * input + u * state + B).""" 3145 3146 gate_inputs = math_ops.matmul(inputs, self._kernel_w) + ( 3147 state * self._kernel_u) 3148 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 3149 output = self._activation(gate_inputs) 3150 return output, output 3151 3152 3153class IndyGRUCell(rnn_cell_impl.LayerRNNCell): 3154 r"""Independently Gated Recurrent Unit cell. 3155 3156 Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, 3157 yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and 3158 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal 3159 matrices, i.e. a Hadamard product with a single vector: 3160 3161 $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j + 3162 [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$ 3163 $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j + 3164 [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$ 3165 $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + 3166 [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ 3167 3168 where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU 3169 node sees only its own state, as opposed to seeing all states in the same 3170 layer. 3171 3172 Args: 3173 num_units: int, The number of units in the GRU cell. 3174 activation: Nonlinearity to use. Default: `tanh`. 3175 reuse: (optional) Python boolean describing whether to reuse variables 3176 in an existing scope. If not `True`, and the existing scope already has 3177 the given variables, an error is raised. 3178 kernel_initializer: (optional) The initializer to use for the weight 3179 matrices applied to the input. 3180 bias_initializer: (optional) The initializer to use for the bias. 3181 name: String, the name of the layer. Layers with the same name will 3182 share weights, but to avoid mistakes we require reuse=True in such 3183 cases. 3184 dtype: Default dtype of the layer (default of `None` means use the type 3185 of the first input). Required when `build` is called before `call`. 3186 """ 3187 3188 def __init__(self, 3189 num_units, 3190 activation=None, 3191 reuse=None, 3192 kernel_initializer=None, 3193 bias_initializer=None, 3194 name=None, 3195 dtype=None): 3196 super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) 3197 3198 # Inputs must be 2-dimensional. 3199 self.input_spec = input_spec.InputSpec(ndim=2) 3200 3201 self._num_units = num_units 3202 self._activation = activation or math_ops.tanh 3203 self._kernel_initializer = kernel_initializer 3204 self._bias_initializer = bias_initializer 3205 3206 @property 3207 def state_size(self): 3208 return self._num_units 3209 3210 @property 3211 def output_size(self): 3212 return self._num_units 3213 3214 def build(self, inputs_shape): 3215 if tensor_shape.dimension_value(inputs_shape[1]) is None: 3216 raise ValueError( 3217 "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 3218 3219 input_depth = tensor_shape.dimension_value(inputs_shape[1]) 3220 # pylint: disable=protected-access 3221 self._gate_kernel_w = self.add_variable( 3222 "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3223 shape=[input_depth, 2 * self._num_units], 3224 initializer=self._kernel_initializer) 3225 self._gate_kernel_u = self.add_variable( 3226 "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3227 shape=[1, 2 * self._num_units], 3228 initializer=init_ops.random_uniform_initializer( 3229 minval=-1, maxval=1, dtype=self.dtype)) 3230 self._gate_bias = self.add_variable( 3231 "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, 3232 shape=[2 * self._num_units], 3233 initializer=(self._bias_initializer 3234 if self._bias_initializer is not None else 3235 init_ops.constant_initializer(1.0, dtype=self.dtype))) 3236 self._candidate_kernel_w = self.add_variable( 3237 "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3238 shape=[input_depth, self._num_units], 3239 initializer=self._kernel_initializer) 3240 self._candidate_kernel_u = self.add_variable( 3241 "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3242 shape=[1, self._num_units], 3243 initializer=init_ops.random_uniform_initializer( 3244 minval=-1, maxval=1, dtype=self.dtype)) 3245 self._candidate_bias = self.add_variable( 3246 "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, 3247 shape=[self._num_units], 3248 initializer=(self._bias_initializer 3249 if self._bias_initializer is not None else 3250 init_ops.zeros_initializer(dtype=self.dtype))) 3251 # pylint: enable=protected-access 3252 3253 self.built = True 3254 3255 def call(self, inputs, state): 3256 """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells.""" 3257 3258 gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( 3259 gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) 3260 gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) 3261 3262 value = math_ops.sigmoid(gate_inputs) 3263 r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 3264 3265 r_state = r * state 3266 3267 candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + ( 3268 r_state * self._candidate_kernel_u) 3269 candidate = nn_ops.bias_add(candidate, self._candidate_bias) 3270 3271 c = self._activation(candidate) 3272 new_h = u * state + (1 - u) * c 3273 return new_h, new_h 3274 3275 3276class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): 3277 r"""Basic IndyLSTM recurrent network cell. 3278 3279 Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to 3280 BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\) 3281 matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a 3282 Hadamard product with a single vector: 3283 3284 $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ 3285 $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ 3286 $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$ 3287 $$c_t = f_t \circ c_{t-1} + 3288 i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ 3289 3290 where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM 3291 node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all 3292 states in the same layer. 3293 3294 We add forget_bias (default: 1) to the biases of the forget gate in order to 3295 reduce the scale of forgetting in the beginning of the training. 3296 3297 It does not allow cell clipping, a projection layer, and does not 3298 use peep-hole connections: it is the basic baseline. 3299 """ 3300 3301 def __init__(self, 3302 num_units, 3303 forget_bias=1.0, 3304 activation=None, 3305 reuse=None, 3306 kernel_initializer=None, 3307 bias_initializer=None, 3308 name=None, 3309 dtype=None): 3310 """Initialize the IndyLSTM cell. 3311 3312 Args: 3313 num_units: int, The number of units in the LSTM cell. 3314 forget_bias: float, The bias added to forget gates (see above). 3315 Must set to `0.0` manually when restoring from CudnnLSTM-trained 3316 checkpoints. 3317 activation: Activation function of the inner states. Default: `tanh`. 3318 reuse: (optional) Python boolean describing whether to reuse variables 3319 in an existing scope. If not `True`, and the existing scope already has 3320 the given variables, an error is raised. 3321 kernel_initializer: (optional) The initializer to use for the weight 3322 matrix applied to the inputs. 3323 bias_initializer: (optional) The initializer to use for the bias. 3324 name: String, the name of the layer. Layers with the same name will 3325 share weights, but to avoid mistakes we require reuse=True in such 3326 cases. 3327 dtype: Default dtype of the layer (default of `None` means use the type 3328 of the first input). Required when `build` is called before `call`. 3329 """ 3330 super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) 3331 3332 # Inputs must be 2-dimensional. 3333 self.input_spec = input_spec.InputSpec(ndim=2) 3334 3335 self._num_units = num_units 3336 self._forget_bias = forget_bias 3337 self._activation = activation or math_ops.tanh 3338 self._kernel_initializer = kernel_initializer 3339 self._bias_initializer = bias_initializer 3340 3341 @property 3342 def state_size(self): 3343 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 3344 3345 @property 3346 def output_size(self): 3347 return self._num_units 3348 3349 def build(self, inputs_shape): 3350 if tensor_shape.dimension_value(inputs_shape[1]) is None: 3351 raise ValueError( 3352 "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 3353 3354 input_depth = tensor_shape.dimension_value(inputs_shape[1]) 3355 # pylint: disable=protected-access 3356 self._kernel_w = self.add_variable( 3357 "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3358 shape=[input_depth, 4 * self._num_units], 3359 initializer=self._kernel_initializer) 3360 self._kernel_u = self.add_variable( 3361 "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3362 shape=[1, 4 * self._num_units], 3363 initializer=init_ops.random_uniform_initializer( 3364 minval=-1, maxval=1, dtype=self.dtype)) 3365 self._bias = self.add_variable( 3366 rnn_cell_impl._BIAS_VARIABLE_NAME, 3367 shape=[4 * self._num_units], 3368 initializer=(self._bias_initializer 3369 if self._bias_initializer is not None else 3370 init_ops.zeros_initializer(dtype=self.dtype))) 3371 # pylint: enable=protected-access 3372 3373 self.built = True 3374 3375 def call(self, inputs, state): 3376 """Independent Long short-term memory cell (IndyLSTM). 3377 3378 Args: 3379 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 3380 state: An `LSTMStateTuple` of state tensors, each shaped 3381 `[batch_size, num_units]`. 3382 3383 Returns: 3384 A pair containing the new hidden state, and the new state (a 3385 `LSTMStateTuple`). 3386 """ 3387 sigmoid = math_ops.sigmoid 3388 one = constant_op.constant(1, dtype=dtypes.int32) 3389 c, h = state 3390 3391 gate_inputs = math_ops.matmul(inputs, self._kernel_w) 3392 gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u 3393 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 3394 3395 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 3396 i, j, f, o = array_ops.split( 3397 value=gate_inputs, num_or_size_splits=4, axis=one) 3398 3399 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 3400 # Note that using `add` and `multiply` instead of `+` and `*` gives a 3401 # performance improvement. So using those at the cost of readability. 3402 add = math_ops.add 3403 multiply = math_ops.multiply 3404 new_c = add( 3405 multiply(c, sigmoid(add(f, forget_bias_tensor))), 3406 multiply(sigmoid(i), self._activation(j))) 3407 new_h = multiply(self._activation(new_c), sigmoid(o)) 3408 3409 new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 3410 return new_h, new_state 3411 3412 3413NTMControllerState = collections.namedtuple( 3414 "NTMControllerState", 3415 ("controller_state", "read_vector_list", "w_list", "M", "time")) 3416 3417 3418class NTMCell(rnn_cell_impl.LayerRNNCell): 3419 """Neural Turing Machine Cell with RNN controller. 3420 3421 Implementation based on: 3422 https://arxiv.org/abs/1807.08518 3423 Mark Collier, Joeran Beel 3424 3425 which is in turn based on the source code of: 3426 https://github.com/snowkylin/ntm 3427 3428 and of course the original NTM paper: 3429 Neural Turing Machines 3430 https://arxiv.org/abs/1410.5401 3431 A Graves, G Wayne, I Danihelka 3432 """ 3433 3434 def __init__(self, 3435 controller, 3436 memory_size, 3437 memory_vector_dim, 3438 read_head_num, 3439 write_head_num, 3440 shift_range=1, 3441 output_dim=None, 3442 clip_value=20, 3443 dtype=dtypes.float32, 3444 name=None): 3445 """Initialize the NTM Cell. 3446 3447 Args: 3448 controller: an RNNCell, the RNN controller. 3449 memory_size: int, The number of memory locations in the NTM memory 3450 matrix 3451 memory_vector_dim: int, The dimensionality of each location in the NTM 3452 memory matrix 3453 read_head_num: int, The number of read heads from the controller into 3454 memory 3455 write_head_num: int, The number of write heads from the controller into 3456 memory 3457 shift_range: int, The number of places to the left/right it is possible 3458 to iterate the previous address to in a single step 3459 output_dim: int, The number of dimensions to make a linear projection of 3460 the NTM controller outputs to. If None, no linear projection is 3461 applied 3462 clip_value: float, The maximum absolute value the controller parameters 3463 are clipped to 3464 dtype: Default dtype of the layer (default of `None` means use the type 3465 of the first input). Required when `build` is called before `call`. 3466 name: String, the name of the layer. Layers with the same name will 3467 share weights, but to avoid mistakes we require reuse=True in such 3468 cases. 3469 """ 3470 super(NTMCell, self).__init__(dtype=dtype, name=name) 3471 3472 rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller) 3473 3474 self.controller = controller 3475 self.memory_size = memory_size 3476 self.memory_vector_dim = memory_vector_dim 3477 self.read_head_num = read_head_num 3478 self.write_head_num = write_head_num 3479 self.clip_value = clip_value 3480 3481 self.output_dim = output_dim 3482 self.shift_range = shift_range 3483 3484 self.num_parameters_per_head = ( 3485 self.memory_vector_dim + 2 * self.shift_range + 4) 3486 self.num_heads = self.read_head_num + self.write_head_num 3487 self.total_parameter_num = ( 3488 self.num_parameters_per_head * self.num_heads + 3489 self.memory_vector_dim * 2 * self.write_head_num) 3490 3491 @property 3492 def state_size(self): 3493 return NTMControllerState( 3494 controller_state=self.controller.state_size, 3495 read_vector_list=[ 3496 self.memory_vector_dim for _ in range(self.read_head_num) 3497 ], 3498 w_list=[ 3499 self.memory_size 3500 for _ in range(self.read_head_num + self.write_head_num) 3501 ], 3502 M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]), 3503 time=tensor_shape.TensorShape([])) 3504 3505 @property 3506 def output_size(self): 3507 return self.output_dim 3508 3509 def build(self, inputs_shape): 3510 if self.output_dim is None: 3511 if inputs_shape[1].value is None: 3512 raise ValueError( 3513 "Expected inputs.shape[-1] to be known, saw shape: %s" % 3514 inputs_shape) 3515 else: 3516 self.output_dim = inputs_shape[1].value 3517 3518 def _create_linear_initializer(input_size, dtype=dtypes.float32): 3519 stddev = 1.0 / math.sqrt(input_size) 3520 return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype) 3521 3522 self._params_kernel = self.add_variable( 3523 "parameters_kernel", 3524 shape=[self.controller.output_size, self.total_parameter_num], 3525 initializer=_create_linear_initializer(self.controller.output_size)) 3526 3527 self._params_bias = self.add_variable( 3528 "parameters_bias", 3529 shape=[self.total_parameter_num], 3530 initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) 3531 3532 self._output_kernel = self.add_variable( 3533 "output_kernel", 3534 shape=[ 3535 self.controller.output_size + 3536 self.memory_vector_dim * self.read_head_num, self.output_dim 3537 ], 3538 initializer=_create_linear_initializer(self.controller.output_size + 3539 self.memory_vector_dim * 3540 self.read_head_num)) 3541 3542 self._output_bias = self.add_variable( 3543 "output_bias", 3544 shape=[self.output_dim], 3545 initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) 3546 3547 self._init_read_vectors = [ 3548 self.add_variable( 3549 "initial_read_vector_%d" % i, 3550 shape=[1, self.memory_vector_dim], 3551 initializer=initializers.glorot_uniform()) 3552 for i in range(self.read_head_num) 3553 ] 3554 3555 self._init_address_weights = [ 3556 self.add_variable( 3557 "initial_address_weights_%d" % i, 3558 shape=[1, self.memory_size], 3559 initializer=initializers.glorot_uniform()) 3560 for i in range(self.read_head_num + self.write_head_num) 3561 ] 3562 3563 self._M = self.add_variable( 3564 "memory", 3565 shape=[self.memory_size, self.memory_vector_dim], 3566 initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype)) 3567 3568 self.built = True 3569 3570 def call(self, x, prev_state): 3571 # Addressing Mechanisms (Sec 3.3) 3572 3573 def _prev_read_vector_list_initial_value(): 3574 return [ 3575 self._expand( 3576 math_ops.tanh( 3577 array_ops.squeeze( 3578 math_ops.matmul( 3579 array_ops.ones([1, 1]), self._init_read_vectors[i]))), 3580 dim=0, 3581 N=x.shape[0].value or array_ops.shape(x)[0]) 3582 for i in range(self.read_head_num) 3583 ] 3584 3585 prev_read_vector_list = control_flow_ops.cond( 3586 math_ops.equal(prev_state.time, 3587 0), _prev_read_vector_list_initial_value, lambda: 3588 prev_state.read_vector_list) 3589 if self.read_head_num == 1: 3590 prev_read_vector_list = [prev_read_vector_list] 3591 3592 controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1) 3593 controller_output, controller_state = self.controller( 3594 controller_input, prev_state.controller_state) 3595 3596 parameters = math_ops.matmul(controller_output, self._params_kernel) 3597 parameters = nn_ops.bias_add(parameters, self._params_bias) 3598 parameters = clip_ops.clip_by_value(parameters, -self.clip_value, 3599 self.clip_value) 3600 head_parameter_list = array_ops.split( 3601 parameters[:, :self.num_parameters_per_head * self.num_heads], 3602 self.num_heads, 3603 axis=1) 3604 erase_add_list = array_ops.split( 3605 parameters[:, self.num_parameters_per_head * self.num_heads:], 3606 2 * self.write_head_num, 3607 axis=1) 3608 3609 def _prev_w_list_initial_value(): 3610 return [ 3611 self._expand( 3612 nn_ops.softmax( 3613 array_ops.squeeze( 3614 math_ops.matmul( 3615 array_ops.ones([1, 1]), 3616 self._init_address_weights[i]))), 3617 dim=0, 3618 N=x.shape[0].value or array_ops.shape(x)[0]) 3619 for i in range(self.read_head_num + self.write_head_num) 3620 ] 3621 3622 prev_w_list = control_flow_ops.cond( 3623 math_ops.equal(prev_state.time, 0), 3624 _prev_w_list_initial_value, lambda: prev_state.w_list) 3625 if (self.read_head_num + self.write_head_num) == 1: 3626 prev_w_list = [prev_w_list] 3627 3628 prev_M = control_flow_ops.cond( 3629 math_ops.equal(prev_state.time, 0), lambda: self._expand( 3630 self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]), 3631 lambda: prev_state.M) 3632 3633 w_list = [] 3634 for i, head_parameter in enumerate(head_parameter_list): 3635 k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim]) 3636 beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim]) 3637 g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1]) 3638 s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim + 3639 2:(self.memory_vector_dim + 2 + 3640 (self.shift_range * 2 + 1))]) 3641 gamma = nn_ops.softplus(head_parameter[:, -1]) + 1 3642 w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i]) 3643 w_list.append(w) 3644 3645 # Reading (Sec 3.1) 3646 3647 read_w_list = w_list[:self.read_head_num] 3648 read_vector_list = [] 3649 for i in range(self.read_head_num): 3650 read_vector = math_ops.reduce_sum( 3651 array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1) 3652 read_vector_list.append(read_vector) 3653 3654 # Writing (Sec 3.2) 3655 3656 write_w_list = w_list[self.read_head_num:] 3657 M = prev_M 3658 for i in range(self.write_head_num): 3659 w = array_ops.expand_dims(write_w_list[i], axis=2) 3660 erase_vector = array_ops.expand_dims( 3661 math_ops.sigmoid(erase_add_list[i * 2]), axis=1) 3662 add_vector = array_ops.expand_dims( 3663 math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1) 3664 erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector) 3665 M = M * erase_M + math_ops.matmul(w, add_vector) 3666 3667 output = math_ops.matmul( 3668 array_ops.concat([controller_output] + read_vector_list, axis=1), 3669 self._output_kernel) 3670 output = nn_ops.bias_add(output, self._output_bias) 3671 output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value) 3672 3673 return output, NTMControllerState( 3674 controller_state=controller_state, 3675 read_vector_list=read_vector_list, 3676 w_list=w_list, 3677 M=M, 3678 time=prev_state.time + 1) 3679 3680 def _expand(self, x, dim, N): 3681 return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)], 3682 axis=dim) 3683 3684 def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w): 3685 # Sec 3.3.1 Focusing by Content 3686 3687 k = array_ops.expand_dims(k, axis=2) 3688 inner_product = math_ops.matmul(prev_M, k) 3689 k_norm = math_ops.sqrt( 3690 math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True)) 3691 M_norm = math_ops.sqrt( 3692 math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True)) 3693 norm_product = M_norm * k_norm 3694 3695 # eq (6) 3696 K = array_ops.squeeze(inner_product / (norm_product + 1e-8)) 3697 3698 K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K) 3699 3700 # eq (5) 3701 w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True) 3702 3703 # Sec 3.3.2 Focusing by Location 3704 3705 g = array_ops.expand_dims(g, axis=1) 3706 3707 # eq (7) 3708 w_g = g * w_c + (1 - g) * prev_w 3709 3710 s = array_ops.concat([ 3711 s[:, :self.shift_range + 1], 3712 array_ops.zeros([ 3713 s.shape[0].value or array_ops.shape(s)[0], self.memory_size - 3714 (self.shift_range * 2 + 1) 3715 ]), s[:, -self.shift_range:] 3716 ], 3717 axis=1) 3718 t = array_ops.concat( 3719 [array_ops.reverse(s, axis=[1]), 3720 array_ops.reverse(s, axis=[1])], 3721 axis=1) 3722 s_matrix = array_ops.stack([ 3723 t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1] 3724 for i in range(self.memory_size) 3725 ], 3726 axis=1) 3727 3728 # eq (8) 3729 w_ = math_ops.reduce_sum( 3730 array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2) 3731 w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1)) 3732 3733 # eq (9) 3734 w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True) 3735 3736 return w 3737 3738 def zero_state(self, batch_size, dtype): 3739 read_vector_list = [ 3740 array_ops.zeros([batch_size, self.memory_vector_dim]) 3741 for _ in range(self.read_head_num) 3742 ] 3743 3744 w_list = [ 3745 array_ops.zeros([batch_size, self.memory_size]) 3746 for _ in range(self.read_head_num + self.write_head_num) 3747 ] 3748 3749 controller_init_state = self.controller.zero_state(batch_size, dtype) 3750 3751 M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim]) 3752 3753 return NTMControllerState( 3754 controller_state=controller_init_state, 3755 read_vector_list=read_vector_list, 3756 w_list=w_list, 3757 M=M, 3758 time=0) 3759 3760 3761class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): 3762 """MinimalRNN cell. 3763 3764 The implementation is based on: 3765 3766 https://arxiv.org/pdf/1806.05394v2.pdf 3767 3768 Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz. 3769 "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal 3770 Propagation in Recurrent Neural Networks." ICML, 2018. 3771 3772 A MinimalRNN cell first projects the input to the hidden space. The new 3773 hidden state is then calculated as a weighted sum of the projected input and 3774 the previous hidden state, using a single update gate. 3775 """ 3776 3777 def __init__(self, 3778 units, 3779 activation="tanh", 3780 kernel_initializer="glorot_uniform", 3781 bias_initializer="ones", 3782 name=None, 3783 dtype=None, 3784 **kwargs): 3785 """Initialize the parameters for a MinimalRNN cell. 3786 3787 Args: 3788 units: int, The number of units in the MinimalRNN cell. 3789 activation: Nonlinearity to use in the feedforward network. Default: 3790 `tanh`. 3791 kernel_initializer: The initializer to use for the weight in the update 3792 gate and feedforward network. Default: `glorot_uniform`. 3793 bias_initializer: The initializer to use for the bias in the update 3794 gate. Default: `ones`. 3795 name: String, the name of the cell. 3796 dtype: Default dtype of the cell. 3797 **kwargs: Dict, keyword named properties for common cell attributes. 3798 """ 3799 super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs) 3800 3801 # Inputs must be 2-dimensional. 3802 self.input_spec = input_spec.InputSpec(ndim=2) 3803 3804 self.units = units 3805 self.activation = activations.get(activation) 3806 self.kernel_initializer = initializers.get(kernel_initializer) 3807 self.bias_initializer = initializers.get(bias_initializer) 3808 3809 @property 3810 def state_size(self): 3811 return self.units 3812 3813 @property 3814 def output_size(self): 3815 return self.units 3816 3817 def build(self, inputs_shape): 3818 if inputs_shape[-1] is None: 3819 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 3820 % str(inputs_shape)) 3821 3822 input_size = inputs_shape[-1] 3823 # pylint: disable=protected-access 3824 # self._kernel contains W_x, W, V 3825 self.kernel = self.add_weight( 3826 name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3827 shape=[input_size + 2 * self.units, self.units], 3828 initializer=self.kernel_initializer) 3829 self.bias = self.add_weight( 3830 name=rnn_cell_impl._BIAS_VARIABLE_NAME, 3831 shape=[self.units], 3832 initializer=self.bias_initializer) 3833 # pylint: enable=protected-access 3834 3835 self.built = True 3836 3837 def call(self, inputs, state): 3838 """Run one step of MinimalRNN. 3839 3840 Args: 3841 inputs: input Tensor, must be 2-D, `[batch, input_size]`. 3842 state: state Tensor, must be 2-D, `[batch, state_size]`. 3843 3844 Returns: 3845 A tuple containing: 3846 3847 - Output: A `2-D` tensor with shape `[batch_size, state_size]`. 3848 - New state: A `2-D` tensor with shape `[batch_size, state_size]`. 3849 3850 Raises: 3851 ValueError: If input size cannot be inferred from inputs via 3852 static shape inference. 3853 """ 3854 input_size = inputs.get_shape()[1] 3855 if tensor_shape.dimension_value(input_size) is None: 3856 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 3857 3858 feedforward_weight, gate_weight = array_ops.split( 3859 value=self.kernel, 3860 num_or_size_splits=[tensor_shape.dimension_value(input_size), 3861 2 * self.units], 3862 axis=0) 3863 3864 feedforward = math_ops.matmul(inputs, feedforward_weight) 3865 feedforward = self.activation(feedforward) 3866 3867 gate_inputs = math_ops.matmul( 3868 array_ops.concat([feedforward, state], 1), gate_weight) 3869 gate_inputs = nn_ops.bias_add(gate_inputs, self.bias) 3870 u = math_ops.sigmoid(gate_inputs) 3871 3872 new_h = u * state + (1 - u) * feedforward 3873 return new_h, new_h 3874 3875 3876class CFNCell(rnn_cell_impl.LayerRNNCell): 3877 """Chaos Free Network cell. 3878 3879 The implementation is based on: 3880 3881 https://openreview.net/pdf?id=S1dIzvclg 3882 3883 Thomas Laurent, James von Brecht. 3884 "A recurrent neural network without chaos." ICLR, 2017. 3885 3886 A CFN cell first projects the input to the hidden space. The hidden state 3887 goes through a contractive mapping. The new hidden state is then calculated 3888 as a linear combination of the projected input and the contracted previous 3889 hidden state, using decoupled input and forget gates. 3890 """ 3891 3892 def __init__(self, 3893 units, 3894 activation="tanh", 3895 kernel_initializer="glorot_uniform", 3896 bias_initializer="ones", 3897 name=None, 3898 dtype=None, 3899 **kwargs): 3900 """Initialize the parameters for a CFN cell. 3901 3902 Args: 3903 units: int, The number of units in the CFN cell. 3904 activation: Nonlinearity to use. Default: `tanh`. 3905 kernel_initializer: Initializer for the `kernel` weights 3906 matrix. Default: `glorot_uniform`. 3907 bias_initializer: The initializer to use for the bias in the 3908 gates. Default: `ones`. 3909 name: String, the name of the cell. 3910 dtype: Default dtype of the cell. 3911 **kwargs: Dict, keyword named properties for common cell attributes. 3912 """ 3913 super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs) 3914 3915 # Inputs must be 2-dimensional. 3916 self.input_spec = input_spec.InputSpec(ndim=2) 3917 3918 self.units = units 3919 self.activation = activations.get(activation) 3920 self.kernel_initializer = initializers.get(kernel_initializer) 3921 self.bias_initializer = initializers.get(bias_initializer) 3922 3923 @property 3924 def state_size(self): 3925 return self.units 3926 3927 @property 3928 def output_size(self): 3929 return self.units 3930 3931 def build(self, inputs_shape): 3932 if inputs_shape[-1] is None: 3933 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 3934 % str(inputs_shape)) 3935 3936 input_size = inputs_shape[-1] 3937 # pylint: disable=protected-access 3938 # `self.kernel` contains V_{\theta}, V_{\eta}, W. 3939 # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}. 3940 # `self.bias` contains b_{\theta}, b_{\eta}. 3941 self.kernel = self.add_weight( 3942 shape=[input_size, 3 * self.units], 3943 name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3944 initializer=self.kernel_initializer) 3945 self.recurrent_kernel = self.add_weight( 3946 shape=[self.units, 2 * self.units], 3947 name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 3948 initializer=self.kernel_initializer) 3949 self.bias = self.add_weight( 3950 shape=[2 * self.units], 3951 name=rnn_cell_impl._BIAS_VARIABLE_NAME, 3952 initializer=self.bias_initializer) 3953 # pylint: enable=protected-access 3954 3955 self.built = True 3956 3957 def call(self, inputs, state): 3958 """Run one step of CFN. 3959 3960 Args: 3961 inputs: input Tensor, must be 2-D, `[batch, input_size]`. 3962 state: state Tensor, must be 2-D, `[batch, state_size]`. 3963 3964 Returns: 3965 A tuple containing: 3966 3967 - Output: A `2-D` tensor with shape `[batch_size, state_size]`. 3968 - New state: A `2-D` tensor with shape `[batch_size, state_size]`. 3969 3970 Raises: 3971 ValueError: If input size cannot be inferred from inputs via 3972 static shape inference. 3973 """ 3974 input_size = inputs.get_shape()[-1] 3975 if tensor_shape.dimension_value(input_size) is None: 3976 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 3977 3978 # The variable names u, v, w, b are consistent with the notations in the 3979 # original paper. 3980 v, w = array_ops.split( 3981 value=self.kernel, 3982 num_or_size_splits=[2 * self.units, self.units], 3983 axis=1) 3984 u = self.recurrent_kernel 3985 b = self.bias 3986 3987 gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v) 3988 gates = nn_ops.bias_add(gates, b) 3989 gates = math_ops.sigmoid(gates) 3990 theta, eta = array_ops.split(value=gates, 3991 num_or_size_splits=2, 3992 axis=1) 3993 3994 proj_input = math_ops.matmul(inputs, w) 3995 3996 # The input gate is (1 - eta), which is different from the original paper. 3997 # This is for the propose of initialization. With the default 3998 # bias_initializer `ones`, the input gate is initialized to a small number. 3999 new_h = theta * self.activation(state) + (1 - eta) * self.activation( 4000 proj_input) 4001 4002 return new_h, new_h 4003