• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module implementing RNN Cells that used to be in core.
16
17@@EmbeddingWrapper
18@@InputProjectionWrapper
19@@OutputProjectionWrapper
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import math
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import embedding_ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.ops import rnn_cell_impl
36from tensorflow.python.ops import variable_scope as vs
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import nest
39
40
41# pylint: disable=protected-access,invalid-name
42RNNCell = rnn_cell_impl.RNNCell
43_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
44_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
45# pylint: enable=protected-access,invalid-name
46
47
48class _Linear(object):
49  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
50
51  Args:
52    args: a 2D Tensor or a list of 2D, batch, n, Tensors.
53    output_size: int, second dimension of weight variable.
54    dtype: data type for variables.
55    build_bias: boolean, whether to build a bias variable.
56    bias_initializer: starting value to initialize the bias
57      (default is all zeros).
58    kernel_initializer: starting value to initialize the weight.
59
60  Raises:
61    ValueError: if inputs_shape is wrong.
62  """
63
64  def __init__(self,
65               args,
66               output_size,
67               build_bias,
68               bias_initializer=None,
69               kernel_initializer=None):
70    self._build_bias = build_bias
71
72    if args is None or (nest.is_sequence(args) and not args):
73      raise ValueError("`args` must be specified")
74    if not nest.is_sequence(args):
75      args = [args]
76      self._is_sequence = False
77    else:
78      self._is_sequence = True
79
80    # Calculate the total size of arguments on dimension 1.
81    total_arg_size = 0
82    shapes = [a.get_shape() for a in args]
83    for shape in shapes:
84      if shape.ndims != 2:
85        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
86      if shape.dims[1].value is None:
87        raise ValueError("linear expects shape[1] to be provided for shape %s, "
88                         "but saw %s" % (shape, shape[1]))
89      else:
90        total_arg_size += shape.dims[1].value
91
92    dtype = [a.dtype for a in args][0]
93
94    scope = vs.get_variable_scope()
95    with vs.variable_scope(scope) as outer_scope:
96      self._weights = vs.get_variable(
97          _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
98          dtype=dtype,
99          initializer=kernel_initializer)
100      if build_bias:
101        with vs.variable_scope(outer_scope) as inner_scope:
102          inner_scope.set_partitioner(None)
103          if bias_initializer is None:
104            bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
105          self._biases = vs.get_variable(
106              _BIAS_VARIABLE_NAME, [output_size],
107              dtype=dtype,
108              initializer=bias_initializer)
109
110  def __call__(self, args):
111    if not self._is_sequence:
112      args = [args]
113
114    if len(args) == 1:
115      res = math_ops.matmul(args[0], self._weights)
116    else:
117      # Explicitly creating a one for a minor performance improvement.
118      one = constant_op.constant(1, dtype=dtypes.int32)
119      res = math_ops.matmul(array_ops.concat(args, one), self._weights)
120    if self._build_bias:
121      res = nn_ops.bias_add(res, self._biases)
122    return res
123
124
125# TODO(xpan): Remove this function in a follow up.
126def _linear(args,
127            output_size,
128            bias,
129            bias_initializer=None,
130            kernel_initializer=None):
131  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
132
133  Args:
134    args: a 2D Tensor or a list of 2D, batch, n, Tensors.
135    output_size: int, second dimension of W[i].
136    bias: boolean, whether to add a bias term or not.
137    bias_initializer: starting value to initialize the bias
138      (default is all zeros).
139    kernel_initializer: starting value to initialize the weight.
140
141  Returns:
142    A 2D Tensor with shape `[batch, output_size]` equal to
143    sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
144
145  Raises:
146    ValueError: if some of the arguments has unspecified or wrong shape.
147  """
148  if args is None or (nest.is_sequence(args) and not args):
149    raise ValueError("`args` must be specified")
150  if not nest.is_sequence(args):
151    args = [args]
152
153  # Calculate the total size of arguments on dimension 1.
154  total_arg_size = 0
155  shapes = [a.get_shape() for a in args]
156  for shape in shapes:
157    if shape.ndims != 2:
158      raise ValueError("linear is expecting 2D arguments: %s" % shapes)
159    if shape.dims[1].value is None:
160      raise ValueError("linear expects shape[1] to be provided for shape %s, "
161                       "but saw %s" % (shape, shape[1]))
162    else:
163      total_arg_size += shape.dims[1].value
164
165  dtype = [a.dtype for a in args][0]
166
167  # Now the computation.
168  scope = vs.get_variable_scope()
169  with vs.variable_scope(scope) as outer_scope:
170    weights = vs.get_variable(
171        _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
172        dtype=dtype,
173        initializer=kernel_initializer)
174    if len(args) == 1:
175      res = math_ops.matmul(args[0], weights)
176    else:
177      res = math_ops.matmul(array_ops.concat(args, 1), weights)
178    if not bias:
179      return res
180    with vs.variable_scope(outer_scope) as inner_scope:
181      inner_scope.set_partitioner(None)
182      if bias_initializer is None:
183        bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
184      biases = vs.get_variable(
185          _BIAS_VARIABLE_NAME, [output_size],
186          dtype=dtype,
187          initializer=bias_initializer)
188    return nn_ops.bias_add(res, biases)
189
190
191class EmbeddingWrapper(RNNCell):
192  """Operator adding input embedding to the given cell.
193
194  Note: in many cases it may be more efficient to not use this wrapper,
195  but instead concatenate the whole sequence of your inputs in time,
196  do the embedding on this batch-concatenated sequence, then split it and
197  feed into your RNN.
198  """
199
200  def __init__(self,
201               cell,
202               embedding_classes,
203               embedding_size,
204               initializer=None,
205               reuse=None):
206    """Create a cell with an added input embedding.
207
208    Args:
209      cell: an RNNCell, an embedding will be put before its inputs.
210      embedding_classes: integer, how many symbols will be embedded.
211      embedding_size: integer, the size of the vectors we embed into.
212      initializer: an initializer to use when creating the embedding;
213        if None, the initializer from variable scope or a default one is used.
214      reuse: (optional) Python boolean describing whether to reuse variables
215        in an existing scope.  If not `True`, and the existing scope already has
216        the given variables, an error is raised.
217
218    Raises:
219      TypeError: if cell is not an RNNCell.
220      ValueError: if embedding_classes is not positive.
221    """
222    super(EmbeddingWrapper, self).__init__(_reuse=reuse)
223    rnn_cell_impl.assert_like_rnncell("cell", cell)
224    if embedding_classes <= 0 or embedding_size <= 0:
225      raise ValueError("Both embedding_classes and embedding_size must be > 0: "
226                       "%d, %d." % (embedding_classes, embedding_size))
227    self._cell = cell
228    self._embedding_classes = embedding_classes
229    self._embedding_size = embedding_size
230    self._initializer = initializer
231
232  @property
233  def state_size(self):
234    return self._cell.state_size
235
236  @property
237  def output_size(self):
238    return self._cell.output_size
239
240  def zero_state(self, batch_size, dtype):
241    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
242      return self._cell.zero_state(batch_size, dtype)
243
244  def call(self, inputs, state):
245    """Run the cell on embedded inputs."""
246    with ops.device("/cpu:0"):
247      if self._initializer:
248        initializer = self._initializer
249      elif vs.get_variable_scope().initializer:
250        initializer = vs.get_variable_scope().initializer
251      else:
252        # Default initializer for embeddings should have variance=1.
253        sqrt3 = math.sqrt(3)  # Uniform(-sqrt(3), sqrt(3)) has variance=1.
254        initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
255
256      if isinstance(state, tuple):
257        data_type = state[0].dtype
258      else:
259        data_type = state.dtype
260
261      embedding = vs.get_variable(
262          "embedding", [self._embedding_classes, self._embedding_size],
263          initializer=initializer,
264          dtype=data_type)
265      embedded = embedding_ops.embedding_lookup(embedding,
266                                                array_ops.reshape(inputs, [-1]))
267
268      return self._cell(embedded, state)
269
270
271class InputProjectionWrapper(RNNCell):
272  """Operator adding an input projection to the given cell.
273
274  Note: in many cases it may be more efficient to not use this wrapper,
275  but instead concatenate the whole sequence of your inputs in time,
276  do the projection on this batch-concatenated sequence, then split it.
277  """
278
279  def __init__(self,
280               cell,
281               num_proj,
282               activation=None,
283               input_size=None,
284               reuse=None):
285    """Create a cell with input projection.
286
287    Args:
288      cell: an RNNCell, a projection of inputs is added before it.
289      num_proj: Python integer.  The dimension to project to.
290      activation: (optional) an optional activation function.
291      input_size: Deprecated and unused.
292      reuse: (optional) Python boolean describing whether to reuse variables
293        in an existing scope.  If not `True`, and the existing scope already has
294        the given variables, an error is raised.
295
296    Raises:
297      TypeError: if cell is not an RNNCell.
298    """
299    super(InputProjectionWrapper, self).__init__(_reuse=reuse)
300    if input_size is not None:
301      logging.warn("%s: The input_size parameter is deprecated.", self)
302    rnn_cell_impl.assert_like_rnncell("cell", cell)
303    self._cell = cell
304    self._num_proj = num_proj
305    self._activation = activation
306    self._linear = None
307
308  @property
309  def state_size(self):
310    return self._cell.state_size
311
312  @property
313  def output_size(self):
314    return self._cell.output_size
315
316  def zero_state(self, batch_size, dtype):
317    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
318      return self._cell.zero_state(batch_size, dtype)
319
320  def call(self, inputs, state):
321    """Run the input projection and then the cell."""
322    # Default scope: "InputProjectionWrapper"
323    if self._linear is None:
324      self._linear = _Linear(inputs, self._num_proj, True)
325    projected = self._linear(inputs)
326    if self._activation:
327      projected = self._activation(projected)
328    return self._cell(projected, state)
329
330
331class OutputProjectionWrapper(RNNCell):
332  """Operator adding an output projection to the given cell.
333
334  Note: in many cases it may be more efficient to not use this wrapper,
335  but instead concatenate the whole sequence of your outputs in time,
336  do the projection on this batch-concatenated sequence, then split it
337  if needed or directly feed into a softmax.
338  """
339
340  def __init__(self, cell, output_size, activation=None, reuse=None):
341    """Create a cell with output projection.
342
343    Args:
344      cell: an RNNCell, a projection to output_size is added to it.
345      output_size: integer, the size of the output after projection.
346      activation: (optional) an optional activation function.
347      reuse: (optional) Python boolean describing whether to reuse variables
348        in an existing scope.  If not `True`, and the existing scope already has
349        the given variables, an error is raised.
350
351    Raises:
352      TypeError: if cell is not an RNNCell.
353      ValueError: if output_size is not positive.
354    """
355    super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
356    rnn_cell_impl.assert_like_rnncell("cell", cell)
357    if output_size < 1:
358      raise ValueError("Parameter output_size must be > 0: %d." % output_size)
359    self._cell = cell
360    self._output_size = output_size
361    self._activation = activation
362    self._linear = None
363
364  @property
365  def state_size(self):
366    return self._cell.state_size
367
368  @property
369  def output_size(self):
370    return self._output_size
371
372  def zero_state(self, batch_size, dtype):
373    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
374      return self._cell.zero_state(batch_size, dtype)
375
376  def call(self, inputs, state):
377    """Run the cell and output projection on inputs, starting from state."""
378    output, res_state = self._cell(inputs, state)
379    if self._linear is None:
380      self._linear = _Linear(output, self._output_size, True)
381    projected = self._linear(output)
382    if self._activation:
383      projected = self._activation(projected)
384    return projected, res_state
385