• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for constructing GridRNN cells"""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import namedtuple
22import functools
23
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import variable_scope as vs
29
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.contrib import layers
32from tensorflow.contrib import rnn
33
34
35class GridRNNCell(rnn.RNNCell):
36  """Grid recurrent cell.
37
38  This implementation is based on:
39
40    http://arxiv.org/pdf/1507.01526v3.pdf
41
42    This is the generic implementation of GridRNN. Users can specify arbitrary
43    number of dimensions,
44    set some of them to be priority (section 3.2), non-recurrent (section 3.3)
45    and input/output dimensions (section 3.4).
46    Weight sharing can also be specified using the `tied` parameter.
47    Type of recurrent units can be specified via `cell_fn`.
48  """
49
50  def __init__(self,
51               num_units,
52               num_dims=1,
53               input_dims=None,
54               output_dims=None,
55               priority_dims=None,
56               non_recurrent_dims=None,
57               tied=False,
58               cell_fn=None,
59               non_recurrent_fn=None,
60               state_is_tuple=True,
61               output_is_tuple=True):
62    """Initialize the parameters of a Grid RNN cell
63
64    Args:
65      num_units: int, The number of units in all dimensions of this GridRNN cell
66      num_dims: int, Number of dimensions of this grid.
67      input_dims: int or list, List of dimensions which will receive input data.
68      output_dims: int or list, List of dimensions from which the output will be
69        recorded.
70      priority_dims: int or list, List of dimensions to be considered as
71        priority dimensions.
72              If None, no dimension is prioritized.
73      non_recurrent_dims: int or list, List of dimensions that are not
74        recurrent.
75              The transfer function for non-recurrent dimensions is specified
76                via `non_recurrent_fn`, which is
77                default to be `tensorflow.nn.relu`.
78      tied: bool, Whether to share the weights among the dimensions of this
79        GridRNN cell.
80              If there are non-recurrent dimensions in the grid, weights are
81                shared between each group of recurrent and non-recurrent
82                dimensions.
83      cell_fn: function, a function which returns the recurrent cell object.
84        Has to be in the following signature:
85              ```
86              def cell_func(num_units):
87                # ...
88              ```
89              and returns an object of type `RNNCell`. If None, LSTMCell with
90                default parameters will be used.
91        Note that if you use a custom RNNCell (with `cell_fn`), it is your
92        responsibility to make sure the inner cell use `state_is_tuple=True`.
93
94      non_recurrent_fn: a tensorflow Op that will be the transfer function of
95        the non-recurrent dimensions
96      state_is_tuple: If True, accepted and returned states are tuples of the
97        states of the recurrent dimensions. If False, they are concatenated
98        along the column axis. The latter behavior will soon be deprecated.
99
100        Note that if you use a custom RNNCell (with `cell_fn`), it is your
101        responsibility to make sure the inner cell use `state_is_tuple=True`.
102
103      output_is_tuple: If True, the output is a tuple of the outputs of the
104        recurrent dimensions. If False, they are concatenated along the
105        column axis. The later behavior will soon be deprecated.
106
107    Raises:
108      TypeError: if cell_fn does not return an RNNCell instance.
109    """
110    if not state_is_tuple:
111      logging.warning('%s: Using a concatenated state is slower and will '
112                      'soon be deprecated.  Use state_is_tuple=True.', self)
113    if not output_is_tuple:
114      logging.warning('%s: Using a concatenated output is slower and will '
115                      'soon be deprecated.  Use output_is_tuple=True.', self)
116
117    if num_dims < 1:
118      raise ValueError('dims must be >= 1: {}'.format(num_dims))
119
120    self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
121                                     priority_dims, non_recurrent_dims,
122                                     non_recurrent_fn or nn.relu, tied,
123                                     num_units)
124
125    self._state_is_tuple = state_is_tuple
126    self._output_is_tuple = output_is_tuple
127
128    if cell_fn is None:
129      my_cell_fn = functools.partial(
130          rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
131    else:
132      my_cell_fn = lambda: cell_fn(num_units)
133    if tied:
134      self._cells = [my_cell_fn()] * num_dims
135    else:
136      self._cells = [my_cell_fn() for _ in range(num_dims)]
137    if not isinstance(self._cells[0], rnn.RNNCell):
138      raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
139                      type(self._cells[0]))
140
141    if self._output_is_tuple:
142      self._output_size = tuple(self._cells[0].output_size
143                                for _ in self._config.outputs)
144    else:
145      self._output_size = self._cells[0].output_size * len(self._config.outputs)
146
147    if self._state_is_tuple:
148      self._state_size = tuple(self._cells[0].state_size
149                               for _ in self._config.recurrents)
150    else:
151      self._state_size = self._cell_state_size() * len(self._config.recurrents)
152
153  @property
154  def output_size(self):
155    return self._output_size
156
157  @property
158  def state_size(self):
159    return self._state_size
160
161  def __call__(self, inputs, state, scope=None):
162    """Run one step of GridRNN.
163
164    Args:
165      inputs: input Tensor, 2D, batch x input_size. Or None
166      state: state Tensor, 2D, batch x state_size. Note that state_size =
167        cell_state_size * recurrent_dims
168      scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
169
170    Returns:
171      A tuple containing:
172
173      - A 2D, batch x output_size, Tensor representing the output of the cell
174        after reading "inputs" when previous state was "state".
175      - A 2D, batch x state_size, Tensor representing the new state of the cell
176        after reading "inputs" when previous state was "state".
177    """
178    conf = self._config
179    dtype = inputs.dtype
180
181    c_prev, m_prev, cell_output_size = self._extract_states(state)
182
183    new_output = [None] * conf.num_dims
184    new_state = [None] * conf.num_dims
185
186    with vs.variable_scope(scope or type(self).__name__):  # GridRNNCell
187      # project input, populate c_prev and m_prev
188      self._project_input(inputs, c_prev, m_prev, cell_output_size > 0)
189
190      # propagate along dimensions, first for non-priority dimensions
191      # then priority dimensions
192      _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev,
193                 new_output, new_state, True)
194      _propagate(conf.priority, conf, self._cells,
195                 c_prev, m_prev, new_output, new_state, False)
196
197      # collect outputs and states
198      output_tensors = [new_output[i] for i in self._config.outputs]
199      if self._output_is_tuple:
200        output = tuple(output_tensors)
201      else:
202        if output_tensors:
203          output = array_ops.concat(output_tensors, 1)
204        else:
205          output = array_ops.zeros([0, 0], dtype)
206
207      if self._state_is_tuple:
208        states = tuple(new_state[i] for i in self._config.recurrents)
209      else:
210        # concat each state first, then flatten the whole thing
211        state_tensors = [
212            x for i in self._config.recurrents for x in new_state[i]
213        ]
214        if state_tensors:
215          states = array_ops.concat(state_tensors, 1)
216        else:
217          states = array_ops.zeros([0, 0], dtype)
218
219    return output, states
220
221  def _extract_states(self, state):
222    """Extract the cell and previous output tensors from the given state.
223
224    Args:
225      state: The RNN state.
226
227    Returns:
228      Tuple of the cell value, previous output, and cell_output_size.
229
230    Raises:
231      ValueError: If len(self._config.recurrents) != len(state).
232    """
233    conf = self._config
234
235    # c_prev is `m` (cell value), and
236    # m_prev is `h` (previous output) in the paper.
237    # Keeping c and m here for consistency with the codebase
238    c_prev = [None] * conf.num_dims
239    m_prev = [None] * conf.num_dims
240
241    # for LSTM   : state = memory cell + output, hence cell_output_size > 0
242    # for GRU/RNN: state = output (whose size is equal to _num_units),
243    #              hence cell_output_size = 0
244    total_cell_state_size = self._cell_state_size()
245    cell_output_size = total_cell_state_size - conf.num_units
246
247    if self._state_is_tuple:
248      if len(conf.recurrents) != len(state):
249        raise ValueError('Expected state as a tuple of {} '
250                         'element'.format(len(conf.recurrents)))
251
252      for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
253        if cell_output_size > 0:
254          c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state
255        else:
256          m_prev[recurrent_dim] = recurrent_state
257    else:
258      for recurrent_dim, start_idx in zip(conf.recurrents,
259                                          range(0, self.state_size,
260                                                total_cell_state_size)):
261        if cell_output_size > 0:
262          c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
263                                                  [-1, conf.num_units])
264          m_prev[recurrent_dim] = array_ops.slice(
265              state, [0, start_idx + conf.num_units], [-1, cell_output_size])
266        else:
267          m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
268                                                  [-1, conf.num_units])
269    return c_prev, m_prev, cell_output_size
270
271  def _project_input(self, inputs, c_prev, m_prev, with_c):
272    """Fills in c_prev and m_prev with projected input, for input dimensions.
273
274    Args:
275      inputs: inputs tensor
276      c_prev: cell value
277      m_prev: previous output
278      with_c: boolean; whether to include project_c.
279
280    Raises:
281      ValueError: if len(self._config.input) != len(inputs)
282    """
283    conf = self._config
284
285    if (inputs is not None and
286        tensor_shape.dimension_value(inputs.shape.with_rank(2)[1]) > 0 and
287        conf.inputs):
288      if isinstance(inputs, tuple):
289        if len(conf.inputs) != len(inputs):
290          raise ValueError('Expect inputs as a tuple of {} '
291                           'tensors'.format(len(conf.inputs)))
292        input_splits = inputs
293      else:
294        input_splits = array_ops.split(
295            value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
296      input_sz = tensor_shape.dimension_value(
297          input_splits[0].shape.with_rank(2)[1])
298
299      for i, j in enumerate(conf.inputs):
300        input_project_m = vs.get_variable(
301            'project_m_{}'.format(j), [input_sz, conf.num_units],
302            dtype=inputs.dtype)
303        m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
304
305        if with_c:
306          input_project_c = vs.get_variable(
307              'project_c_{}'.format(j), [input_sz, conf.num_units],
308              dtype=inputs.dtype)
309          c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
310
311  def _cell_state_size(self):
312    """Total size of the state of the inner cell used in this grid.
313
314    Returns:
315      Total size of the state of the inner cell.
316    """
317    state_sizes = self._cells[0].state_size
318    if isinstance(state_sizes, tuple):
319      return sum(state_sizes)
320    return state_sizes
321
322
323"""Specialized cells, for convenience
324"""
325
326
327class Grid1BasicRNNCell(GridRNNCell):
328  """1D BasicRNN cell"""
329
330  def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
331    super(Grid1BasicRNNCell, self).__init__(
332        num_units=num_units,
333        num_dims=1,
334        input_dims=0,
335        output_dims=0,
336        priority_dims=0,
337        tied=False,
338        cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
339        state_is_tuple=state_is_tuple,
340        output_is_tuple=output_is_tuple)
341
342
343class Grid2BasicRNNCell(GridRNNCell):
344  """2D BasicRNN cell
345
346  This creates a 2D cell which receives input and gives output in the first
347  dimension.
348
349  The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
350  specified.
351  """
352
353  def __init__(self,
354               num_units,
355               tied=False,
356               non_recurrent_fn=None,
357               state_is_tuple=True,
358               output_is_tuple=True):
359    super(Grid2BasicRNNCell, self).__init__(
360        num_units=num_units,
361        num_dims=2,
362        input_dims=0,
363        output_dims=0,
364        priority_dims=0,
365        tied=tied,
366        non_recurrent_dims=None if non_recurrent_fn is None else 0,
367        cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
368        non_recurrent_fn=non_recurrent_fn,
369        state_is_tuple=state_is_tuple,
370        output_is_tuple=output_is_tuple)
371
372
373class Grid1BasicLSTMCell(GridRNNCell):
374  """1D BasicLSTM cell."""
375
376  def __init__(self,
377               num_units,
378               forget_bias=1,
379               state_is_tuple=True,
380               output_is_tuple=True):
381    def cell_fn(n):
382      return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
383    super(Grid1BasicLSTMCell, self).__init__(
384        num_units=num_units,
385        num_dims=1,
386        input_dims=0,
387        output_dims=0,
388        priority_dims=0,
389        tied=False,
390        cell_fn=cell_fn,
391        state_is_tuple=state_is_tuple,
392        output_is_tuple=output_is_tuple)
393
394
395class Grid2BasicLSTMCell(GridRNNCell):
396  """2D BasicLSTM cell.
397
398  This creates a 2D cell which receives input and gives output in the first
399  dimension.
400
401  The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
402  specified.
403  """
404
405  def __init__(self,
406               num_units,
407               tied=False,
408               non_recurrent_fn=None,
409               forget_bias=1,
410               state_is_tuple=True,
411               output_is_tuple=True):
412    def cell_fn(n):
413      return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
414    super(Grid2BasicLSTMCell, self).__init__(
415        num_units=num_units,
416        num_dims=2,
417        input_dims=0,
418        output_dims=0,
419        priority_dims=0,
420        tied=tied,
421        non_recurrent_dims=None if non_recurrent_fn is None else 0,
422        cell_fn=cell_fn,
423        non_recurrent_fn=non_recurrent_fn,
424        state_is_tuple=state_is_tuple,
425        output_is_tuple=output_is_tuple)
426
427
428class Grid1LSTMCell(GridRNNCell):
429  """1D LSTM cell.
430
431  This is different from Grid1BasicLSTMCell because it gives options to
432  specify the forget bias and enabling peepholes.
433  """
434
435  def __init__(self,
436               num_units,
437               use_peepholes=False,
438               forget_bias=1.0,
439               state_is_tuple=True,
440               output_is_tuple=True):
441
442    def cell_fn(n):
443      return rnn.LSTMCell(
444          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
445
446    super(Grid1LSTMCell, self).__init__(
447        num_units=num_units,
448        num_dims=1,
449        input_dims=0,
450        output_dims=0,
451        priority_dims=0,
452        cell_fn=cell_fn,
453        state_is_tuple=state_is_tuple,
454        output_is_tuple=output_is_tuple)
455
456
457class Grid2LSTMCell(GridRNNCell):
458  """2D LSTM cell.
459
460    This creates a 2D cell which receives input and gives output in the first
461    dimension.
462    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
463    specified.
464  """
465
466  def __init__(self,
467               num_units,
468               tied=False,
469               non_recurrent_fn=None,
470               use_peepholes=False,
471               forget_bias=1.0,
472               state_is_tuple=True,
473               output_is_tuple=True):
474
475    def cell_fn(n):
476      return rnn.LSTMCell(
477          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
478
479    super(Grid2LSTMCell, self).__init__(
480        num_units=num_units,
481        num_dims=2,
482        input_dims=0,
483        output_dims=0,
484        priority_dims=0,
485        tied=tied,
486        non_recurrent_dims=None if non_recurrent_fn is None else 0,
487        cell_fn=cell_fn,
488        non_recurrent_fn=non_recurrent_fn,
489        state_is_tuple=state_is_tuple,
490        output_is_tuple=output_is_tuple)
491
492
493class Grid3LSTMCell(GridRNNCell):
494  """3D BasicLSTM cell.
495
496    This creates a 2D cell which receives input and gives output in the first
497    dimension.
498    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
499    specified.
500    The second and third dimensions are LSTM.
501  """
502
503  def __init__(self,
504               num_units,
505               tied=False,
506               non_recurrent_fn=None,
507               use_peepholes=False,
508               forget_bias=1.0,
509               state_is_tuple=True,
510               output_is_tuple=True):
511
512    def cell_fn(n):
513      return rnn.LSTMCell(
514          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
515
516    super(Grid3LSTMCell, self).__init__(
517        num_units=num_units,
518        num_dims=3,
519        input_dims=0,
520        output_dims=0,
521        priority_dims=0,
522        tied=tied,
523        non_recurrent_dims=None if non_recurrent_fn is None else 0,
524        cell_fn=cell_fn,
525        non_recurrent_fn=non_recurrent_fn,
526        state_is_tuple=state_is_tuple,
527        output_is_tuple=output_is_tuple)
528
529
530class Grid2GRUCell(GridRNNCell):
531  """2D LSTM cell.
532
533    This creates a 2D cell which receives input and gives output in the first
534    dimension.
535    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
536    specified.
537  """
538
539  def __init__(self,
540               num_units,
541               tied=False,
542               non_recurrent_fn=None,
543               state_is_tuple=True,
544               output_is_tuple=True):
545    super(Grid2GRUCell, self).__init__(
546        num_units=num_units,
547        num_dims=2,
548        input_dims=0,
549        output_dims=0,
550        priority_dims=0,
551        tied=tied,
552        non_recurrent_dims=None if non_recurrent_fn is None else 0,
553        cell_fn=lambda n: rnn.GRUCell(num_units=n),
554        non_recurrent_fn=non_recurrent_fn,
555        state_is_tuple=state_is_tuple,
556        output_is_tuple=output_is_tuple)
557
558
559# Helpers
560
561_GridRNNDimension = namedtuple('_GridRNNDimension', [
562    'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
563])
564
565_GridRNNConfig = namedtuple('_GridRNNConfig',
566                            ['num_dims', 'dims', 'inputs', 'outputs',
567                             'recurrents', 'priority', 'non_priority', 'tied',
568                             'num_units'])
569
570
571def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
572                      ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
573  def check_dim_list(ls):
574    if ls is None:
575      ls = []
576    if not isinstance(ls, (list, tuple)):
577      ls = [ls]
578    ls = sorted(set(ls))
579    if any(_ < 0 or _ >= num_dims for _ in ls):
580      raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls,
581                                                                     num_dims))
582    return ls
583
584  input_dims = check_dim_list(ls_input_dims)
585  output_dims = check_dim_list(ls_output_dims)
586  priority_dims = check_dim_list(ls_priority_dims)
587  non_recurrent_dims = check_dim_list(ls_non_recurrent_dims)
588
589  rnn_dims = []
590  for i in range(num_dims):
591    rnn_dims.append(
592        _GridRNNDimension(
593            idx=i,
594            is_input=(i in input_dims),
595            is_output=(i in output_dims),
596            is_priority=(i in priority_dims),
597            non_recurrent_fn=non_recurrent_fn
598            if i in non_recurrent_dims else None))
599  return _GridRNNConfig(
600      num_dims=num_dims,
601      dims=rnn_dims,
602      inputs=input_dims,
603      outputs=output_dims,
604      recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
605      priority=priority_dims,
606      non_priority=[x for x in range(num_dims) if x not in priority_dims],
607      tied=tied,
608      num_units=num_units)
609
610
611def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
612               first_call):
613  """Propagates through all the cells in dim_indices dimensions.
614  """
615  if len(dim_indices) == 0:
616    return
617
618  # Because of the way RNNCells are implemented, we take the last dimension
619  # (H_{N-1}) out and feed it as the state of the RNN cell
620  # (in `last_dim_output`).
621  # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
622  if conf.num_dims > 1:
623    ls_cell_inputs = [None] * (conf.num_dims - 1)
624    for d in conf.dims[:-1]:
625      if new_output[d.idx] is None:
626        ls_cell_inputs[d.idx] = m_prev[d.idx]
627      else:
628        ls_cell_inputs[d.idx] = new_output[d.idx]
629    cell_inputs = array_ops.concat(ls_cell_inputs, 1)
630  else:
631    cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
632                                  m_prev[0].dtype)
633
634  last_dim_output = (new_output[-1]
635                     if new_output[-1] is not None else m_prev[-1])
636
637  for i in dim_indices:
638    d = conf.dims[i]
639    if d.non_recurrent_fn:
640      if conf.num_dims > 1:
641        linear_args = array_ops.concat([cell_inputs, last_dim_output], 1)
642      else:
643        linear_args = last_dim_output
644      with vs.variable_scope('non_recurrent' if conf.tied else
645                             'non_recurrent/cell_{}'.format(i)):
646        if conf.tied and not (first_call and i == dim_indices[0]):
647          vs.get_variable_scope().reuse_variables()
648
649        new_output[d.idx] = layers.fully_connected(
650            linear_args,
651            num_outputs=conf.num_units,
652            activation_fn=d.non_recurrent_fn,
653            weights_initializer=(vs.get_variable_scope().initializer or
654                                 layers.initializers.xavier_initializer),
655            weights_regularizer=vs.get_variable_scope().regularizer)
656    else:
657      if c_prev[i] is not None:
658        cell_state = (c_prev[i], last_dim_output)
659      else:
660        # for GRU/RNN, the state is just the previous output
661        cell_state = last_dim_output
662
663      with vs.variable_scope('recurrent' if conf.tied else
664                             'recurrent/cell_{}'.format(i)):
665        if conf.tied and not (first_call and i == dim_indices[0]):
666          vs.get_variable_scope().reuse_variables()
667        cell = cells[i]
668        new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)
669