1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module implementing RNN Cells that used to be in core. 16 17@@EmbeddingWrapper 18@@InputProjectionWrapper 19@@OutputProjectionWrapper 20""" 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import math 26 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import embedding_ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nn_ops 35from tensorflow.python.ops import rnn_cell_impl 36from tensorflow.python.ops import variable_scope as vs 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import nest 39 40 41# pylint: disable=protected-access,invalid-name 42RNNCell = rnn_cell_impl.RNNCell 43_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME 44_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME 45# pylint: enable=protected-access,invalid-name 46 47 48class _Linear(object): 49 """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 50 51 Args: 52 args: a 2D Tensor or a list of 2D, batch, n, Tensors. 53 output_size: int, second dimension of weight variable. 54 dtype: data type for variables. 55 build_bias: boolean, whether to build a bias variable. 56 bias_initializer: starting value to initialize the bias 57 (default is all zeros). 58 kernel_initializer: starting value to initialize the weight. 59 60 Raises: 61 ValueError: if inputs_shape is wrong. 62 """ 63 64 def __init__(self, 65 args, 66 output_size, 67 build_bias, 68 bias_initializer=None, 69 kernel_initializer=None): 70 self._build_bias = build_bias 71 72 if args is None or (nest.is_sequence(args) and not args): 73 raise ValueError("`args` must be specified") 74 if not nest.is_sequence(args): 75 args = [args] 76 self._is_sequence = False 77 else: 78 self._is_sequence = True 79 80 # Calculate the total size of arguments on dimension 1. 81 total_arg_size = 0 82 shapes = [a.get_shape() for a in args] 83 for shape in shapes: 84 if shape.ndims != 2: 85 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 86 if shape.dims[1].value is None: 87 raise ValueError("linear expects shape[1] to be provided for shape %s, " 88 "but saw %s" % (shape, shape[1])) 89 else: 90 total_arg_size += shape.dims[1].value 91 92 dtype = [a.dtype for a in args][0] 93 94 scope = vs.get_variable_scope() 95 with vs.variable_scope(scope) as outer_scope: 96 self._weights = vs.get_variable( 97 _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], 98 dtype=dtype, 99 initializer=kernel_initializer) 100 if build_bias: 101 with vs.variable_scope(outer_scope) as inner_scope: 102 inner_scope.set_partitioner(None) 103 if bias_initializer is None: 104 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 105 self._biases = vs.get_variable( 106 _BIAS_VARIABLE_NAME, [output_size], 107 dtype=dtype, 108 initializer=bias_initializer) 109 110 def __call__(self, args): 111 if not self._is_sequence: 112 args = [args] 113 114 if len(args) == 1: 115 res = math_ops.matmul(args[0], self._weights) 116 else: 117 # Explicitly creating a one for a minor performance improvement. 118 one = constant_op.constant(1, dtype=dtypes.int32) 119 res = math_ops.matmul(array_ops.concat(args, one), self._weights) 120 if self._build_bias: 121 res = nn_ops.bias_add(res, self._biases) 122 return res 123 124 125# TODO(xpan): Remove this function in a follow up. 126def _linear(args, 127 output_size, 128 bias, 129 bias_initializer=None, 130 kernel_initializer=None): 131 """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 132 133 Args: 134 args: a 2D Tensor or a list of 2D, batch, n, Tensors. 135 output_size: int, second dimension of W[i]. 136 bias: boolean, whether to add a bias term or not. 137 bias_initializer: starting value to initialize the bias 138 (default is all zeros). 139 kernel_initializer: starting value to initialize the weight. 140 141 Returns: 142 A 2D Tensor with shape `[batch, output_size]` equal to 143 sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 144 145 Raises: 146 ValueError: if some of the arguments has unspecified or wrong shape. 147 """ 148 if args is None or (nest.is_sequence(args) and not args): 149 raise ValueError("`args` must be specified") 150 if not nest.is_sequence(args): 151 args = [args] 152 153 # Calculate the total size of arguments on dimension 1. 154 total_arg_size = 0 155 shapes = [a.get_shape() for a in args] 156 for shape in shapes: 157 if shape.ndims != 2: 158 raise ValueError("linear is expecting 2D arguments: %s" % shapes) 159 if shape.dims[1].value is None: 160 raise ValueError("linear expects shape[1] to be provided for shape %s, " 161 "but saw %s" % (shape, shape[1])) 162 else: 163 total_arg_size += shape.dims[1].value 164 165 dtype = [a.dtype for a in args][0] 166 167 # Now the computation. 168 scope = vs.get_variable_scope() 169 with vs.variable_scope(scope) as outer_scope: 170 weights = vs.get_variable( 171 _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], 172 dtype=dtype, 173 initializer=kernel_initializer) 174 if len(args) == 1: 175 res = math_ops.matmul(args[0], weights) 176 else: 177 res = math_ops.matmul(array_ops.concat(args, 1), weights) 178 if not bias: 179 return res 180 with vs.variable_scope(outer_scope) as inner_scope: 181 inner_scope.set_partitioner(None) 182 if bias_initializer is None: 183 bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 184 biases = vs.get_variable( 185 _BIAS_VARIABLE_NAME, [output_size], 186 dtype=dtype, 187 initializer=bias_initializer) 188 return nn_ops.bias_add(res, biases) 189 190 191class EmbeddingWrapper(RNNCell): 192 """Operator adding input embedding to the given cell. 193 194 Note: in many cases it may be more efficient to not use this wrapper, 195 but instead concatenate the whole sequence of your inputs in time, 196 do the embedding on this batch-concatenated sequence, then split it and 197 feed into your RNN. 198 """ 199 200 def __init__(self, 201 cell, 202 embedding_classes, 203 embedding_size, 204 initializer=None, 205 reuse=None): 206 """Create a cell with an added input embedding. 207 208 Args: 209 cell: an RNNCell, an embedding will be put before its inputs. 210 embedding_classes: integer, how many symbols will be embedded. 211 embedding_size: integer, the size of the vectors we embed into. 212 initializer: an initializer to use when creating the embedding; 213 if None, the initializer from variable scope or a default one is used. 214 reuse: (optional) Python boolean describing whether to reuse variables 215 in an existing scope. If not `True`, and the existing scope already has 216 the given variables, an error is raised. 217 218 Raises: 219 TypeError: if cell is not an RNNCell. 220 ValueError: if embedding_classes is not positive. 221 """ 222 super(EmbeddingWrapper, self).__init__(_reuse=reuse) 223 rnn_cell_impl.assert_like_rnncell("cell", cell) 224 if embedding_classes <= 0 or embedding_size <= 0: 225 raise ValueError("Both embedding_classes and embedding_size must be > 0: " 226 "%d, %d." % (embedding_classes, embedding_size)) 227 self._cell = cell 228 self._embedding_classes = embedding_classes 229 self._embedding_size = embedding_size 230 self._initializer = initializer 231 232 @property 233 def state_size(self): 234 return self._cell.state_size 235 236 @property 237 def output_size(self): 238 return self._cell.output_size 239 240 def zero_state(self, batch_size, dtype): 241 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 242 return self._cell.zero_state(batch_size, dtype) 243 244 def call(self, inputs, state): 245 """Run the cell on embedded inputs.""" 246 with ops.device("/cpu:0"): 247 if self._initializer: 248 initializer = self._initializer 249 elif vs.get_variable_scope().initializer: 250 initializer = vs.get_variable_scope().initializer 251 else: 252 # Default initializer for embeddings should have variance=1. 253 sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1. 254 initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3) 255 256 if isinstance(state, tuple): 257 data_type = state[0].dtype 258 else: 259 data_type = state.dtype 260 261 embedding = vs.get_variable( 262 "embedding", [self._embedding_classes, self._embedding_size], 263 initializer=initializer, 264 dtype=data_type) 265 embedded = embedding_ops.embedding_lookup(embedding, 266 array_ops.reshape(inputs, [-1])) 267 268 return self._cell(embedded, state) 269 270 271class InputProjectionWrapper(RNNCell): 272 """Operator adding an input projection to the given cell. 273 274 Note: in many cases it may be more efficient to not use this wrapper, 275 but instead concatenate the whole sequence of your inputs in time, 276 do the projection on this batch-concatenated sequence, then split it. 277 """ 278 279 def __init__(self, 280 cell, 281 num_proj, 282 activation=None, 283 input_size=None, 284 reuse=None): 285 """Create a cell with input projection. 286 287 Args: 288 cell: an RNNCell, a projection of inputs is added before it. 289 num_proj: Python integer. The dimension to project to. 290 activation: (optional) an optional activation function. 291 input_size: Deprecated and unused. 292 reuse: (optional) Python boolean describing whether to reuse variables 293 in an existing scope. If not `True`, and the existing scope already has 294 the given variables, an error is raised. 295 296 Raises: 297 TypeError: if cell is not an RNNCell. 298 """ 299 super(InputProjectionWrapper, self).__init__(_reuse=reuse) 300 if input_size is not None: 301 logging.warn("%s: The input_size parameter is deprecated.", self) 302 rnn_cell_impl.assert_like_rnncell("cell", cell) 303 self._cell = cell 304 self._num_proj = num_proj 305 self._activation = activation 306 self._linear = None 307 308 @property 309 def state_size(self): 310 return self._cell.state_size 311 312 @property 313 def output_size(self): 314 return self._cell.output_size 315 316 def zero_state(self, batch_size, dtype): 317 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 318 return self._cell.zero_state(batch_size, dtype) 319 320 def call(self, inputs, state): 321 """Run the input projection and then the cell.""" 322 # Default scope: "InputProjectionWrapper" 323 if self._linear is None: 324 self._linear = _Linear(inputs, self._num_proj, True) 325 projected = self._linear(inputs) 326 if self._activation: 327 projected = self._activation(projected) 328 return self._cell(projected, state) 329 330 331class OutputProjectionWrapper(RNNCell): 332 """Operator adding an output projection to the given cell. 333 334 Note: in many cases it may be more efficient to not use this wrapper, 335 but instead concatenate the whole sequence of your outputs in time, 336 do the projection on this batch-concatenated sequence, then split it 337 if needed or directly feed into a softmax. 338 """ 339 340 def __init__(self, cell, output_size, activation=None, reuse=None): 341 """Create a cell with output projection. 342 343 Args: 344 cell: an RNNCell, a projection to output_size is added to it. 345 output_size: integer, the size of the output after projection. 346 activation: (optional) an optional activation function. 347 reuse: (optional) Python boolean describing whether to reuse variables 348 in an existing scope. If not `True`, and the existing scope already has 349 the given variables, an error is raised. 350 351 Raises: 352 TypeError: if cell is not an RNNCell. 353 ValueError: if output_size is not positive. 354 """ 355 super(OutputProjectionWrapper, self).__init__(_reuse=reuse) 356 rnn_cell_impl.assert_like_rnncell("cell", cell) 357 if output_size < 1: 358 raise ValueError("Parameter output_size must be > 0: %d." % output_size) 359 self._cell = cell 360 self._output_size = output_size 361 self._activation = activation 362 self._linear = None 363 364 @property 365 def state_size(self): 366 return self._cell.state_size 367 368 @property 369 def output_size(self): 370 return self._output_size 371 372 def zero_state(self, batch_size, dtype): 373 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 374 return self._cell.zero_state(batch_size, dtype) 375 376 def call(self, inputs, state): 377 """Run the cell and output projection on inputs, starting from state.""" 378 output, res_state = self._cell(inputs, state) 379 if self._linear is None: 380 self._linear = _Linear(output, self._output_size, True) 381 projected = self._linear(output) 382 if self._activation: 383 projected = self._activation(projected) 384 return projected, res_state 385