1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""LSTM Block Cell ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.contrib.rnn.ops import gen_lstm_ops 25from tensorflow.contrib.util import loader 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.keras.engine import input_spec 29from tensorflow.python.layers import base as base_layer 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nn_ops 34from tensorflow.python.ops import rnn_cell_impl 35from tensorflow.python.platform import resource_loader 36 37_lstm_ops_so = loader.load_op_library( 38 resource_loader.get_path_to_datafile("_lstm_ops.so")) 39 40LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name 41 42 43# pylint: disable=invalid-name 44def _lstm_block_cell(x, 45 cs_prev, 46 h_prev, 47 w, 48 b, 49 wci=None, 50 wcf=None, 51 wco=None, 52 forget_bias=None, 53 cell_clip=None, 54 use_peephole=None, 55 name=None): 56 r"""Computes the LSTM cell forward propagation for 1 time step. 57 58 This implementation uses 1 weight matrix and 1 bias vector, and there's an 59 optional peephole connection. 60 61 This kernel op implements the following mathematical equations: 62 63 ```python 64 xh = [x, h_prev] 65 [i, ci, f, o] = xh * w + b 66 f = f + forget_bias 67 68 if not use_peephole: 69 wci = wcf = wco = 0 70 71 i = sigmoid(cs_prev * wci + i) 72 f = sigmoid(cs_prev * wcf + f) 73 ci = tanh(ci) 74 75 cs = ci .* i + cs_prev .* f 76 cs = clip(cs, cell_clip) 77 78 o = sigmoid(cs * wco + o) 79 co = tanh(cs) 80 h = co .* o 81 ``` 82 83 Args: 84 x: A `Tensor`. Must be one of the following types: `float32`. 85 The input to the LSTM cell, shape (batch_size, num_inputs). 86 cs_prev: A `Tensor`. Must have the same type as `x`. 87 Value of the cell state at previous time step. 88 h_prev: A `Tensor`. Must have the same type as `x`. 89 Output of the previous cell at previous time step. 90 w: A `Tensor`. Must have the same type as `x`. The weight matrix. 91 b: A `Tensor`. Must have the same type as `x`. The bias vector. 92 wci: A `Tensor`. Must have the same type as `x`. 93 The weight matrix for input gate peephole connection. 94 wcf: A `Tensor`. Must have the same type as `x`. 95 The weight matrix for forget gate peephole connection. 96 wco: A `Tensor`. Must have the same type as `x`. 97 The weight matrix for output gate peephole connection. 98 forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. 99 cell_clip: An optional `float`. Defaults to `-1` (no clipping). 100 Value to clip the 'cs' value to. Disable by setting to negative value. 101 use_peephole: An optional `bool`. Defaults to `False`. 102 Whether to use peephole weights. 103 name: A name for the operation (optional). 104 105 Returns: 106 A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). 107 i: A `Tensor`. Has the same type as `x`. The input gate. 108 cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh. 109 f: A `Tensor`. Has the same type as `x`. The forget gate. 110 o: A `Tensor`. Has the same type as `x`. The output gate. 111 ci: A `Tensor`. Has the same type as `x`. The cell input. 112 co: A `Tensor`. Has the same type as `x`. The cell after the tanh. 113 h: A `Tensor`. Has the same type as `x`. The output h vector. 114 115 Raises: 116 ValueError: If cell_size is None. 117 """ 118 if wci is None: 119 cell_size = cs_prev.get_shape().with_rank(2).dims[1].value 120 if cell_size is None: 121 raise ValueError("cell_size from `cs_prev` should not be None.") 122 wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) 123 wcf = wci 124 wco = wci 125 126 # pylint: disable=protected-access 127 return gen_lstm_ops.lstm_block_cell( 128 x=x, 129 cs_prev=cs_prev, 130 h_prev=h_prev, 131 w=w, 132 wci=wci, 133 wcf=wcf, 134 wco=wco, 135 b=b, 136 forget_bias=forget_bias, 137 cell_clip=cell_clip if cell_clip is not None else -1, 138 use_peephole=use_peephole, 139 name=name) 140 # pylint: enable=protected-access 141 142 143def _block_lstm(seq_len_max, 144 x, 145 w, 146 b, 147 cs_prev=None, 148 h_prev=None, 149 wci=None, 150 wcf=None, 151 wco=None, 152 forget_bias=None, 153 cell_clip=None, 154 use_peephole=None, 155 name=None): 156 r"""TODO(williamchan): add doc. 157 158 Args: 159 seq_len_max: A `Tensor` of type `int64`. 160 x: A list of at least 1 `Tensor` objects of the same type. 161 w: A `Tensor`. Must have the same type as `x`. 162 b: A `Tensor`. Must have the same type as `x`. 163 cs_prev: A `Tensor`. Must have the same type as `x`. 164 h_prev: A `Tensor`. Must have the same type as `x`. 165 wci: A `Tensor`. Must have the same type as `x`. 166 wcf: A `Tensor`. Must have the same type as `x`. 167 wco: A `Tensor`. Must have the same type as `x`. 168 forget_bias: An optional `float`. Defaults to `1`. 169 cell_clip: An optional `float`. Defaults to `-1` (no clipping). 170 use_peephole: An optional `bool`. Defaults to `False`. 171 name: A name for the operation (optional). 172 173 Returns: 174 A tuple of `Tensor` objects (i, cs, f, o, ci, co, h). 175 i: A list with the same number of `Tensor` objects as `x` of `Tensor` 176 objects of the same type as x. 177 cs: A list with the same number of `Tensor` objects as `x` of `Tensor` 178 objects of the same type as x. 179 f: A list with the same number of `Tensor` objects as `x` of `Tensor` 180 objects of the same type as x. 181 o: A list with the same number of `Tensor` objects as `x` of `Tensor` 182 objects of the same type as x. 183 ci: A list with the same number of `Tensor` objects as `x` of `Tensor` 184 objects of the same type as x. 185 co: A list with the same number of `Tensor` objects as `x` of `Tensor` 186 objects of the same type as x. 187 h: A list with the same number of `Tensor` objects as `x` of `Tensor` 188 objects of the same type as x. 189 190 Raises: 191 ValueError: If `b` does not have a valid shape. 192 """ 193 dtype = x[0].dtype 194 batch_size = x[0].get_shape().with_rank(2).dims[0].value 195 cell_size4 = b.get_shape().with_rank(1).dims[0].value 196 if cell_size4 is None: 197 raise ValueError("`b` shape must not be None.") 198 cell_size = cell_size4 / 4 199 zero_state = None 200 if cs_prev is None or h_prev is None: 201 zero_state = array_ops.constant( 202 0, dtype=dtype, shape=[batch_size, cell_size]) 203 if cs_prev is None: 204 cs_prev = zero_state 205 if h_prev is None: 206 h_prev = zero_state 207 if wci is None: 208 wci = array_ops.constant(0, dtype=dtype, shape=[cell_size]) 209 wcf = wci 210 wco = wci 211 212 # pylint: disable=protected-access 213 i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( 214 seq_len_max=seq_len_max, 215 x=array_ops.stack(x), 216 cs_prev=cs_prev, 217 h_prev=h_prev, 218 w=w, 219 wci=wci, 220 wcf=wcf, 221 wco=wco, 222 b=b, 223 forget_bias=forget_bias, 224 cell_clip=cell_clip if cell_clip is not None else -1, 225 name=name, 226 use_peephole=use_peephole) 227 228 return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack( 229 f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack( 230 co), array_ops.unstack(h) 231 # pylint: enable=protected-access 232 # pylint: enable=invalid-name 233 234 235_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] 236 237 238@ops.RegisterGradient("LSTMBlockCell") 239def _LSTMBlockCellGrad(op, *grad): 240 """Gradient for LSTMBlockCell.""" 241 (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs 242 (i, cs, f, o, ci, co, _) = op.outputs 243 (_, cs_grad, _, _, _, _, h_grad) = grad 244 245 batch_size = x.get_shape().with_rank(2).dims[0].value 246 if batch_size is None: 247 batch_size = -1 248 input_size = x.get_shape().with_rank(2).dims[1].value 249 if input_size is None: 250 raise ValueError("input_size from `x` should not be None.") 251 cell_size = cs_prev.get_shape().with_rank(2).dims[1].value 252 if cell_size is None: 253 raise ValueError("cell_size from `cs_prev` should not be None.") 254 255 (cs_prev_grad, dicfo, wci_grad, wcf_grad, 256 wco_grad) = gen_lstm_ops.lstm_block_cell_grad( 257 x, 258 cs_prev, 259 h_prev, 260 w, 261 wci, 262 wcf, 263 wco, 264 b, 265 i, 266 cs, 267 f, 268 o, 269 ci, 270 co, 271 cs_grad, 272 h_grad, 273 use_peephole=op.get_attr("use_peephole")) 274 275 # Backprop from dicfo to xh. 276 xh_grad = math_ops.matmul(dicfo, w, transpose_b=True) 277 278 x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size)) 279 x_grad.get_shape().merge_with(x.get_shape()) 280 281 h_prev_grad = array_ops.slice(xh_grad, (0, input_size), 282 (batch_size, cell_size)) 283 h_prev_grad.get_shape().merge_with(h_prev.get_shape()) 284 285 # Backprop from dicfo to w. 286 xh = array_ops.concat([x, h_prev], 1) 287 w_grad = math_ops.matmul(xh, dicfo, transpose_a=True) 288 w_grad.get_shape().merge_with(w.get_shape()) 289 290 # Backprop from dicfo to b. 291 b_grad = nn_ops.bias_add_grad(dicfo) 292 b_grad.get_shape().merge_with(b.get_shape()) 293 294 return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, 295 wco_grad, b_grad) 296 297 298@ops.RegisterGradient("BlockLSTM") 299def _BlockLSTMGrad(op, *grad): 300 """Gradient for BlockLSTM.""" 301 seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs 302 i, cs, f, o, ci, co, h = op.outputs 303 304 cs_grad = grad[1] 305 h_grad = grad[6] 306 307 (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, 308 b_grad) = gen_lstm_ops.block_lstm_grad( 309 seq_len_max, 310 x, 311 cs_prev, 312 h_prev, 313 w, 314 wci, 315 wcf, 316 wco, 317 b, 318 i, 319 cs, 320 f, 321 o, 322 ci, 323 co, 324 h, 325 cs_grad, 326 h_grad, 327 use_peephole=op.get_attr("use_peephole")) 328 329 return [ 330 None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, 331 wco_grad, b_grad 332 ] 333 334 335class LSTMBlockCell(LayerRNNCell): 336 """Basic LSTM recurrent network cell. 337 338 The implementation is based on: http://arxiv.org/abs/1409.2329. 339 340 We add `forget_bias` (default: 1) to the biases of the forget gate in order to 341 reduce the scale of forgetting in the beginning of the training. 342 343 Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much 344 faster. The weight and bias matrices should be compatible as long as the 345 variable scope matches. 346 """ 347 348 def __init__(self, 349 num_units, 350 forget_bias=1.0, 351 cell_clip=None, 352 use_peephole=False, 353 dtype=None, 354 reuse=None, 355 name="lstm_cell"): 356 """Initialize the basic LSTM cell. 357 358 Args: 359 num_units: int, The number of units in the LSTM cell. 360 forget_bias: float, The bias added to forget gates (see above). 361 cell_clip: An optional `float`. Defaults to `-1` (no clipping). 362 use_peephole: Whether to use peephole connections or not. 363 dtype: the variable dtype of this layer. Default to tf.float32. 364 reuse: (optional) boolean describing whether to reuse variables in an 365 existing scope. If not `True`, and the existing scope already has the 366 given variables, an error is raised. 367 name: String, the name of the layer. Layers with the same name will 368 share weights, but to avoid mistakes we require reuse=True in such 369 cases. By default this is "lstm_cell", for variable-name compatibility 370 with `tf.nn.rnn_cell.LSTMCell`. 371 372 When restoring from CudnnLSTM-trained checkpoints, must use 373 CudnnCompatibleLSTMBlockCell instead. 374 """ 375 super(LSTMBlockCell, self).__init__(_reuse=reuse, dtype=dtype, name=name) 376 self._num_units = num_units 377 self._forget_bias = forget_bias 378 self._use_peephole = use_peephole 379 self._cell_clip = cell_clip if cell_clip is not None else -1 380 self._names = { 381 "W": "kernel", 382 "b": "bias", 383 "wci": "w_i_diag", 384 "wcf": "w_f_diag", 385 "wco": "w_o_diag", 386 "scope": "lstm_cell" 387 } 388 # Inputs must be 2-dimensional. 389 self.input_spec = input_spec.InputSpec(ndim=2) 390 391 @property 392 def state_size(self): 393 return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 394 395 @property 396 def output_size(self): 397 return self._num_units 398 399 def build(self, inputs_shape): 400 if not inputs_shape.dims[1].value: 401 raise ValueError( 402 "Expecting inputs_shape[1] to be set: %s" % str(inputs_shape)) 403 input_size = inputs_shape.dims[1].value 404 self._kernel = self.add_variable( 405 self._names["W"], [input_size + self._num_units, self._num_units * 4]) 406 self._bias = self.add_variable( 407 self._names["b"], [self._num_units * 4], 408 initializer=init_ops.constant_initializer(0.0)) 409 if self._use_peephole: 410 self._w_i_diag = self.add_variable(self._names["wci"], [self._num_units]) 411 self._w_f_diag = self.add_variable(self._names["wcf"], [self._num_units]) 412 self._w_o_diag = self.add_variable(self._names["wco"], [self._num_units]) 413 414 self.built = True 415 416 def call(self, inputs, state): 417 """Long short-term memory cell (LSTM).""" 418 if len(state) != 2: 419 raise ValueError("Expecting state to be a tuple with length 2.") 420 421 if self._use_peephole: 422 wci = self._w_i_diag 423 wcf = self._w_f_diag 424 wco = self._w_o_diag 425 else: 426 wci = wcf = wco = array_ops.zeros([self._num_units], dtype=self.dtype) 427 428 (cs_prev, h_prev) = state 429 (_, cs, _, _, _, _, h) = _lstm_block_cell( 430 inputs, 431 cs_prev, 432 h_prev, 433 self._kernel, 434 self._bias, 435 wci=wci, 436 wcf=wcf, 437 wco=wco, 438 forget_bias=self._forget_bias, 439 cell_clip=self._cell_clip, 440 use_peephole=self._use_peephole) 441 442 new_state = rnn_cell_impl.LSTMStateTuple(cs, h) 443 return h, new_state 444 445 446@six.add_metaclass(abc.ABCMeta) 447class LSTMBlockWrapper(base_layer.Layer): 448 """This is a helper class that provides housekeeping for LSTM cells. 449 450 This may be useful for alternative LSTM and similar type of cells. 451 The subclasses must implement `_call_cell` method and `num_units` property. 452 """ 453 454 @abc.abstractproperty 455 def num_units(self): 456 """Number of units in this cell (output dimension).""" 457 pass 458 459 @abc.abstractmethod 460 def _call_cell(self, inputs, initial_cell_state, initial_output, dtype, 461 sequence_length): 462 """Run this LSTM on inputs, starting from the given state. 463 464 This method must be implemented by subclasses and does the actual work 465 of calling the cell. 466 467 Args: 468 inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` 469 initial_cell_state: initial value for cell state, shape `[batch_size, 470 self._num_units]` 471 initial_output: initial value of cell output, shape `[batch_size, 472 self._num_units]` 473 dtype: The data type for the initial state and expected output. 474 sequence_length: Specifies the length of each sequence in inputs. An int32 475 or int64 vector (tensor) size [batch_size], values in [0, time_len) or 476 None. 477 478 Returns: 479 A pair containing: 480 481 - State: A `3-D` tensor of shape `[time_len, batch_size, output_size]` 482 - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]` 483 """ 484 pass 485 486 def call(self, inputs, initial_state=None, dtype=None, sequence_length=None): 487 """Run this LSTM on inputs, starting from the given state. 488 489 Args: 490 inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. 491 initial_state: a tuple `(initial_cell_state, initial_output)` with tensors 492 of shape `[batch_size, self._num_units]`. If this is not provided, the 493 cell is expected to create a zero initial state of type `dtype`. 494 dtype: The data type for the initial state and expected output. Required 495 if `initial_state` is not provided or RNN state has a heterogeneous 496 dtype. 497 sequence_length: Specifies the length of each sequence in inputs. An 498 `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, 499 time_len).` 500 Defaults to `time_len` for each element. 501 502 Returns: 503 A pair containing: 504 505 - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]` 506 or a list of time_len tensors of shape `[batch_size, output_size]`, 507 to match the type of the `inputs`. 508 - Final state: a tuple `(cell_state, output)` matching `initial_state`. 509 510 Raises: 511 ValueError: in case of shape mismatches 512 """ 513 is_list = isinstance(inputs, list) 514 if is_list: 515 inputs = array_ops.stack(inputs) 516 inputs_shape = inputs.get_shape().with_rank(3) 517 if not inputs_shape[2]: 518 raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape) 519 batch_size = inputs_shape.dims[1].value 520 if batch_size is None: 521 batch_size = array_ops.shape(inputs)[1] 522 time_len = inputs_shape.dims[0].value 523 if time_len is None: 524 time_len = array_ops.shape(inputs)[0] 525 526 # Provide default values for initial_state and dtype 527 if initial_state is None: 528 if dtype is None: 529 raise ValueError("Either initial_state or dtype needs to be specified") 530 z = array_ops.zeros( 531 array_ops.stack([batch_size, self.num_units]), dtype=dtype) 532 initial_state = z, z 533 else: 534 if len(initial_state) != 2: 535 raise ValueError( 536 "Expecting initial_state to be a tuple with length 2 or None") 537 if dtype is None: 538 dtype = initial_state[0].dtype 539 540 # create the actual cell 541 if sequence_length is not None: 542 sequence_length = ops.convert_to_tensor(sequence_length) 543 initial_cell_state, initial_output = initial_state # pylint: disable=unpacking-non-sequence 544 cell_states, outputs = self._call_cell( 545 inputs, initial_cell_state, initial_output, dtype, sequence_length) 546 547 if sequence_length is not None: 548 # Mask out the part beyond sequence_length 549 mask = array_ops.transpose( 550 array_ops.sequence_mask(sequence_length, time_len, dtype=dtype), 551 [1, 0]) 552 mask = array_ops.tile( 553 array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units]) 554 outputs *= mask 555 # Prepend initial states to cell_states and outputs for indexing to work 556 # correctly,since we want to access the last valid state at 557 # sequence_length - 1, which can even be -1, corresponding to the 558 # initial state. 559 mod_cell_states = array_ops.concat( 560 [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0) 561 mod_outputs = array_ops.concat( 562 [array_ops.expand_dims(initial_output, [0]), outputs], 0) 563 final_cell_state = self._gather_states(mod_cell_states, sequence_length, 564 batch_size) 565 final_output = self._gather_states(mod_outputs, sequence_length, 566 batch_size) 567 else: 568 # No sequence_lengths used: final state is the last state 569 final_cell_state = cell_states[-1] 570 final_output = outputs[-1] 571 572 if is_list: 573 # Input was a list, so return a list 574 outputs = array_ops.unstack(outputs) 575 576 final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output) 577 return outputs, final_state 578 579 def _gather_states(self, data, indices, batch_size): 580 """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" 581 return array_ops.gather_nd( 582 data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1)) 583 584 585class LSTMBlockFusedCell(LSTMBlockWrapper): 586 """FusedRNNCell implementation of LSTM. 587 588 This is an extremely efficient LSTM implementation, that uses a single TF op 589 for the entire LSTM. It should be both faster and more memory-efficient than 590 LSTMBlockCell defined above. 591 592 The implementation is based on: http://arxiv.org/abs/1409.2329. 593 594 We add forget_bias (default: 1) to the biases of the forget gate in order to 595 reduce the scale of forgetting in the beginning of the training. 596 597 The variable naming is consistent with `rnn_cell_impl.LSTMCell`. 598 """ 599 600 def __init__(self, 601 num_units, 602 forget_bias=1.0, 603 cell_clip=None, 604 use_peephole=False, 605 reuse=None, 606 dtype=None, 607 name="lstm_fused_cell"): 608 """Initialize the LSTM cell. 609 610 Args: 611 num_units: int, The number of units in the LSTM cell. 612 forget_bias: float, The bias added to forget gates (see above). 613 cell_clip: clip the cell to this value. Defaults is no cell clipping. 614 use_peephole: Whether to use peephole connections or not. 615 reuse: (optional) boolean describing whether to reuse variables in an 616 existing scope. If not `True`, and the existing scope already has the 617 given variables, an error is raised. 618 dtype: the dtype of variables of this layer. 619 name: String, the name of the layer. Layers with the same name will 620 share weights, but to avoid mistakes we require reuse=True in such 621 cases. By default this is "lstm_cell", for variable-name compatibility 622 with `tf.nn.rnn_cell.LSTMCell`. 623 """ 624 super(LSTMBlockFusedCell, self).__init__( 625 _reuse=reuse, name=name, dtype=dtype) 626 self._num_units = num_units 627 self._forget_bias = forget_bias 628 self._cell_clip = cell_clip if cell_clip is not None else -1 629 self._use_peephole = use_peephole 630 631 # Inputs must be 3-dimensional. 632 self.input_spec = input_spec.InputSpec(ndim=3) 633 634 @property 635 def num_units(self): 636 """Number of units in this cell (output dimension).""" 637 return self._num_units 638 639 def build(self, input_shape): 640 input_size = input_shape.dims[2].value 641 self._kernel = self.add_variable( 642 "kernel", [input_size + self._num_units, self._num_units * 4]) 643 self._bias = self.add_variable( 644 "bias", [self._num_units * 4], 645 initializer=init_ops.constant_initializer(0.0)) 646 if self._use_peephole: 647 self._w_i_diag = self.add_variable("w_i_diag", [self._num_units]) 648 self._w_f_diag = self.add_variable("w_f_diag", [self._num_units]) 649 self._w_o_diag = self.add_variable("w_o_diag", [self._num_units]) 650 651 self.built = True 652 653 def _call_cell(self, 654 inputs, 655 initial_cell_state=None, 656 initial_output=None, 657 dtype=None, 658 sequence_length=None): 659 """Run this LSTM on inputs, starting from the given state. 660 661 Args: 662 inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` 663 initial_cell_state: initial value for cell state, shape `[batch_size, 664 self._num_units]` 665 initial_output: initial value of cell output, shape `[batch_size, 666 self._num_units]` 667 dtype: The data type for the initial state and expected output. 668 sequence_length: Specifies the length of each sequence in inputs. An 669 `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, 670 time_len)` or None. 671 672 Returns: 673 A pair containing: 674 675 - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size, 676 output_size]` 677 - Output (h): A `3-D` tensor of shape `[time_len, batch_size, 678 output_size]` 679 """ 680 681 inputs_shape = inputs.get_shape().with_rank(3) 682 time_len = inputs_shape.dims[0].value 683 if time_len is None: 684 time_len = array_ops.shape(inputs)[0] 685 686 if self._use_peephole: 687 wci = self._w_i_diag 688 wco = self._w_o_diag 689 wcf = self._w_f_diag 690 else: 691 wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) 692 693 if sequence_length is None: 694 max_seq_len = math_ops.cast(time_len, dtypes.int64) 695 else: 696 max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length), 697 dtypes.int64) 698 699 _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm( 700 seq_len_max=max_seq_len, 701 x=inputs, 702 cs_prev=initial_cell_state, 703 h_prev=initial_output, 704 w=self._kernel, 705 wci=wci, 706 wcf=wcf, 707 wco=wco, 708 b=self._bias, 709 forget_bias=self._forget_bias, 710 cell_clip=self._cell_clip, 711 use_peephole=self._use_peephole) 712 return cs, h 713