1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""TfLite BasicRnnCell wrapper. 17 18TODO(renjieliu): Find a better home for this one. 19""" 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23import itertools 24 25from tensorflow.lite.python.op_hint import OpHint 26from tensorflow.python.layers import base as base_layer 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import clip_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.ops import partitioned_variables 33from tensorflow.python.ops import rnn_cell_impl 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.tf_export import tf_export 37 38 39@tf_export(v1=["lite.experimental.nn.TfLiteRNNCell"]) 40@deprecation.deprecated( 41 None, "Use `keras.layers.RNN` instead for TF2.x.") 42class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell): 43 """The most basic RNN cell. 44 45 This is used only for TfLite, it provides hints and it also makes the 46 variables in the desired for the tflite ops. 47 """ 48 49 def __init__(self, 50 num_units, 51 activation=None, 52 reuse=None, 53 name=None, 54 dtype=None, 55 **kwargs): 56 """Initializes the parameters for an RNN cell. 57 58 Args: 59 num_units: int, The number of units in the RNN cell. 60 activation: Nonlinearity to use. Default: `tanh`. It could also be string 61 that is within Keras activation function names. 62 reuse: (optional) Python boolean describing whether to reuse variables in 63 an existing scope. Raises an error if not `True` and the existing scope 64 already has the given variables. 65 name: String, the name of the layer. Layers with the same name will share 66 weights, but to avoid mistakes we require reuse=True in such cases. 67 dtype: Default dtype of the layer (default of `None` means use the type of 68 the first input). Required when `build` is called before `call`. 69 **kwargs: Dict, keyword named properties for common layer attributes, like 70 `trainable` etc when constructing the cell from configs of get_config(). 71 72 Raises: 73 ValueError: If the existing scope already has the given variables. 74 """ 75 super(TfLiteRNNCell, self).__init__( 76 _reuse=reuse, name=name, dtype=dtype, **kwargs) 77 78 # Inputs must be Rank-2. 79 self.input_spec = base_layer.InputSpec(ndim=2) 80 81 self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn") 82 self._num_units = num_units 83 if activation: 84 if activation != "tanh": 85 raise ValueError("activation other than tanh is not supported") 86 self._activation = math_ops.tanh 87 else: 88 self._activation = math_ops.tanh 89 90 @property 91 def state_size(self): 92 return self._num_units 93 94 @property 95 def output_size(self): 96 return self._num_units 97 98 def build(self, inputs_shape): 99 """Builds the RNN cell. 100 101 Args: 102 inputs_shape: Rnn input tensor shape. 103 104 Raises: 105 ValueError: If last dimension of the input shape is not known. 106 """ 107 if inputs_shape[-1] is None: 108 raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % 109 (inputs_shape,)) 110 111 input_depth = inputs_shape[-1] 112 113 def add_variable_wrapped(name, shape, initializer, index): 114 var = self.add_weight(name, shape=shape, initializer=initializer) 115 return self._tflite_wrapper.add_input( 116 var, name=name, index_override=index) 117 118 self._input_weights = add_variable_wrapped( 119 "input_weights", [self._num_units, input_depth], None, 1) 120 self._recurrent_weights = add_variable_wrapped( 121 "recurrent_weights", [self._num_units, self._num_units], None, 2) 122 self._bias = add_variable_wrapped( 123 "bias", 124 shape=[self._num_units], 125 initializer=init_ops.zeros_initializer(dtype=self.dtype), 126 index=3) 127 128 self.built = True 129 130 def call(self, inputs, state): 131 """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 132 inputs = self._tflite_wrapper.add_input( 133 inputs, tag="input", name="input", aggregate="stack", index_override=0) 134 state = self._tflite_wrapper.add_input( 135 state, 136 tag="hidden_state", 137 name="hidden_state", 138 aggregate="first", 139 index_override=4) 140 weights = array_ops.transpose( 141 array_ops.concat([self._input_weights, self._recurrent_weights], 1)) 142 gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights) 143 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 144 output = self._activation(gate_inputs) 145 output = self._tflite_wrapper.add_output( 146 output, 147 tag="output", 148 name="output", 149 index_override=1, 150 aggregate="stack") 151 return output, output 152 153 def get_config(self): 154 config = { 155 "num_units": self._num_units, 156 "activation": "tanh", 157 "reuse": self._reuse, 158 } 159 base_config = super(TfLiteRNNCell, self).get_config() 160 return dict( 161 itertools.chain(list(base_config.items()), list(config.items()))) 162 163 164@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"]) 165@deprecation.deprecated( 166 None, "Use `keras.layers.LSTM` instead.") 167class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): 168 """Long short-term memory unit (LSTM) recurrent network cell. 169 170 This is used only for TfLite, it provides hints and it also makes the 171 variables in the desired for the tflite ops (transposed and separated). 172 173 The default non-peephole implementation is based on: 174 175 https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf 176 177 Felix Gers, Jurgen Schmidhuber, and Fred Cummins. 178 "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. 179 180 The peephole implementation is based on: 181 182 https://research.google.com/pubs/archive/43905.pdf 183 184 Hasim Sak, Andrew Senior, and Francoise Beaufays. 185 "Long short-term memory recurrent neural network architectures for 186 large scale acoustic modeling." INTERSPEECH, 2014. 187 188 The class uses optional peep-hole connections, optional cell clipping, and 189 an optional projection layer. 190 191 Note that this cell is not optimized for performance. Please use 192 `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or 193 `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for 194 better performance on CPU. 195 """ 196 197 def __init__(self, 198 num_units, 199 use_peepholes=False, 200 cell_clip=None, 201 initializer=None, 202 num_proj=None, 203 proj_clip=None, 204 num_unit_shards=None, 205 num_proj_shards=None, 206 forget_bias=1.0, 207 state_is_tuple=True, 208 activation=None, 209 reuse=None, 210 name=None, 211 dtype=None): 212 """Initialize the parameters for an LSTM cell. 213 214 Args: 215 num_units: int, The number of units in the LSTM cell. 216 use_peepholes: bool, set True to enable diagonal/peephole connections. 217 cell_clip: (optional) A float value, if provided the cell state is clipped 218 by this value prior to the cell output activation. 219 initializer: (optional) The initializer to use for the weight and 220 projection matrices. 221 num_proj: (optional) int, The output dimensionality for the projection 222 matrices. If None, no projection is performed. 223 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 224 provided, then the projected values are clipped elementwise to within 225 `[-proj_clip, proj_clip]`. 226 num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a 227 variable_scope partitioner instead. 228 num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a 229 variable_scope partitioner instead. 230 forget_bias: Biases of the forget gate are initialized by default to 1 in 231 order to reduce the scale of forgetting at the beginning of the 232 training. Must set it manually to `0.0` when restoring from CudnnLSTM 233 trained checkpoints. 234 state_is_tuple: If True, accepted and returned states are 2-tuples of the 235 `c_state` and `m_state`. If False, they are concatenated along the 236 column axis. This latter behavior will soon be deprecated. 237 activation: Activation function of the inner states. Default: `tanh`. 238 reuse: (optional) Python boolean describing whether to reuse variables in 239 an existing scope. If not `True`, and the existing scope already has 240 the given variables, an error is raised. 241 name: String, the name of the layer. Layers with the same name will share 242 weights, but to avoid mistakes we require reuse=True in such cases. 243 dtype: Default dtype of the layer (default of `None` means use the type of 244 the first input). Required when `build` is called before `call`. When 245 restoring from CudnnLSTM-trained checkpoints, use 246 `CudnnCompatibleLSTMCell` instead. 247 """ 248 super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) 249 # TODO(raziel): decide if we want to just support tuples (yes please!). 250 if not state_is_tuple: 251 logging.warn( 252 "%s: Using a concatenated state is slower and will soon be " 253 "deprecated. Use state_is_tuple=True.", self) 254 if num_unit_shards is not None or num_proj_shards is not None: 255 logging.warn( 256 "%s: The num_unit_shards and proj_unit_shards parameters are " 257 "deprecated and will be removed in Jan 2017. " 258 "Use a variable scope with a partitioner instead.", self) 259 260 # Inputs must be 2-dimensional. 261 # TODO(raziel): layers stuff -- chop if un-layerizing Op. 262 self.input_spec = base_layer.InputSpec(ndim=2) 263 264 self._tflite_wrapper = OpHint("UnidirectionalSequenceLstm") 265 266 self._num_units = num_units 267 self._use_peepholes = use_peepholes 268 self._cell_clip = cell_clip 269 self._initializer = initializer 270 self._num_proj = num_proj 271 self._proj_clip = proj_clip 272 self._num_unit_shards = num_unit_shards 273 self._num_proj_shards = num_proj_shards 274 self._forget_bias = forget_bias 275 self._state_is_tuple = state_is_tuple 276 if activation: 277 if activation != "tanh": 278 raise ValueError("activation other than tanh is not supported") 279 self._activation = math_ops.tanh 280 else: 281 self._activation = math_ops.tanh 282 283 self._output_size = num_proj if num_proj else num_units 284 self._state_size = ( 285 rnn_cell_impl.LSTMStateTuple(num_units, self._output_size) 286 if state_is_tuple else num_units + self._output_size) 287 288 @property 289 def state_size(self): 290 return self._state_size 291 292 @property 293 def output_size(self): 294 return self._output_size 295 296 def build(self, inputs_shape): 297 """Build TfLite LSTM cell graph. 298 299 Args: 300 inputs_shape: The inputs_shape must be known, and is [batch_size, 301 input_size] shape. 302 303 Raises: 304 ValueError: if the inputs_shape is invalid. 305 """ 306 if len(inputs_shape) != 2: 307 raise ValueError( 308 "inputs_shape must be 2-dimensional, saw shape: %s" % inputs_shape) 309 input_depth = ( 310 inputs_shape[1] 311 if isinstance(inputs_shape[1], int) else inputs_shape[1].value) 312 if input_depth is None: 313 raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) 314 315 maybe_partitioner = ( 316 partitioned_variables.fixed_size_partitioner(self._num_unit_shards) 317 if self._num_unit_shards is not None else None) 318 input_weight_shape = [self._num_units, input_depth] 319 cell_weight_shape = [self._num_units, self._output_size] 320 bias_shape = [self._num_units] 321 322 def add_variable_wrapped(name, shape, initializer, index, partitioner): 323 var = self.add_weight( 324 name, shape=shape, initializer=initializer, partitioner=partitioner) 325 return self._tflite_wrapper.add_input( 326 var, name=name, index_override=index) 327 328 weight_initializer = self._initializer 329 if self.dtype is None: 330 bias_initializer = init_ops.zeros_initializer 331 else: 332 bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) 333 334 forget_bias_initializer = init_ops.constant_initializer(self._forget_bias) 335 336 self.input_to_input_w = add_variable_wrapped( 337 "input_to_input_w", input_weight_shape, weight_initializer, 1, 338 maybe_partitioner) 339 self.input_to_forget_w = add_variable_wrapped( 340 "input_to_forget_w", input_weight_shape, weight_initializer, 2, 341 maybe_partitioner) 342 self.input_to_cell_w = add_variable_wrapped( 343 "input_to_cell_w", input_weight_shape, weight_initializer, 3, 344 maybe_partitioner) 345 self.input_to_output_w = add_variable_wrapped( 346 "input_to_output_w", input_weight_shape, weight_initializer, 4, 347 maybe_partitioner) 348 self.cell_to_input_w = add_variable_wrapped( 349 "cell_to_input_w", cell_weight_shape, weight_initializer, 5, 350 maybe_partitioner) 351 self.cell_to_forget_w = add_variable_wrapped( 352 "cell_to_forget_w", cell_weight_shape, weight_initializer, 6, 353 maybe_partitioner) 354 self.cell_to_cell_w = add_variable_wrapped( 355 "cell_to_cell_w", cell_weight_shape, weight_initializer, 7, 356 maybe_partitioner) 357 self.cell_to_output_w = add_variable_wrapped( 358 "cell_to_output_w", cell_weight_shape, weight_initializer, 8, 359 maybe_partitioner) 360 361 self.input_bias = add_variable_wrapped( 362 "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) 363 self.forget_bias = add_variable_wrapped("forget_bias", bias_shape, 364 forget_bias_initializer, 13, 365 maybe_partitioner) 366 self.cell_bias = add_variable_wrapped( 367 "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) 368 self.output_bias = add_variable_wrapped( 369 "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner) 370 371 # index 9, 10, 11. 372 # f stands for forget, i stands for input and o stands for output. 373 if self._use_peepholes: 374 self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], 375 self._initializer, 10, 376 maybe_partitioner) 377 self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], 378 self._initializer, 9, 379 maybe_partitioner) 380 self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], 381 self._initializer, 11, 382 maybe_partitioner) 383 384 # index 16 for proj kernel. 385 if self._num_proj is not None: 386 maybe_proj_partitioner = ( 387 partitioned_variables.fixed_size_partitioner(self._num_proj_shards) 388 if self._num_proj_shards is not None else None) 389 self._proj_kernel = add_variable_wrapped( 390 "projection/kernel", [self._num_proj, self._num_units], 391 self._initializer, 392 16, 393 partitioner=maybe_proj_partitioner) 394 395 self.built = True 396 397 def call(self, inputs, state): 398 """Run one step of LSTM. 399 400 Args: 401 inputs: input Tensor, 2D, `[batch, num_units]`. 402 state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, 403 [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple 404 of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. 405 406 Returns: 407 A tuple containing: 408 409 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 410 LSTM after reading `inputs` when previous state was `state`. 411 Here output_dim is: 412 num_proj if num_proj was set, 413 num_units otherwise. 414 - Tensor(s) representing the new state of LSTM after reading `inputs` when 415 the previous state was `state`. Same type and shape(s) as `state`. 416 417 Raises: 418 ValueError: If input size cannot be inferred from inputs via 419 static shape inference. 420 """ 421 inputs = self._tflite_wrapper.add_input( 422 inputs, tag="input", name="input", aggregate="stack", index_override=0) 423 424 # Make sure inputs and bias_initializer has the same type. 425 assert inputs.dtype == self.input_to_input_w.dtype 426 427 num_proj = self._num_units if self._num_proj is None else self._num_proj 428 sigmoid = math_ops.sigmoid 429 430 if self._state_is_tuple: 431 (c_prev, m_prev) = state 432 else: 433 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 434 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 435 436 # Note: For TfLite, cell_state is at index 19 while activation state at 437 # index 18. 438 c_prev = self._tflite_wrapper.add_input( 439 c_prev, 440 tag="c_prev", 441 name="c_prev", 442 aggregate="first", 443 index_override=19) 444 m_prev = self._tflite_wrapper.add_input( 445 m_prev, 446 tag="m_prev", 447 name="m_prev", 448 aggregate="first", 449 index_override=18) 450 451 input_size = inputs.shape.with_rank(2).dims[1] 452 if input_size.value is None: 453 raise ValueError("Could not infer input size from inputs.shape[-1]") 454 455 inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) 456 457 # i stands for input gate. 458 # f stands for forget gate activation. 459 # o outputs. 460 # j output of LSTM unit. 461 # c is the final state. 462 # m is the output. 463 i = nn_ops.bias_add( 464 math_ops.matmul( 465 inputs_and_m_prev, 466 array_ops.concat([self.input_to_input_w, self.cell_to_input_w], 467 axis=1), 468 transpose_b=True), self.input_bias) 469 f = nn_ops.bias_add( 470 math_ops.matmul( 471 inputs_and_m_prev, 472 array_ops.concat([self.input_to_forget_w, self.cell_to_forget_w], 473 axis=1), 474 transpose_b=True), self.forget_bias) 475 o = nn_ops.bias_add( 476 math_ops.matmul( 477 inputs_and_m_prev, 478 array_ops.concat([self.input_to_output_w, self.cell_to_output_w], 479 axis=1), 480 transpose_b=True), self.output_bias) 481 j = nn_ops.bias_add( 482 math_ops.matmul( 483 inputs_and_m_prev, 484 array_ops.concat([self.input_to_cell_w, self.cell_to_cell_w], 485 axis=1), 486 transpose_b=True), self.cell_bias) 487 488 # Diagonal connections 489 if self._use_peepholes: 490 c = ( 491 sigmoid(f + self._w_f_diag * c_prev) * c_prev + 492 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 493 else: 494 c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j)) 495 496 if self._cell_clip is not None: 497 # pylint: disable=invalid-unary-operand-type 498 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 499 # pylint: enable=invalid-unary-operand-type 500 if self._use_peepholes: 501 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 502 else: 503 m = sigmoid(o) * self._activation(c) 504 505 if self._num_proj is not None: 506 transposed_proj_kernel = array_ops.transpose(self._proj_kernel) 507 m = math_ops.matmul(m, transposed_proj_kernel) 508 509 if self._proj_clip is not None: 510 # pylint: disable=invalid-unary-operand-type 511 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 512 # pylint: enable=invalid-unary-operand-type 513 514 c = self._tflite_wrapper.add_output( 515 c, tag="c", name="c", aggregate="last", index_override=1) 516 m = self._tflite_wrapper.add_output( 517 m, tag="m", name="m", index_override=2, aggregate="stack") 518 519 new_state = ( 520 rnn_cell_impl.LSTMStateTuple(c, m) 521 if self._state_is_tuple else array_ops.concat([c, m], 1)) 522 return m, new_state 523 524 def get_config(self): 525 config = { 526 "num_units": self._num_units, 527 "use_peepholes": self._use_peepholes, 528 "cell_clip": self._cell_clip, 529 "num_proj": self._num_proj, 530 "proj_clip": self._proj_clip, 531 "num_unit_shards": self._num_unit_shards, 532 "num_proj_shards": self._num_proj_shards, 533 "forget_bias": self._forget_bias, 534 "state_is_tuple": self._state_is_tuple, 535 "activation": "tanh", 536 "reuse": self._reuse, 537 } 538 base_config = super(TFLiteLSTMCell, self).get_config() 539 return dict(list(base_config.items()) + list(config.items())) 540