1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for constructing GridRNN cells""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import namedtuple 22import functools 23 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn 28from tensorflow.python.ops import variable_scope as vs 29 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.contrib import layers 32from tensorflow.contrib import rnn 33 34 35class GridRNNCell(rnn.RNNCell): 36 """Grid recurrent cell. 37 38 This implementation is based on: 39 40 http://arxiv.org/pdf/1507.01526v3.pdf 41 42 This is the generic implementation of GridRNN. Users can specify arbitrary 43 number of dimensions, 44 set some of them to be priority (section 3.2), non-recurrent (section 3.3) 45 and input/output dimensions (section 3.4). 46 Weight sharing can also be specified using the `tied` parameter. 47 Type of recurrent units can be specified via `cell_fn`. 48 """ 49 50 def __init__(self, 51 num_units, 52 num_dims=1, 53 input_dims=None, 54 output_dims=None, 55 priority_dims=None, 56 non_recurrent_dims=None, 57 tied=False, 58 cell_fn=None, 59 non_recurrent_fn=None, 60 state_is_tuple=True, 61 output_is_tuple=True): 62 """Initialize the parameters of a Grid RNN cell 63 64 Args: 65 num_units: int, The number of units in all dimensions of this GridRNN cell 66 num_dims: int, Number of dimensions of this grid. 67 input_dims: int or list, List of dimensions which will receive input data. 68 output_dims: int or list, List of dimensions from which the output will be 69 recorded. 70 priority_dims: int or list, List of dimensions to be considered as 71 priority dimensions. 72 If None, no dimension is prioritized. 73 non_recurrent_dims: int or list, List of dimensions that are not 74 recurrent. 75 The transfer function for non-recurrent dimensions is specified 76 via `non_recurrent_fn`, which is 77 default to be `tensorflow.nn.relu`. 78 tied: bool, Whether to share the weights among the dimensions of this 79 GridRNN cell. 80 If there are non-recurrent dimensions in the grid, weights are 81 shared between each group of recurrent and non-recurrent 82 dimensions. 83 cell_fn: function, a function which returns the recurrent cell object. 84 Has to be in the following signature: 85 ``` 86 def cell_func(num_units): 87 # ... 88 ``` 89 and returns an object of type `RNNCell`. If None, LSTMCell with 90 default parameters will be used. 91 Note that if you use a custom RNNCell (with `cell_fn`), it is your 92 responsibility to make sure the inner cell use `state_is_tuple=True`. 93 94 non_recurrent_fn: a tensorflow Op that will be the transfer function of 95 the non-recurrent dimensions 96 state_is_tuple: If True, accepted and returned states are tuples of the 97 states of the recurrent dimensions. If False, they are concatenated 98 along the column axis. The latter behavior will soon be deprecated. 99 100 Note that if you use a custom RNNCell (with `cell_fn`), it is your 101 responsibility to make sure the inner cell use `state_is_tuple=True`. 102 103 output_is_tuple: If True, the output is a tuple of the outputs of the 104 recurrent dimensions. If False, they are concatenated along the 105 column axis. The later behavior will soon be deprecated. 106 107 Raises: 108 TypeError: if cell_fn does not return an RNNCell instance. 109 """ 110 if not state_is_tuple: 111 logging.warning('%s: Using a concatenated state is slower and will ' 112 'soon be deprecated. Use state_is_tuple=True.', self) 113 if not output_is_tuple: 114 logging.warning('%s: Using a concatenated output is slower and will ' 115 'soon be deprecated. Use output_is_tuple=True.', self) 116 117 if num_dims < 1: 118 raise ValueError('dims must be >= 1: {}'.format(num_dims)) 119 120 self._config = _parse_rnn_config(num_dims, input_dims, output_dims, 121 priority_dims, non_recurrent_dims, 122 non_recurrent_fn or nn.relu, tied, 123 num_units) 124 125 self._state_is_tuple = state_is_tuple 126 self._output_is_tuple = output_is_tuple 127 128 if cell_fn is None: 129 my_cell_fn = functools.partial( 130 rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple) 131 else: 132 my_cell_fn = lambda: cell_fn(num_units) 133 if tied: 134 self._cells = [my_cell_fn()] * num_dims 135 else: 136 self._cells = [my_cell_fn() for _ in range(num_dims)] 137 if not isinstance(self._cells[0], rnn.RNNCell): 138 raise TypeError('cell_fn must return an RNNCell instance, saw: %s' % 139 type(self._cells[0])) 140 141 if self._output_is_tuple: 142 self._output_size = tuple(self._cells[0].output_size 143 for _ in self._config.outputs) 144 else: 145 self._output_size = self._cells[0].output_size * len(self._config.outputs) 146 147 if self._state_is_tuple: 148 self._state_size = tuple(self._cells[0].state_size 149 for _ in self._config.recurrents) 150 else: 151 self._state_size = self._cell_state_size() * len(self._config.recurrents) 152 153 @property 154 def output_size(self): 155 return self._output_size 156 157 @property 158 def state_size(self): 159 return self._state_size 160 161 def __call__(self, inputs, state, scope=None): 162 """Run one step of GridRNN. 163 164 Args: 165 inputs: input Tensor, 2D, batch x input_size. Or None 166 state: state Tensor, 2D, batch x state_size. Note that state_size = 167 cell_state_size * recurrent_dims 168 scope: VariableScope for the created subgraph; defaults to "GridRNNCell". 169 170 Returns: 171 A tuple containing: 172 173 - A 2D, batch x output_size, Tensor representing the output of the cell 174 after reading "inputs" when previous state was "state". 175 - A 2D, batch x state_size, Tensor representing the new state of the cell 176 after reading "inputs" when previous state was "state". 177 """ 178 conf = self._config 179 dtype = inputs.dtype 180 181 c_prev, m_prev, cell_output_size = self._extract_states(state) 182 183 new_output = [None] * conf.num_dims 184 new_state = [None] * conf.num_dims 185 186 with vs.variable_scope(scope or type(self).__name__): # GridRNNCell 187 # project input, populate c_prev and m_prev 188 self._project_input(inputs, c_prev, m_prev, cell_output_size > 0) 189 190 # propagate along dimensions, first for non-priority dimensions 191 # then priority dimensions 192 _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev, 193 new_output, new_state, True) 194 _propagate(conf.priority, conf, self._cells, 195 c_prev, m_prev, new_output, new_state, False) 196 197 # collect outputs and states 198 output_tensors = [new_output[i] for i in self._config.outputs] 199 if self._output_is_tuple: 200 output = tuple(output_tensors) 201 else: 202 if output_tensors: 203 output = array_ops.concat(output_tensors, 1) 204 else: 205 output = array_ops.zeros([0, 0], dtype) 206 207 if self._state_is_tuple: 208 states = tuple(new_state[i] for i in self._config.recurrents) 209 else: 210 # concat each state first, then flatten the whole thing 211 state_tensors = [ 212 x for i in self._config.recurrents for x in new_state[i] 213 ] 214 if state_tensors: 215 states = array_ops.concat(state_tensors, 1) 216 else: 217 states = array_ops.zeros([0, 0], dtype) 218 219 return output, states 220 221 def _extract_states(self, state): 222 """Extract the cell and previous output tensors from the given state. 223 224 Args: 225 state: The RNN state. 226 227 Returns: 228 Tuple of the cell value, previous output, and cell_output_size. 229 230 Raises: 231 ValueError: If len(self._config.recurrents) != len(state). 232 """ 233 conf = self._config 234 235 # c_prev is `m` (cell value), and 236 # m_prev is `h` (previous output) in the paper. 237 # Keeping c and m here for consistency with the codebase 238 c_prev = [None] * conf.num_dims 239 m_prev = [None] * conf.num_dims 240 241 # for LSTM : state = memory cell + output, hence cell_output_size > 0 242 # for GRU/RNN: state = output (whose size is equal to _num_units), 243 # hence cell_output_size = 0 244 total_cell_state_size = self._cell_state_size() 245 cell_output_size = total_cell_state_size - conf.num_units 246 247 if self._state_is_tuple: 248 if len(conf.recurrents) != len(state): 249 raise ValueError('Expected state as a tuple of {} ' 250 'element'.format(len(conf.recurrents))) 251 252 for recurrent_dim, recurrent_state in zip(conf.recurrents, state): 253 if cell_output_size > 0: 254 c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state 255 else: 256 m_prev[recurrent_dim] = recurrent_state 257 else: 258 for recurrent_dim, start_idx in zip(conf.recurrents, 259 range(0, self.state_size, 260 total_cell_state_size)): 261 if cell_output_size > 0: 262 c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], 263 [-1, conf.num_units]) 264 m_prev[recurrent_dim] = array_ops.slice( 265 state, [0, start_idx + conf.num_units], [-1, cell_output_size]) 266 else: 267 m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], 268 [-1, conf.num_units]) 269 return c_prev, m_prev, cell_output_size 270 271 def _project_input(self, inputs, c_prev, m_prev, with_c): 272 """Fills in c_prev and m_prev with projected input, for input dimensions. 273 274 Args: 275 inputs: inputs tensor 276 c_prev: cell value 277 m_prev: previous output 278 with_c: boolean; whether to include project_c. 279 280 Raises: 281 ValueError: if len(self._config.input) != len(inputs) 282 """ 283 conf = self._config 284 285 if (inputs is not None and 286 tensor_shape.dimension_value(inputs.shape.with_rank(2)[1]) > 0 and 287 conf.inputs): 288 if isinstance(inputs, tuple): 289 if len(conf.inputs) != len(inputs): 290 raise ValueError('Expect inputs as a tuple of {} ' 291 'tensors'.format(len(conf.inputs))) 292 input_splits = inputs 293 else: 294 input_splits = array_ops.split( 295 value=inputs, num_or_size_splits=len(conf.inputs), axis=1) 296 input_sz = tensor_shape.dimension_value( 297 input_splits[0].shape.with_rank(2)[1]) 298 299 for i, j in enumerate(conf.inputs): 300 input_project_m = vs.get_variable( 301 'project_m_{}'.format(j), [input_sz, conf.num_units], 302 dtype=inputs.dtype) 303 m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) 304 305 if with_c: 306 input_project_c = vs.get_variable( 307 'project_c_{}'.format(j), [input_sz, conf.num_units], 308 dtype=inputs.dtype) 309 c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) 310 311 def _cell_state_size(self): 312 """Total size of the state of the inner cell used in this grid. 313 314 Returns: 315 Total size of the state of the inner cell. 316 """ 317 state_sizes = self._cells[0].state_size 318 if isinstance(state_sizes, tuple): 319 return sum(state_sizes) 320 return state_sizes 321 322 323"""Specialized cells, for convenience 324""" 325 326 327class Grid1BasicRNNCell(GridRNNCell): 328 """1D BasicRNN cell""" 329 330 def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True): 331 super(Grid1BasicRNNCell, self).__init__( 332 num_units=num_units, 333 num_dims=1, 334 input_dims=0, 335 output_dims=0, 336 priority_dims=0, 337 tied=False, 338 cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), 339 state_is_tuple=state_is_tuple, 340 output_is_tuple=output_is_tuple) 341 342 343class Grid2BasicRNNCell(GridRNNCell): 344 """2D BasicRNN cell 345 346 This creates a 2D cell which receives input and gives output in the first 347 dimension. 348 349 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 350 specified. 351 """ 352 353 def __init__(self, 354 num_units, 355 tied=False, 356 non_recurrent_fn=None, 357 state_is_tuple=True, 358 output_is_tuple=True): 359 super(Grid2BasicRNNCell, self).__init__( 360 num_units=num_units, 361 num_dims=2, 362 input_dims=0, 363 output_dims=0, 364 priority_dims=0, 365 tied=tied, 366 non_recurrent_dims=None if non_recurrent_fn is None else 0, 367 cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), 368 non_recurrent_fn=non_recurrent_fn, 369 state_is_tuple=state_is_tuple, 370 output_is_tuple=output_is_tuple) 371 372 373class Grid1BasicLSTMCell(GridRNNCell): 374 """1D BasicLSTM cell.""" 375 376 def __init__(self, 377 num_units, 378 forget_bias=1, 379 state_is_tuple=True, 380 output_is_tuple=True): 381 def cell_fn(n): 382 return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) 383 super(Grid1BasicLSTMCell, self).__init__( 384 num_units=num_units, 385 num_dims=1, 386 input_dims=0, 387 output_dims=0, 388 priority_dims=0, 389 tied=False, 390 cell_fn=cell_fn, 391 state_is_tuple=state_is_tuple, 392 output_is_tuple=output_is_tuple) 393 394 395class Grid2BasicLSTMCell(GridRNNCell): 396 """2D BasicLSTM cell. 397 398 This creates a 2D cell which receives input and gives output in the first 399 dimension. 400 401 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 402 specified. 403 """ 404 405 def __init__(self, 406 num_units, 407 tied=False, 408 non_recurrent_fn=None, 409 forget_bias=1, 410 state_is_tuple=True, 411 output_is_tuple=True): 412 def cell_fn(n): 413 return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) 414 super(Grid2BasicLSTMCell, self).__init__( 415 num_units=num_units, 416 num_dims=2, 417 input_dims=0, 418 output_dims=0, 419 priority_dims=0, 420 tied=tied, 421 non_recurrent_dims=None if non_recurrent_fn is None else 0, 422 cell_fn=cell_fn, 423 non_recurrent_fn=non_recurrent_fn, 424 state_is_tuple=state_is_tuple, 425 output_is_tuple=output_is_tuple) 426 427 428class Grid1LSTMCell(GridRNNCell): 429 """1D LSTM cell. 430 431 This is different from Grid1BasicLSTMCell because it gives options to 432 specify the forget bias and enabling peepholes. 433 """ 434 435 def __init__(self, 436 num_units, 437 use_peepholes=False, 438 forget_bias=1.0, 439 state_is_tuple=True, 440 output_is_tuple=True): 441 442 def cell_fn(n): 443 return rnn.LSTMCell( 444 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 445 446 super(Grid1LSTMCell, self).__init__( 447 num_units=num_units, 448 num_dims=1, 449 input_dims=0, 450 output_dims=0, 451 priority_dims=0, 452 cell_fn=cell_fn, 453 state_is_tuple=state_is_tuple, 454 output_is_tuple=output_is_tuple) 455 456 457class Grid2LSTMCell(GridRNNCell): 458 """2D LSTM cell. 459 460 This creates a 2D cell which receives input and gives output in the first 461 dimension. 462 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 463 specified. 464 """ 465 466 def __init__(self, 467 num_units, 468 tied=False, 469 non_recurrent_fn=None, 470 use_peepholes=False, 471 forget_bias=1.0, 472 state_is_tuple=True, 473 output_is_tuple=True): 474 475 def cell_fn(n): 476 return rnn.LSTMCell( 477 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 478 479 super(Grid2LSTMCell, self).__init__( 480 num_units=num_units, 481 num_dims=2, 482 input_dims=0, 483 output_dims=0, 484 priority_dims=0, 485 tied=tied, 486 non_recurrent_dims=None if non_recurrent_fn is None else 0, 487 cell_fn=cell_fn, 488 non_recurrent_fn=non_recurrent_fn, 489 state_is_tuple=state_is_tuple, 490 output_is_tuple=output_is_tuple) 491 492 493class Grid3LSTMCell(GridRNNCell): 494 """3D BasicLSTM cell. 495 496 This creates a 2D cell which receives input and gives output in the first 497 dimension. 498 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 499 specified. 500 The second and third dimensions are LSTM. 501 """ 502 503 def __init__(self, 504 num_units, 505 tied=False, 506 non_recurrent_fn=None, 507 use_peepholes=False, 508 forget_bias=1.0, 509 state_is_tuple=True, 510 output_is_tuple=True): 511 512 def cell_fn(n): 513 return rnn.LSTMCell( 514 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 515 516 super(Grid3LSTMCell, self).__init__( 517 num_units=num_units, 518 num_dims=3, 519 input_dims=0, 520 output_dims=0, 521 priority_dims=0, 522 tied=tied, 523 non_recurrent_dims=None if non_recurrent_fn is None else 0, 524 cell_fn=cell_fn, 525 non_recurrent_fn=non_recurrent_fn, 526 state_is_tuple=state_is_tuple, 527 output_is_tuple=output_is_tuple) 528 529 530class Grid2GRUCell(GridRNNCell): 531 """2D LSTM cell. 532 533 This creates a 2D cell which receives input and gives output in the first 534 dimension. 535 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 536 specified. 537 """ 538 539 def __init__(self, 540 num_units, 541 tied=False, 542 non_recurrent_fn=None, 543 state_is_tuple=True, 544 output_is_tuple=True): 545 super(Grid2GRUCell, self).__init__( 546 num_units=num_units, 547 num_dims=2, 548 input_dims=0, 549 output_dims=0, 550 priority_dims=0, 551 tied=tied, 552 non_recurrent_dims=None if non_recurrent_fn is None else 0, 553 cell_fn=lambda n: rnn.GRUCell(num_units=n), 554 non_recurrent_fn=non_recurrent_fn, 555 state_is_tuple=state_is_tuple, 556 output_is_tuple=output_is_tuple) 557 558 559# Helpers 560 561_GridRNNDimension = namedtuple('_GridRNNDimension', [ 562 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn' 563]) 564 565_GridRNNConfig = namedtuple('_GridRNNConfig', 566 ['num_dims', 'dims', 'inputs', 'outputs', 567 'recurrents', 'priority', 'non_priority', 'tied', 568 'num_units']) 569 570 571def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, 572 ls_non_recurrent_dims, non_recurrent_fn, tied, num_units): 573 def check_dim_list(ls): 574 if ls is None: 575 ls = [] 576 if not isinstance(ls, (list, tuple)): 577 ls = [ls] 578 ls = sorted(set(ls)) 579 if any(_ < 0 or _ >= num_dims for _ in ls): 580 raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, 581 num_dims)) 582 return ls 583 584 input_dims = check_dim_list(ls_input_dims) 585 output_dims = check_dim_list(ls_output_dims) 586 priority_dims = check_dim_list(ls_priority_dims) 587 non_recurrent_dims = check_dim_list(ls_non_recurrent_dims) 588 589 rnn_dims = [] 590 for i in range(num_dims): 591 rnn_dims.append( 592 _GridRNNDimension( 593 idx=i, 594 is_input=(i in input_dims), 595 is_output=(i in output_dims), 596 is_priority=(i in priority_dims), 597 non_recurrent_fn=non_recurrent_fn 598 if i in non_recurrent_dims else None)) 599 return _GridRNNConfig( 600 num_dims=num_dims, 601 dims=rnn_dims, 602 inputs=input_dims, 603 outputs=output_dims, 604 recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], 605 priority=priority_dims, 606 non_priority=[x for x in range(num_dims) if x not in priority_dims], 607 tied=tied, 608 num_units=num_units) 609 610 611def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, 612 first_call): 613 """Propagates through all the cells in dim_indices dimensions. 614 """ 615 if len(dim_indices) == 0: 616 return 617 618 # Because of the way RNNCells are implemented, we take the last dimension 619 # (H_{N-1}) out and feed it as the state of the RNN cell 620 # (in `last_dim_output`). 621 # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs` 622 if conf.num_dims > 1: 623 ls_cell_inputs = [None] * (conf.num_dims - 1) 624 for d in conf.dims[:-1]: 625 if new_output[d.idx] is None: 626 ls_cell_inputs[d.idx] = m_prev[d.idx] 627 else: 628 ls_cell_inputs[d.idx] = new_output[d.idx] 629 cell_inputs = array_ops.concat(ls_cell_inputs, 1) 630 else: 631 cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], 632 m_prev[0].dtype) 633 634 last_dim_output = (new_output[-1] 635 if new_output[-1] is not None else m_prev[-1]) 636 637 for i in dim_indices: 638 d = conf.dims[i] 639 if d.non_recurrent_fn: 640 if conf.num_dims > 1: 641 linear_args = array_ops.concat([cell_inputs, last_dim_output], 1) 642 else: 643 linear_args = last_dim_output 644 with vs.variable_scope('non_recurrent' if conf.tied else 645 'non_recurrent/cell_{}'.format(i)): 646 if conf.tied and not (first_call and i == dim_indices[0]): 647 vs.get_variable_scope().reuse_variables() 648 649 new_output[d.idx] = layers.fully_connected( 650 linear_args, 651 num_outputs=conf.num_units, 652 activation_fn=d.non_recurrent_fn, 653 weights_initializer=(vs.get_variable_scope().initializer or 654 layers.initializers.xavier_initializer), 655 weights_regularizer=vs.get_variable_scope().regularizer) 656 else: 657 if c_prev[i] is not None: 658 cell_state = (c_prev[i], last_dim_output) 659 else: 660 # for GRU/RNN, the state is just the previous output 661 cell_state = last_dim_output 662 663 with vs.variable_scope('recurrent' if conf.tied else 664 'recurrent/cell_{}'.format(i)): 665 if conf.tied and not (first_call and i == dim_indices[0]): 666 vs.get_variable_scope().reuse_variables() 667 cell = cells[i] 668 new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state) 669