1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module implementing RNN Cells with pruning. 16 17This module implements BasicLSTMCell and LSTMCell with pruning. 18Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25from tensorflow.contrib.model_pruning.python.layers import core_layers 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import clip_ops 31from tensorflow.python.ops import init_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nn_ops 34from tensorflow.python.ops import rnn_cell as tf_rnn 35 36 37class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): 38 """Basic LSTM recurrent network cell with pruning. 39 40 Overrides the call method of tensorflow BasicLSTMCell and injects the weight 41 masks 42 43 The implementation is based on: http://arxiv.org/abs/1409.2329. 44 45 We add forget_bias (default: 1) to the biases of the forget gate in order to 46 reduce the scale of forgetting in the beginning of the training. 47 48 It does not allow cell clipping, a projection layer, and does not 49 use peep-hole connections: it is the basic baseline. 50 51 For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` 52 that follows. 53 """ 54 55 def __init__(self, 56 num_units, 57 forget_bias=1.0, 58 state_is_tuple=True, 59 activation=None, 60 reuse=None, 61 name=None): 62 """Initialize the basic LSTM cell with pruning. 63 64 Args: 65 num_units: int, The number of units in the LSTM cell. 66 forget_bias: float, The bias added to forget gates (see above). 67 Must set to `0.0` manually when restoring from CudnnLSTM-trained 68 checkpoints. 69 state_is_tuple: If True, accepted and returned states are 2-tuples of 70 the `c_state` and `m_state`. If False, they are concatenated 71 along the column axis. The latter behavior will soon be deprecated. 72 activation: Activation function of the inner states. Default: `tanh`. 73 reuse: (optional) Python boolean describing whether to reuse variables 74 in an existing scope. If not `True`, and the existing scope already has 75 the given variables, an error is raised. 76 name: String, the name of the layer. Layers with the same name will 77 share weights, but to avoid mistakes we require reuse=True in such 78 cases. 79 80 When restoring from CudnnLSTM-trained checkpoints, must use 81 CudnnCompatibleLSTMCell instead. 82 """ 83 super(MaskedBasicLSTMCell, self).__init__( 84 num_units, 85 forget_bias=forget_bias, 86 state_is_tuple=state_is_tuple, 87 activation=activation, 88 reuse=reuse, 89 name=name) 90 91 def build(self, inputs_shape): 92 # Call the build method of the parent class. 93 super(MaskedBasicLSTMCell, self).build(inputs_shape) 94 95 self.built = False 96 97 input_depth = inputs_shape.dims[1].value 98 h_depth = self._num_units 99 self._mask = self.add_variable( 100 name="mask", 101 shape=[input_depth + h_depth, 4 * h_depth], 102 initializer=init_ops.ones_initializer(), 103 trainable=False, 104 dtype=self.dtype) 105 self._threshold = self.add_variable( 106 name="threshold", 107 shape=[], 108 initializer=init_ops.zeros_initializer(), 109 trainable=False, 110 dtype=self.dtype) 111 # Add masked_weights in the weights namescope so as to make it easier 112 # for the quantization library to add quant ops. 113 self._masked_kernel = math_ops.multiply(self._mask, self._kernel, 114 core_layers.MASKED_WEIGHT_NAME) 115 if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): 116 ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) 117 ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, 118 self._masked_kernel) 119 ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) 120 ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) 121 122 self.built = True 123 124 def call(self, inputs, state): 125 """Long short-term memory cell (LSTM) with masks for pruning. 126 127 Args: 128 inputs: `2-D` tensor with shape `[batch_size, input_size]`. 129 state: An `LSTMStateTuple` of state tensors, each shaped 130 `[batch_size, self.state_size]`, if `state_is_tuple` has been set to 131 `True`. Otherwise, a `Tensor` shaped 132 `[batch_size, 2 * self.state_size]`. 133 134 Returns: 135 A pair containing the new hidden state, and the new state (either a 136 `LSTMStateTuple` or a concatenated state, depending on 137 `state_is_tuple`). 138 """ 139 sigmoid = math_ops.sigmoid 140 one = constant_op.constant(1, dtype=dtypes.int32) 141 # Parameters of gates are concatenated into one multiply for efficiency. 142 if self._state_is_tuple: 143 c, h = state 144 else: 145 c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) 146 147 gate_inputs = math_ops.matmul( 148 array_ops.concat([inputs, h], 1), self._masked_kernel) 149 gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 150 151 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 152 i, j, f, o = array_ops.split( 153 value=gate_inputs, num_or_size_splits=4, axis=one) 154 155 forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 156 # Note that using `add` and `multiply` instead of `+` and `*` gives a 157 # performance improvement. So using those at the cost of readability. 158 add = math_ops.add 159 multiply = math_ops.multiply 160 new_c = add( 161 multiply(c, sigmoid(add(f, forget_bias_tensor))), 162 multiply(sigmoid(i), self._activation(j))) 163 new_h = multiply(self._activation(new_c), sigmoid(o)) 164 165 if self._state_is_tuple: 166 new_state = tf_rnn.LSTMStateTuple(new_c, new_h) 167 else: 168 new_state = array_ops.concat([new_c, new_h], 1) 169 return new_h, new_state 170 171 172class MaskedLSTMCell(tf_rnn.LSTMCell): 173 """LSTMCell with pruning. 174 175 Overrides the call method of tensorflow LSTMCell and injects the weight masks. 176 Masks are applied to only the weight matrix of the LSTM and not the 177 projection matrix. 178 """ 179 180 def __init__(self, 181 num_units, 182 use_peepholes=False, 183 cell_clip=None, 184 initializer=None, 185 num_proj=None, 186 proj_clip=None, 187 num_unit_shards=None, 188 num_proj_shards=None, 189 forget_bias=1.0, 190 state_is_tuple=True, 191 activation=None, 192 reuse=None): 193 """Initialize the parameters for an LSTM cell with masks for pruning. 194 195 Args: 196 num_units: int, The number of units in the LSTM cell 197 use_peepholes: bool, set True to enable diagonal/peephole connections. 198 cell_clip: (optional) A float value, if provided the cell state is clipped 199 by this value prior to the cell output activation. 200 initializer: (optional) The initializer to use for the weight and 201 projection matrices. 202 num_proj: (optional) int, The output dimensionality for the projection 203 matrices. If None, no projection is performed. 204 proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 205 provided, then the projected values are clipped elementwise to within 206 `[-proj_clip, proj_clip]`. 207 num_unit_shards: Deprecated, will be removed by Jan. 2017. 208 Use a variable_scope partitioner instead. 209 num_proj_shards: Deprecated, will be removed by Jan. 2017. 210 Use a variable_scope partitioner instead. 211 forget_bias: Biases of the forget gate are initialized by default to 1 212 in order to reduce the scale of forgetting at the beginning of 213 the training. Must set it manually to `0.0` when restoring from 214 CudnnLSTM trained checkpoints. 215 state_is_tuple: If True, accepted and returned states are 2-tuples of 216 the `c_state` and `m_state`. If False, they are concatenated 217 along the column axis. This latter behavior will soon be deprecated. 218 activation: Activation function of the inner states. Default: `tanh`. 219 reuse: (optional) Python boolean describing whether to reuse variables 220 in an existing scope. If not `True`, and the existing scope already has 221 the given variables, an error is raised. 222 223 When restoring from CudnnLSTM-trained checkpoints, must use 224 CudnnCompatibleLSTMCell instead. 225 """ 226 super(MaskedLSTMCell, self).__init__( 227 num_units, 228 use_peepholes=use_peepholes, 229 cell_clip=cell_clip, 230 initializer=initializer, 231 num_proj=num_proj, 232 proj_clip=proj_clip, 233 num_unit_shards=num_unit_shards, 234 num_proj_shards=num_proj_shards, 235 forget_bias=forget_bias, 236 state_is_tuple=state_is_tuple, 237 activation=activation, 238 reuse=reuse) 239 240 def build(self, inputs_shape): 241 # Call the build method of the parent class. 242 super(MaskedLSTMCell, self).build(inputs_shape) 243 244 self.built = False 245 246 input_depth = inputs_shape.dims[1].value 247 h_depth = self._num_units 248 self._mask = self.add_variable( 249 name="mask", 250 shape=[input_depth + h_depth, 4 * h_depth], 251 initializer=init_ops.ones_initializer(), 252 trainable=False, 253 dtype=self.dtype) 254 self._threshold = self.add_variable( 255 name="threshold", 256 shape=[], 257 initializer=init_ops.zeros_initializer(), 258 trainable=False, 259 dtype=self.dtype) 260 # Add masked_weights in the weights namescope so as to make it easier 261 # for the quantization library to add quant ops. 262 self._masked_kernel = math_ops.multiply(self._mask, self._kernel, 263 core_layers.MASKED_WEIGHT_NAME) 264 if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): 265 ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) 266 ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, 267 self._masked_kernel) 268 ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) 269 ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) 270 271 self.built = True 272 273 def call(self, inputs, state): 274 """Run one step of LSTM. 275 276 Args: 277 inputs: input Tensor, 2D, `[batch, num_units]. 278 state: if `state_is_tuple` is False, this must be a state Tensor, 279 `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a 280 tuple of state Tensors, both `2-D`, with column sizes `c_state` and 281 `m_state`. 282 283 Returns: 284 A tuple containing: 285 286 - A `2-D, [batch, output_dim]`, Tensor representing the output of the 287 LSTM after reading `inputs` when previous state was `state`. 288 Here output_dim is: 289 num_proj if num_proj was set, 290 num_units otherwise. 291 - Tensor(s) representing the new state of LSTM after reading `inputs` when 292 the previous state was `state`. Same type and shape(s) as `state`. 293 294 Raises: 295 ValueError: If input size cannot be inferred from inputs via 296 static shape inference. 297 """ 298 num_proj = self._num_units if self._num_proj is None else self._num_proj 299 sigmoid = math_ops.sigmoid 300 301 if self._state_is_tuple: 302 (c_prev, m_prev) = state 303 else: 304 c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 305 m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 306 307 input_size = inputs.get_shape().with_rank(2).dims[1] 308 if input_size.value is None: 309 raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 310 311 # i = input_gate, j = new_input, f = forget_gate, o = output_gate 312 lstm_matrix = math_ops.matmul( 313 array_ops.concat([inputs, m_prev], 1), self._masked_kernel) 314 lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) 315 316 i, j, f, o = array_ops.split( 317 value=lstm_matrix, num_or_size_splits=4, axis=1) 318 # Diagonal connections 319 if self._use_peepholes: 320 c = ( 321 sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + 322 sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) 323 else: 324 c = ( 325 sigmoid(f + self._forget_bias) * c_prev + 326 sigmoid(i) * self._activation(j)) 327 328 if self._cell_clip is not None: 329 # pylint: disable=invalid-unary-operand-type 330 c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 331 # pylint: enable=invalid-unary-operand-type 332 if self._use_peepholes: 333 m = sigmoid(o + self._w_o_diag * c) * self._activation(c) 334 else: 335 m = sigmoid(o) * self._activation(c) 336 337 if self._num_proj is not None: 338 m = math_ops.matmul(m, self._proj_kernel) 339 340 if self._proj_clip is not None: 341 # pylint: disable=invalid-unary-operand-type 342 m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 343 # pylint: enable=invalid-unary-operand-type 344 345 new_state = ( 346 tf_rnn.LSTMStateTuple(c, m) 347 if self._state_is_tuple else array_ops.concat([c, m], 1)) 348 return m, new_state 349