• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""TfLite BasicRnnCell wrapper.
17
18TODO(renjieliu): Find a better home for this one.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23import itertools
24
25from tensorflow.lite.python.op_hint import OpHint
26from tensorflow.python.layers import base as base_layer
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import clip_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops import partitioned_variables
33from tensorflow.python.ops import rnn_cell_impl
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.tf_export import tf_export
37
38
39@tf_export(v1=["lite.experimental.nn.TfLiteRNNCell"])
40@deprecation.deprecated(
41    None, "Use `keras.layers.RNN` instead for TF2.x.")
42class TfLiteRNNCell(rnn_cell_impl.LayerRNNCell):
43  """The most basic RNN cell.
44
45  This is used only for TfLite, it provides hints and it also makes the
46  variables in the desired for the tflite ops.
47  """
48
49  def __init__(self,
50               num_units,
51               activation=None,
52               reuse=None,
53               name=None,
54               dtype=None,
55               **kwargs):
56    """Initializes the parameters for an RNN cell.
57
58    Args:
59      num_units: int, The number of units in the RNN cell.
60      activation: Nonlinearity to use.  Default: `tanh`. It could also be string
61        that is within Keras activation function names.
62      reuse: (optional) Python boolean describing whether to reuse variables in
63        an existing scope. Raises an error if not `True` and the existing scope
64        already has the given variables.
65      name: String, the name of the layer. Layers with the same name will share
66        weights, but to avoid mistakes we require reuse=True in such cases.
67      dtype: Default dtype of the layer (default of `None` means use the type of
68        the first input). Required when `build` is called before `call`.
69      **kwargs: Dict, keyword named properties for common layer attributes, like
70        `trainable` etc when constructing the cell from configs of get_config().
71
72    Raises:
73      ValueError: If the existing scope already has the given variables.
74    """
75    super(TfLiteRNNCell, self).__init__(
76        _reuse=reuse, name=name, dtype=dtype, **kwargs)
77
78    # Inputs must be Rank-2.
79    self.input_spec = base_layer.InputSpec(ndim=2)
80
81    self._tflite_wrapper = OpHint("UnidirectionalSequenceRnn")
82    self._num_units = num_units
83    if activation:
84      if activation != "tanh":
85        raise ValueError("activation other than tanh is not supported")
86      self._activation = math_ops.tanh
87    else:
88      self._activation = math_ops.tanh
89
90  @property
91  def state_size(self):
92    return self._num_units
93
94  @property
95  def output_size(self):
96    return self._num_units
97
98  def build(self, inputs_shape):
99    """Builds the RNN cell.
100
101    Args:
102      inputs_shape: Rnn input tensor shape.
103
104    Raises:
105      ValueError: If last dimension of the input shape is not known.
106    """
107    if inputs_shape[-1] is None:
108      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" %
109                       (inputs_shape,))
110
111    input_depth = inputs_shape[-1]
112
113    def add_variable_wrapped(name, shape, initializer, index):
114      var = self.add_weight(name, shape=shape, initializer=initializer)
115      return self._tflite_wrapper.add_input(
116          var, name=name, index_override=index)
117
118    self._input_weights = add_variable_wrapped(
119        "input_weights", [self._num_units, input_depth], None, 1)
120    self._recurrent_weights = add_variable_wrapped(
121        "recurrent_weights", [self._num_units, self._num_units], None, 2)
122    self._bias = add_variable_wrapped(
123        "bias",
124        shape=[self._num_units],
125        initializer=init_ops.zeros_initializer(dtype=self.dtype),
126        index=3)
127
128    self.built = True
129
130  def call(self, inputs, state):
131    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
132    inputs = self._tflite_wrapper.add_input(
133        inputs, tag="input", name="input", aggregate="stack", index_override=0)
134    state = self._tflite_wrapper.add_input(
135        state,
136        tag="hidden_state",
137        name="hidden_state",
138        aggregate="first",
139        index_override=4)
140    weights = array_ops.transpose(
141        array_ops.concat([self._input_weights, self._recurrent_weights], 1))
142    gate_inputs = math_ops.matmul(array_ops.concat([inputs, state], 1), weights)
143    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
144    output = self._activation(gate_inputs)
145    output = self._tflite_wrapper.add_output(
146        output,
147        tag="output",
148        name="output",
149        index_override=1,
150        aggregate="stack")
151    return output, output
152
153  def get_config(self):
154    config = {
155        "num_units": self._num_units,
156        "activation": "tanh",
157        "reuse": self._reuse,
158    }
159    base_config = super(TfLiteRNNCell, self).get_config()
160    return dict(
161        itertools.chain(list(base_config.items()), list(config.items())))
162
163
164@tf_export(v1=["lite.experimental.nn.TFLiteLSTMCell"])
165@deprecation.deprecated(
166    None, "Use `keras.layers.LSTM` instead.")
167class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
168  """Long short-term memory unit (LSTM) recurrent network cell.
169
170  This is used only for TfLite, it provides hints and it also makes the
171  variables in the desired for the tflite ops  (transposed and separated).
172
173  The default non-peephole implementation is based on:
174
175    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
176
177  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
178  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
179
180  The peephole implementation is based on:
181
182    https://research.google.com/pubs/archive/43905.pdf
183
184  Hasim Sak, Andrew Senior, and Francoise Beaufays.
185  "Long short-term memory recurrent neural network architectures for
186   large scale acoustic modeling." INTERSPEECH, 2014.
187
188  The class uses optional peep-hole connections, optional cell clipping, and
189  an optional projection layer.
190
191  Note that this cell is not optimized for performance. Please use
192  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
193  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
194  better performance on CPU.
195  """
196
197  def __init__(self,
198               num_units,
199               use_peepholes=False,
200               cell_clip=None,
201               initializer=None,
202               num_proj=None,
203               proj_clip=None,
204               num_unit_shards=None,
205               num_proj_shards=None,
206               forget_bias=1.0,
207               state_is_tuple=True,
208               activation=None,
209               reuse=None,
210               name=None,
211               dtype=None):
212    """Initialize the parameters for an LSTM cell.
213
214    Args:
215      num_units: int, The number of units in the LSTM cell.
216      use_peepholes: bool, set True to enable diagonal/peephole connections.
217      cell_clip: (optional) A float value, if provided the cell state is clipped
218        by this value prior to the cell output activation.
219      initializer: (optional) The initializer to use for the weight and
220        projection matrices.
221      num_proj: (optional) int, The output dimensionality for the projection
222        matrices.  If None, no projection is performed.
223      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
224        provided, then the projected values are clipped elementwise to within
225        `[-proj_clip, proj_clip]`.
226      num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
227        variable_scope partitioner instead.
228      num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
229        variable_scope partitioner instead.
230      forget_bias: Biases of the forget gate are initialized by default to 1 in
231        order to reduce the scale of forgetting at the beginning of the
232        training. Must set it manually to `0.0` when restoring from CudnnLSTM
233        trained checkpoints.
234      state_is_tuple: If True, accepted and returned states are 2-tuples of the
235        `c_state` and `m_state`.  If False, they are concatenated along the
236        column axis.  This latter behavior will soon be deprecated.
237      activation: Activation function of the inner states.  Default: `tanh`.
238      reuse: (optional) Python boolean describing whether to reuse variables in
239        an existing scope.  If not `True`, and the existing scope already has
240        the given variables, an error is raised.
241      name: String, the name of the layer. Layers with the same name will share
242        weights, but to avoid mistakes we require reuse=True in such cases.
243      dtype: Default dtype of the layer (default of `None` means use the type of
244        the first input). Required when `build` is called before `call`.  When
245        restoring from CudnnLSTM-trained checkpoints, use
246        `CudnnCompatibleLSTMCell` instead.
247    """
248    super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
249    # TODO(raziel): decide if we want to just support tuples (yes please!).
250    if not state_is_tuple:
251      logging.warn(
252          "%s: Using a concatenated state is slower and will soon be "
253          "deprecated.  Use state_is_tuple=True.", self)
254    if num_unit_shards is not None or num_proj_shards is not None:
255      logging.warn(
256          "%s: The num_unit_shards and proj_unit_shards parameters are "
257          "deprecated and will be removed in Jan 2017.  "
258          "Use a variable scope with a partitioner instead.", self)
259
260    # Inputs must be 2-dimensional.
261    # TODO(raziel): layers stuff -- chop if un-layerizing Op.
262    self.input_spec = base_layer.InputSpec(ndim=2)
263
264    self._tflite_wrapper = OpHint("UnidirectionalSequenceLstm")
265
266    self._num_units = num_units
267    self._use_peepholes = use_peepholes
268    self._cell_clip = cell_clip
269    self._initializer = initializer
270    self._num_proj = num_proj
271    self._proj_clip = proj_clip
272    self._num_unit_shards = num_unit_shards
273    self._num_proj_shards = num_proj_shards
274    self._forget_bias = forget_bias
275    self._state_is_tuple = state_is_tuple
276    if activation:
277      if activation != "tanh":
278        raise ValueError("activation other than tanh is not supported")
279      self._activation = math_ops.tanh
280    else:
281      self._activation = math_ops.tanh
282
283    self._output_size = num_proj if num_proj else num_units
284    self._state_size = (
285        rnn_cell_impl.LSTMStateTuple(num_units, self._output_size)
286        if state_is_tuple else num_units + self._output_size)
287
288  @property
289  def state_size(self):
290    return self._state_size
291
292  @property
293  def output_size(self):
294    return self._output_size
295
296  def build(self, inputs_shape):
297    """Build TfLite LSTM cell graph.
298
299    Args:
300      inputs_shape: The inputs_shape must be known, and is [batch_size,
301        input_size] shape.
302
303    Raises:
304      ValueError: if the inputs_shape is invalid.
305    """
306    if len(inputs_shape) != 2:
307      raise ValueError(
308          "inputs_shape must be 2-dimensional, saw shape: %s" % inputs_shape)
309    input_depth = (
310        inputs_shape[1]
311        if isinstance(inputs_shape[1], int) else inputs_shape[1].value)
312    if input_depth is None:
313      raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape)
314
315    maybe_partitioner = (
316        partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
317        if self._num_unit_shards is not None else None)
318    input_weight_shape = [self._num_units, input_depth]
319    cell_weight_shape = [self._num_units, self._output_size]
320    bias_shape = [self._num_units]
321
322    def add_variable_wrapped(name, shape, initializer, index, partitioner):
323      var = self.add_weight(
324          name, shape=shape, initializer=initializer, partitioner=partitioner)
325      return self._tflite_wrapper.add_input(
326          var, name=name, index_override=index)
327
328    weight_initializer = self._initializer
329    if self.dtype is None:
330      bias_initializer = init_ops.zeros_initializer
331    else:
332      bias_initializer = init_ops.zeros_initializer(dtype=self.dtype)
333
334    forget_bias_initializer = init_ops.constant_initializer(self._forget_bias)
335
336    self.input_to_input_w = add_variable_wrapped(
337        "input_to_input_w", input_weight_shape, weight_initializer, 1,
338        maybe_partitioner)
339    self.input_to_forget_w = add_variable_wrapped(
340        "input_to_forget_w", input_weight_shape, weight_initializer, 2,
341        maybe_partitioner)
342    self.input_to_cell_w = add_variable_wrapped(
343        "input_to_cell_w", input_weight_shape, weight_initializer, 3,
344        maybe_partitioner)
345    self.input_to_output_w = add_variable_wrapped(
346        "input_to_output_w", input_weight_shape, weight_initializer, 4,
347        maybe_partitioner)
348    self.cell_to_input_w = add_variable_wrapped(
349        "cell_to_input_w", cell_weight_shape, weight_initializer, 5,
350        maybe_partitioner)
351    self.cell_to_forget_w = add_variable_wrapped(
352        "cell_to_forget_w", cell_weight_shape, weight_initializer, 6,
353        maybe_partitioner)
354    self.cell_to_cell_w = add_variable_wrapped(
355        "cell_to_cell_w", cell_weight_shape, weight_initializer, 7,
356        maybe_partitioner)
357    self.cell_to_output_w = add_variable_wrapped(
358        "cell_to_output_w", cell_weight_shape, weight_initializer, 8,
359        maybe_partitioner)
360
361    self.input_bias = add_variable_wrapped(
362        "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner)
363    self.forget_bias = add_variable_wrapped("forget_bias", bias_shape,
364                                            forget_bias_initializer, 13,
365                                            maybe_partitioner)
366    self.cell_bias = add_variable_wrapped(
367        "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner)
368    self.output_bias = add_variable_wrapped(
369        "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner)
370
371    # index 9, 10, 11.
372    # f stands for forget, i stands for input and o stands for output.
373    if self._use_peepholes:
374      self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units],
375                                            self._initializer, 10,
376                                            maybe_partitioner)
377      self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units],
378                                            self._initializer, 9,
379                                            maybe_partitioner)
380      self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units],
381                                            self._initializer, 11,
382                                            maybe_partitioner)
383
384    # index 16 for proj kernel.
385    if self._num_proj is not None:
386      maybe_proj_partitioner = (
387          partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
388          if self._num_proj_shards is not None else None)
389      self._proj_kernel = add_variable_wrapped(
390          "projection/kernel", [self._num_proj, self._num_units],
391          self._initializer,
392          16,
393          partitioner=maybe_proj_partitioner)
394
395    self.built = True
396
397  def call(self, inputs, state):
398    """Run one step of LSTM.
399
400    Args:
401      inputs: input Tensor, 2D, `[batch, num_units]`.
402      state: if `state_is_tuple` is False, this must be a state Tensor, `2-D,
403        [batch, state_size]`.  If `state_is_tuple` is True, this must be a tuple
404        of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`.
405
406    Returns:
407      A tuple containing:
408
409      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
410        LSTM after reading `inputs` when previous state was `state`.
411        Here output_dim is:
412           num_proj if num_proj was set,
413           num_units otherwise.
414      - Tensor(s) representing the new state of LSTM after reading `inputs` when
415        the previous state was `state`.  Same type and shape(s) as `state`.
416
417    Raises:
418      ValueError: If input size cannot be inferred from inputs via
419        static shape inference.
420    """
421    inputs = self._tflite_wrapper.add_input(
422        inputs, tag="input", name="input", aggregate="stack", index_override=0)
423
424    # Make sure inputs and bias_initializer has the same type.
425    assert inputs.dtype == self.input_to_input_w.dtype
426
427    num_proj = self._num_units if self._num_proj is None else self._num_proj
428    sigmoid = math_ops.sigmoid
429
430    if self._state_is_tuple:
431      (c_prev, m_prev) = state
432    else:
433      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
434      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
435
436    # Note: For TfLite, cell_state is at index 19 while activation state at
437    # index 18.
438    c_prev = self._tflite_wrapper.add_input(
439        c_prev,
440        tag="c_prev",
441        name="c_prev",
442        aggregate="first",
443        index_override=19)
444    m_prev = self._tflite_wrapper.add_input(
445        m_prev,
446        tag="m_prev",
447        name="m_prev",
448        aggregate="first",
449        index_override=18)
450
451    input_size = inputs.shape.with_rank(2).dims[1]
452    if input_size.value is None:
453      raise ValueError("Could not infer input size from inputs.shape[-1]")
454
455    inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1)
456
457    # i stands for input gate.
458    # f stands for forget gate activation.
459    # o outputs.
460    # j output of LSTM unit.
461    # c is the final state.
462    # m is the output.
463    i = nn_ops.bias_add(
464        math_ops.matmul(
465            inputs_and_m_prev,
466            array_ops.concat([self.input_to_input_w, self.cell_to_input_w],
467                             axis=1),
468            transpose_b=True), self.input_bias)
469    f = nn_ops.bias_add(
470        math_ops.matmul(
471            inputs_and_m_prev,
472            array_ops.concat([self.input_to_forget_w, self.cell_to_forget_w],
473                             axis=1),
474            transpose_b=True), self.forget_bias)
475    o = nn_ops.bias_add(
476        math_ops.matmul(
477            inputs_and_m_prev,
478            array_ops.concat([self.input_to_output_w, self.cell_to_output_w],
479                             axis=1),
480            transpose_b=True), self.output_bias)
481    j = nn_ops.bias_add(
482        math_ops.matmul(
483            inputs_and_m_prev,
484            array_ops.concat([self.input_to_cell_w, self.cell_to_cell_w],
485                             axis=1),
486            transpose_b=True), self.cell_bias)
487
488    # Diagonal connections
489    if self._use_peepholes:
490      c = (
491          sigmoid(f + self._w_f_diag * c_prev) * c_prev +
492          sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
493    else:
494      c = (sigmoid(f) * c_prev + sigmoid(i) * self._activation(j))
495
496    if self._cell_clip is not None:
497      # pylint: disable=invalid-unary-operand-type
498      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
499      # pylint: enable=invalid-unary-operand-type
500    if self._use_peepholes:
501      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
502    else:
503      m = sigmoid(o) * self._activation(c)
504
505    if self._num_proj is not None:
506      transposed_proj_kernel = array_ops.transpose(self._proj_kernel)
507      m = math_ops.matmul(m, transposed_proj_kernel)
508
509      if self._proj_clip is not None:
510        # pylint: disable=invalid-unary-operand-type
511        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
512        # pylint: enable=invalid-unary-operand-type
513
514    c = self._tflite_wrapper.add_output(
515        c, tag="c", name="c", aggregate="last", index_override=1)
516    m = self._tflite_wrapper.add_output(
517        m, tag="m", name="m", index_override=2, aggregate="stack")
518
519    new_state = (
520        rnn_cell_impl.LSTMStateTuple(c, m)
521        if self._state_is_tuple else array_ops.concat([c, m], 1))
522    return m, new_state
523
524  def get_config(self):
525    config = {
526        "num_units": self._num_units,
527        "use_peepholes": self._use_peepholes,
528        "cell_clip": self._cell_clip,
529        "num_proj": self._num_proj,
530        "proj_clip": self._proj_clip,
531        "num_unit_shards": self._num_unit_shards,
532        "num_proj_shards": self._num_proj_shards,
533        "forget_bias": self._forget_bias,
534        "state_is_tuple": self._state_is_tuple,
535        "activation": "tanh",
536        "reuse": self._reuse,
537    }
538    base_config = super(TFLiteLSTMCell, self).get_config()
539    return dict(list(base_config.items()) + list(config.items()))
540