• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module implementing RNN Cells with pruning.
16
17This module implements BasicLSTMCell and LSTMCell with pruning.
18Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.model_pruning.python.layers import core_layers
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import clip_ops
31from tensorflow.python.ops import init_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn_ops
34from tensorflow.python.ops import rnn_cell as tf_rnn
35
36
37class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
38  """Basic LSTM recurrent network cell with pruning.
39
40  Overrides the call method of tensorflow BasicLSTMCell and injects the weight
41  masks
42
43  The implementation is based on: http://arxiv.org/abs/1409.2329.
44
45  We add forget_bias (default: 1) to the biases of the forget gate in order to
46  reduce the scale of forgetting in the beginning of the training.
47
48  It does not allow cell clipping, a projection layer, and does not
49  use peep-hole connections: it is the basic baseline.
50
51  For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell`
52  that follows.
53  """
54
55  def __init__(self,
56               num_units,
57               forget_bias=1.0,
58               state_is_tuple=True,
59               activation=None,
60               reuse=None,
61               name=None):
62    """Initialize the basic LSTM cell with pruning.
63
64    Args:
65      num_units: int, The number of units in the LSTM cell.
66      forget_bias: float, The bias added to forget gates (see above).
67        Must set to `0.0` manually when restoring from CudnnLSTM-trained
68        checkpoints.
69      state_is_tuple: If True, accepted and returned states are 2-tuples of
70        the `c_state` and `m_state`.  If False, they are concatenated
71        along the column axis.  The latter behavior will soon be deprecated.
72      activation: Activation function of the inner states.  Default: `tanh`.
73      reuse: (optional) Python boolean describing whether to reuse variables
74        in an existing scope.  If not `True`, and the existing scope already has
75        the given variables, an error is raised.
76      name: String, the name of the layer. Layers with the same name will
77        share weights, but to avoid mistakes we require reuse=True in such
78        cases.
79
80      When restoring from CudnnLSTM-trained checkpoints, must use
81      CudnnCompatibleLSTMCell instead.
82    """
83    super(MaskedBasicLSTMCell, self).__init__(
84        num_units,
85        forget_bias=forget_bias,
86        state_is_tuple=state_is_tuple,
87        activation=activation,
88        reuse=reuse,
89        name=name)
90
91  def build(self, inputs_shape):
92    # Call the build method of the parent class.
93    super(MaskedBasicLSTMCell, self).build(inputs_shape)
94
95    self.built = False
96
97    input_depth = inputs_shape.dims[1].value
98    h_depth = self._num_units
99    self._mask = self.add_variable(
100        name="mask",
101        shape=[input_depth + h_depth, 4 * h_depth],
102        initializer=init_ops.ones_initializer(),
103        trainable=False,
104        dtype=self.dtype)
105    self._threshold = self.add_variable(
106        name="threshold",
107        shape=[],
108        initializer=init_ops.zeros_initializer(),
109        trainable=False,
110        dtype=self.dtype)
111    # Add masked_weights in the weights namescope so as to make it easier
112    # for the quantization library to add quant ops.
113    self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
114                                            core_layers.MASKED_WEIGHT_NAME)
115    if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
116      ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
117      ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
118                            self._masked_kernel)
119      ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
120      ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
121
122    self.built = True
123
124  def call(self, inputs, state):
125    """Long short-term memory cell (LSTM) with masks for pruning.
126
127    Args:
128      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
129      state: An `LSTMStateTuple` of state tensors, each shaped
130        `[batch_size, self.state_size]`, if `state_is_tuple` has been set to
131        `True`.  Otherwise, a `Tensor` shaped
132        `[batch_size, 2 * self.state_size]`.
133
134    Returns:
135      A pair containing the new hidden state, and the new state (either a
136        `LSTMStateTuple` or a concatenated state, depending on
137        `state_is_tuple`).
138    """
139    sigmoid = math_ops.sigmoid
140    one = constant_op.constant(1, dtype=dtypes.int32)
141    # Parameters of gates are concatenated into one multiply for efficiency.
142    if self._state_is_tuple:
143      c, h = state
144    else:
145      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
146
147    gate_inputs = math_ops.matmul(
148        array_ops.concat([inputs, h], 1), self._masked_kernel)
149    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
150
151    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
152    i, j, f, o = array_ops.split(
153        value=gate_inputs, num_or_size_splits=4, axis=one)
154
155    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
156    # Note that using `add` and `multiply` instead of `+` and `*` gives a
157    # performance improvement. So using those at the cost of readability.
158    add = math_ops.add
159    multiply = math_ops.multiply
160    new_c = add(
161        multiply(c, sigmoid(add(f, forget_bias_tensor))),
162        multiply(sigmoid(i), self._activation(j)))
163    new_h = multiply(self._activation(new_c), sigmoid(o))
164
165    if self._state_is_tuple:
166      new_state = tf_rnn.LSTMStateTuple(new_c, new_h)
167    else:
168      new_state = array_ops.concat([new_c, new_h], 1)
169    return new_h, new_state
170
171
172class MaskedLSTMCell(tf_rnn.LSTMCell):
173  """LSTMCell with pruning.
174
175  Overrides the call method of tensorflow LSTMCell and injects the weight masks.
176  Masks are applied to only the weight matrix of the LSTM and not the
177  projection matrix.
178  """
179
180  def __init__(self,
181               num_units,
182               use_peepholes=False,
183               cell_clip=None,
184               initializer=None,
185               num_proj=None,
186               proj_clip=None,
187               num_unit_shards=None,
188               num_proj_shards=None,
189               forget_bias=1.0,
190               state_is_tuple=True,
191               activation=None,
192               reuse=None):
193    """Initialize the parameters for an LSTM cell with masks for pruning.
194
195    Args:
196      num_units: int, The number of units in the LSTM cell
197      use_peepholes: bool, set True to enable diagonal/peephole connections.
198      cell_clip: (optional) A float value, if provided the cell state is clipped
199        by this value prior to the cell output activation.
200      initializer: (optional) The initializer to use for the weight and
201        projection matrices.
202      num_proj: (optional) int, The output dimensionality for the projection
203        matrices.  If None, no projection is performed.
204      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
205        provided, then the projected values are clipped elementwise to within
206        `[-proj_clip, proj_clip]`.
207      num_unit_shards: Deprecated, will be removed by Jan. 2017.
208        Use a variable_scope partitioner instead.
209      num_proj_shards: Deprecated, will be removed by Jan. 2017.
210        Use a variable_scope partitioner instead.
211      forget_bias: Biases of the forget gate are initialized by default to 1
212        in order to reduce the scale of forgetting at the beginning of
213        the training. Must set it manually to `0.0` when restoring from
214        CudnnLSTM trained checkpoints.
215      state_is_tuple: If True, accepted and returned states are 2-tuples of
216        the `c_state` and `m_state`.  If False, they are concatenated
217        along the column axis.  This latter behavior will soon be deprecated.
218      activation: Activation function of the inner states.  Default: `tanh`.
219      reuse: (optional) Python boolean describing whether to reuse variables
220        in an existing scope.  If not `True`, and the existing scope already has
221        the given variables, an error is raised.
222
223      When restoring from CudnnLSTM-trained checkpoints, must use
224      CudnnCompatibleLSTMCell instead.
225    """
226    super(MaskedLSTMCell, self).__init__(
227        num_units,
228        use_peepholes=use_peepholes,
229        cell_clip=cell_clip,
230        initializer=initializer,
231        num_proj=num_proj,
232        proj_clip=proj_clip,
233        num_unit_shards=num_unit_shards,
234        num_proj_shards=num_proj_shards,
235        forget_bias=forget_bias,
236        state_is_tuple=state_is_tuple,
237        activation=activation,
238        reuse=reuse)
239
240  def build(self, inputs_shape):
241    # Call the build method of the parent class.
242    super(MaskedLSTMCell, self).build(inputs_shape)
243
244    self.built = False
245
246    input_depth = inputs_shape.dims[1].value
247    h_depth = self._num_units
248    self._mask = self.add_variable(
249        name="mask",
250        shape=[input_depth + h_depth, 4 * h_depth],
251        initializer=init_ops.ones_initializer(),
252        trainable=False,
253        dtype=self.dtype)
254    self._threshold = self.add_variable(
255        name="threshold",
256        shape=[],
257        initializer=init_ops.zeros_initializer(),
258        trainable=False,
259        dtype=self.dtype)
260    # Add masked_weights in the weights namescope so as to make it easier
261    # for the quantization library to add quant ops.
262    self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
263                                            core_layers.MASKED_WEIGHT_NAME)
264    if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
265      ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
266      ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
267                            self._masked_kernel)
268      ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
269      ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
270
271    self.built = True
272
273  def call(self, inputs, state):
274    """Run one step of LSTM.
275
276    Args:
277      inputs: input Tensor, 2D, `[batch, num_units].
278      state: if `state_is_tuple` is False, this must be a state Tensor,
279        `2-D, [batch, state_size]`.  If `state_is_tuple` is True, this must be a
280        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
281        `m_state`.
282
283    Returns:
284      A tuple containing:
285
286      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
287        LSTM after reading `inputs` when previous state was `state`.
288        Here output_dim is:
289           num_proj if num_proj was set,
290           num_units otherwise.
291      - Tensor(s) representing the new state of LSTM after reading `inputs` when
292        the previous state was `state`.  Same type and shape(s) as `state`.
293
294    Raises:
295      ValueError: If input size cannot be inferred from inputs via
296        static shape inference.
297    """
298    num_proj = self._num_units if self._num_proj is None else self._num_proj
299    sigmoid = math_ops.sigmoid
300
301    if self._state_is_tuple:
302      (c_prev, m_prev) = state
303    else:
304      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
305      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
306
307    input_size = inputs.get_shape().with_rank(2).dims[1]
308    if input_size.value is None:
309      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
310
311    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
312    lstm_matrix = math_ops.matmul(
313        array_ops.concat([inputs, m_prev], 1), self._masked_kernel)
314    lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
315
316    i, j, f, o = array_ops.split(
317        value=lstm_matrix, num_or_size_splits=4, axis=1)
318    # Diagonal connections
319    if self._use_peepholes:
320      c = (
321          sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
322          sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
323    else:
324      c = (
325          sigmoid(f + self._forget_bias) * c_prev +
326          sigmoid(i) * self._activation(j))
327
328    if self._cell_clip is not None:
329      # pylint: disable=invalid-unary-operand-type
330      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
331      # pylint: enable=invalid-unary-operand-type
332    if self._use_peepholes:
333      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
334    else:
335      m = sigmoid(o) * self._activation(c)
336
337    if self._num_proj is not None:
338      m = math_ops.matmul(m, self._proj_kernel)
339
340      if self._proj_clip is not None:
341        # pylint: disable=invalid-unary-operand-type
342        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
343        # pylint: enable=invalid-unary-operand-type
344
345    new_state = (
346        tf_rnn.LSTMStateTuple(c, m)
347        if self._state_is_tuple else array_ops.concat([c, m], 1))
348    return m, new_state
349