1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SequenceQueueingStateSaver and wrappers. 16 17Please see the reading data how-to for context. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import collections 25import numbers 26 27import six 28 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import data_flow_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.ops import sparse_ops 42from tensorflow.python.ops import string_ops 43from tensorflow.python.summary import summary 44from tensorflow.python.training import queue_runner 45 46# pylint: disable=protected-access 47_restore_sparse = sparse_ops._take_many_sparse_from_tensors_map 48_store_sparse = sparse_ops._add_many_sparse_to_tensors_map 49# pylint: enable=protected-access 50 51 52class _SequenceInputWrapper(object): 53 """A wrapper object for storing sequence-related input. 54 55 The SequenceInputWapper accepts four objects: 56 57 length: A scalar int containing the length of the input sequence. 58 key: A scalar string containing the unique key of the input sequence. 59 sequences: A dict mapping labels, like `input`, to tensors 60 whose initial index dimension is at least size `length`. 61 context: A dict mapping labels, like `global_target`, to tensors 62 that represent data across the entire example. 63 """ 64 65 def __init__(self, length, key, sequences, context): 66 length = ops.convert_to_tensor(length, name="length") 67 key = ops.convert_to_tensor(key, name="key") 68 if not isinstance(sequences, dict): 69 raise TypeError("sequences must be a dict") 70 if not isinstance(context, dict): 71 raise TypeError("context must be a dict") 72 if not sequences: 73 raise ValueError("must have at least one sequence tensor") 74 for k in sequences.keys(): 75 if not isinstance(k, six.string_types): 76 raise TypeError("sequence key must be string: %s" % k) 77 if ":" in k: 78 raise ValueError("sequence key may not have a colon: '%s'" % k) 79 for k in context.keys(): 80 if not isinstance(k, six.string_types): 81 raise TypeError("context key must be string: %s" % k) 82 if ":" in k: 83 raise ValueError("context key may not have a colon: '%s'" % k) 84 sequences = dict((k, ops.convert_to_tensor( 85 v, name="sequence_%s" % k)) for k, v in sequences.items()) 86 context = dict((k, ops.convert_to_tensor( 87 v, name="context_%s" % k)) for k, v in context.items()) 88 self._length = length 89 self._key = key 90 self._sequences = sequences 91 self._context = context 92 93 @property 94 def length(self): 95 return self._length 96 97 @property 98 def key(self): 99 return self._key 100 101 @property 102 def sequences(self): 103 return self._sequences 104 105 @property 106 def context(self): 107 return self._context 108 109 110def _check_multiple_of(value, multiple_of): 111 """Checks that value `value` is a non-zero multiple of `multiple_of`. 112 113 Args: 114 value: an int32 scalar Tensor. 115 multiple_of: an int or int32 scalar Tensor. 116 117 Returns: 118 new_value: an int32 scalar Tensor matching `value`, but which includes an 119 assertion that `value` is a multiple of `multiple_of`. 120 """ 121 assert isinstance(value, ops.Tensor) 122 with ops.control_dependencies([ 123 control_flow_ops.Assert( 124 math_ops.logical_and( 125 math_ops.equal(math_ops.mod(value, multiple_of), 0), 126 math_ops.not_equal(value, 0)), [ 127 string_ops.string_join([ 128 "Tensor %s should be a multiple of: " % value.name, 129 string_ops.as_string(multiple_of), ", but saw value: ", 130 string_ops.as_string(value), 131 ". Consider setting pad=True." 132 ]) 133 ]) 134 ]): 135 new_value = array_ops.identity(value, name="multiple_of_checked") 136 return new_value 137 138 139def _check_rank(value, expected_rank): 140 """Check the rank of Tensor `value`, via shape inference and assertions. 141 142 Args: 143 value: A Tensor, possibly with shape associated shape information. 144 expected_rank: int32 scalar (optionally a `Tensor`). 145 146 Returns: 147 new_value: A Tensor matching `value`. Accessing this tensor tests 148 assertions on its rank. If expected_rank is not a `Tensor`, then 149 new_value's shape's rank has been set. 150 151 Raises: 152 ValueError: if `expected_rank` is not a `Tensor` and the rank of `value` 153 is known and is not equal to `expected_rank`. 154 """ 155 assert isinstance(value, ops.Tensor) 156 with ops.control_dependencies([ 157 control_flow_ops.Assert( 158 math_ops.equal(expected_rank, array_ops.rank(value)), [ 159 string_ops.string_join([ 160 "Rank of tensor %s should be: " % value.name, 161 string_ops.as_string(expected_rank), ", shape received:" 162 ]), array_ops.shape(value) 163 ]) 164 ]): 165 new_value = array_ops.identity(value, name="rank_checked") 166 if isinstance(expected_rank, ops.Tensor): 167 expected_rank_value = tensor_util.constant_value(expected_rank) 168 if expected_rank_value is not None: 169 expected_rank = int(expected_rank_value) 170 if not isinstance(expected_rank, ops.Tensor): 171 try: 172 new_value.set_shape(new_value.get_shape().with_rank(expected_rank)) 173 except ValueError as e: 174 raise ValueError("Rank check failed for %s: %s" % (value.name, str(e))) 175 return new_value 176 177 178def _check_shape(value, expected_shape): 179 """Check the shape of Tensor `value`, via shape inference and assertions. 180 181 Args: 182 value: A Tensor, possibly with shape associated shape information. 183 expected_shape: a `TensorShape`, list of `int32`, or a vector `Tensor`. 184 185 Returns: 186 new_value: A Tensor matching `value`. Accessing this tensor tests 187 assertions on its shape. If expected_shape is not a `Tensor`, then 188 new_value's shape has been set. 189 190 Raises: 191 ValueError: if `expected_shape` is not a `Tensor` and the shape of `value` 192 is known and is not equal to `expected_shape`. 193 """ 194 assert isinstance(value, ops.Tensor) 195 if isinstance(expected_shape, tensor_shape.TensorShape): 196 expected_shape = expected_shape.as_list() 197 if isinstance(expected_shape, ops.Tensor): 198 expected_shape_value = tensor_util.constant_value(expected_shape) 199 if expected_shape_value is not None: 200 expected_shape = [int(d) for d in expected_shape_value] 201 if isinstance(expected_shape, ops.Tensor): 202 value = _check_rank(value, array_ops.size(expected_shape)) 203 else: 204 value = _check_rank(value, len(expected_shape)) 205 with ops.control_dependencies([ 206 control_flow_ops.Assert( 207 math_ops.reduce_all( 208 math_ops.equal(expected_shape, array_ops.shape(value))), [ 209 string_ops.string_join([ 210 "Shape of tensor %s should be: " % value.name, 211 string_ops.as_string(expected_shape), 212 ", shape received: ", 213 string_ops.as_string(array_ops.shape(value)) 214 ]) 215 ]) 216 ]): 217 new_value = array_ops.identity(value, name="shape_checked") 218 if not isinstance(expected_shape, ops.Tensor): 219 try: 220 new_value.set_shape(new_value.get_shape().merge_with(expected_shape)) 221 except ValueError as e: 222 raise ValueError("Shape check failed for %s: %s" % (value.name, str(e))) 223 return new_value 224 225 226def _check_dimensions(value, dimensions, expected_sizes, debug_prefix): 227 """Check the dimensions of Tensor `value`, via shape inference and assertions. 228 229 Args: 230 value: A Tensor, with optional / partial shape associated shape information. 231 dimensions: An int list, the dimensions to check. 232 expected_sizes: list of mixed ints and int32 scalar tensors. 233 Optionally also a vector `Tensor`. 234 debug_prefix: A string, used for naming ops and printing debugging messages. 235 236 Returns: 237 new_value: A Tensor matching `value`. Accessing this tensor tests 238 assertions on its shape. If expected_sizes is not a `Tensor`, then 239 new_value's shape has been set for all `dimensions[i]` where 240 `expected_sizes[i]` is not a `Tensor`. 241 242 Raises: 243 TypeError: if any of the input contains invalid types: 244 if `value` is not a `Tensor`. 245 if `dimensions` is not a `list` or `tuple`. 246 ValueError: if input has incorrect sizes or inferred shapes do not match: 247 if `dimensions` contains repeated dimensions. 248 if `expected_sizes` is not a `Tensor` and its length does not match that 249 `dimensions`. 250 if `value`'s shape has a well-defined rank, and one of the values in 251 `dimensions` is equal to or above this rank. 252 if `value`'s shape is well defined for some `dimensions[i]`, and 253 `expected_sizes[i]` is not a `Tensor`, and these two values do 254 not match. 255 """ 256 257 if not isinstance(dimensions, (list, tuple)): 258 raise TypeError("dimensions must be a list or tuple") 259 if len(set(dimensions)) != len(dimensions): 260 raise ValueError("dimensions are not unique: %s" % dimensions) 261 if not isinstance(value, ops.Tensor): 262 raise TypeError("value is not a Tensor: %s" % value) 263 value_shape = value.get_shape() 264 if not isinstance(expected_sizes, ops.Tensor): 265 if len(dimensions) != len(expected_sizes): 266 raise ValueError("len(dimensions) != len(expected_sizes): %d vs. %d" % 267 (len(dimensions), len(expected_sizes))) 268 if value_shape.ndims is not None: 269 if value_shape.ndims <= max(dimensions): 270 raise ValueError( 271 "%s: rank of input is not greater than max(dimensions): " 272 "%d vs. %d" % (debug_prefix, value.get_shape().ndims, 273 max(dimensions))) 274 value_dims = value_shape.as_list() 275 for d, s in zip(dimensions, expected_sizes): 276 if not isinstance(s, ops.Tensor): 277 value_dims[d] = s 278 try: 279 value.set_shape(value.get_shape().merge_with(value_dims)) 280 except ValueError as e: 281 raise ValueError("Dimensions check failed for %s: %s" % 282 (debug_prefix, str(e))) 283 with ops.control_dependencies([ 284 control_flow_ops.Assert( 285 math_ops.equal(expected_size, array_ops.shape(value)[dimension]), [ 286 string_ops.string_join([ 287 "Dimension %d of tensor labeled %s should be: " % 288 (dimension, debug_prefix), 289 string_ops.as_string(expected_size), ", shape received: ", 290 string_ops.as_string(array_ops.shape(value)) 291 ]) 292 ]) for (dimension, expected_size) in zip(dimensions, expected_sizes) 293 ]): 294 new_value = array_ops.identity(value, name="dims_checked_%s" % debug_prefix) 295 return new_value 296 297 298def _prepare_sequence_inputs(inputs, states): 299 """Convert input to tensors and validate shape information. 300 301 Args: 302 inputs: A `_SequenceInputWrapper` instance. 303 states: A dictionary mapping state names to input constants or tensors. 304 305 Returns: 306 The tuple (length, key, sorted_states, sorted_sequences, sorted_context), 307 where each value has been checked for valid shape, and the sorted_* dicts 308 are instances of OrderedDict; with key-value pairs sorted by key. 309 310 Raises: 311 ValueError: if the shapes of inputs.context.values(), states.values(), 312 or inputs.sequences.values() are not fully defined (with the exception 313 of the dimension of any `Tensor` in inputs.sequences.values()). 314 TypeError: if the dtype of length is not int32. 315 """ 316 # Convert state initial values to tensors 317 states = dict((k, ops.convert_to_tensor( 318 v, name="state_%s" % k)) for k, v in states.items()) 319 320 def _assert_fully_defined(label, dict_, ignore_first_dimension=False): 321 start_dimension = 1 if ignore_first_dimension else 0 322 for k, v in dict_.items(): 323 if not v.get_shape()[start_dimension:].is_fully_defined(): 324 raise ValueError("Shape for %s %s is not fully defined %s: %s" % 325 (label, k, "(ignoring first dimension)" if 326 ignore_first_dimension else "", v.get_shape())) 327 328 _assert_fully_defined("state", states) 329 _assert_fully_defined("context", inputs.context) 330 # Sequences' first dimension (time) may be variable 331 _assert_fully_defined( 332 "sequence", inputs.sequences, ignore_first_dimension=True) 333 334 # Get dictionaries' dtypes ordered by name - ordering is important 335 # when switching between dicts and tuples for passing to Barrier. 336 def _sort_by_name(d): 337 return collections.OrderedDict(sorted(d.items(), key=lambda k_v: k_v[0])) 338 339 sorted_sequences = _sort_by_name(inputs.sequences) 340 sorted_context = _sort_by_name(inputs.context) 341 sorted_states = _sort_by_name(states) 342 343 length = _check_rank(inputs.length, 0) 344 key = _check_rank(inputs.key, 0) 345 346 if length.dtype != dtypes.int32: 347 raise TypeError("length dtype must be int32, but received: %s" % 348 length.dtype) 349 if key.dtype != dtypes.string: 350 raise TypeError("key dtype must be string, but received: %s" % key.dtype) 351 352 return (length, key, sorted_states, sorted_sequences, sorted_context) 353 354 355# NextQueuedSequenceBatch works closely with 356# SequenceQueueingStateSaver and requires access to its private properties 357# pylint: disable=protected-access 358class NextQueuedSequenceBatch(object): 359 """NextQueuedSequenceBatch stores deferred SequenceQueueingStateSaver data. 360 361 This class is instantiated by `SequenceQueueingStateSaver` and is accessible 362 via its `next_batch` property. 363 """ 364 365 def __init__(self, state_saver): 366 self._state_saver = state_saver 367 368 @property 369 def total_length(self): 370 """The lengths of the original (non-truncated) unrolled examples. 371 372 Returns: 373 An integer vector of length `batch_size`, the total lengths. 374 """ 375 return self._state_saver._received_total_length 376 377 @property 378 def length(self): 379 """The lengths of the given truncated unrolled examples. 380 381 For initial iterations, for which `sequence * num_unroll < length`, 382 this number is `num_unroll`. For the remainder, 383 this number is between `0` and `num_unroll`. 384 385 Returns: 386 An integer vector of length `batch_size`, the lengths. 387 """ 388 return self._state_saver._received_length 389 390 @property 391 def batch_size(self): 392 """The batch_size of the given batch. 393 394 Usually, this is the batch_size requested when initializing the SQSS, but 395 if allow_small_batch=True this will become smaller when inputs are 396 exhausted. 397 398 Returns: 399 A scalar integer tensor, the batch_size 400 """ 401 return self._state_saver._received_batch_size 402 403 @property 404 def insertion_index(self): 405 """The insertion indices of the examples (when they were first added). 406 407 These indices start with the value -2**63 and increase with every 408 call to the prefetch op. Each whole example gets its own insertion 409 index, and this is used to prioritize the example so that its truncated 410 segments appear in adjacent iterations, even if new examples are inserted 411 by the prefetch op between iterations. 412 413 Returns: 414 An int64 vector of length `batch_size`, the insertion indices. 415 """ 416 return self._state_saver._received_indices 417 418 @property 419 def key(self): 420 """The key names of the given truncated unrolled examples. 421 422 The format of the key is: 423 424 ```python 425 "%05d_of_%05d:%s" % (sequence, sequence_count, original_key) 426 ``` 427 428 where `original_key` is the unique key read in by the prefetcher. 429 430 Returns: 431 A string vector of length `batch_size`, the keys. 432 """ 433 return self._state_saver._received_keys 434 435 @property 436 def next_key(self): 437 """The key names of the next (in iteration) truncated unrolled examples. 438 439 The format of the key is: 440 441 ```python 442 "%05d_of_%05d:%s" % (sequence + 1, sequence_count, original_key) 443 ``` 444 445 if `sequence + 1 < sequence_count`, otherwise: 446 447 ```python 448 "STOP:%s" % original_key 449 ``` 450 451 where `original_key` is the unique key read in by the prefetcher. 452 453 Returns: 454 A string vector of length `batch_size`, the keys. 455 """ 456 return self._state_saver._received_next_key 457 458 @property 459 def sequence(self): 460 """An int32 vector, length `batch_size`: the sequence index of each entry. 461 462 When an input is split up, the sequence values 463 ``` 464 0, 1, ..., sequence_count - 1 465 ``` 466 are assigned to each split. 467 468 Returns: 469 An int32 vector `Tensor`. 470 """ 471 return self._state_saver._received_sequence 472 473 @property 474 def sequence_count(self): 475 """An int32 vector, length `batch_size`: the sequence count of each entry. 476 477 When an input is split up, the number of splits is equal to: 478 `padded_length / num_unroll`. This is the sequence_count. 479 480 Returns: 481 An int32 vector `Tensor`. 482 """ 483 return self._state_saver._received_sequence_count 484 485 @property 486 def context(self): 487 """A dict mapping keys of `input_context` to batched context. 488 489 Returns: 490 A dict mapping keys of `input_context` to tensors. 491 If we had at input: 492 493 ```python 494 context["name"].get_shape() == [d1, d2, ...] 495 ``` 496 497 then for this property: 498 499 ```python 500 context["name"].get_shape() == [batch_size, d1, d2, ...] 501 ``` 502 503 """ 504 return self._state_saver._received_context 505 506 @property 507 def sequences(self): 508 """A dict mapping keys of `input_sequences` to split and rebatched data. 509 510 Returns: 511 A dict mapping keys of `input_sequences` to tensors. 512 If we had at input: 513 514 ```python 515 sequences["name"].get_shape() == [None, d1, d2, ...] 516 ``` 517 518 where `None` meant the sequence time was dynamic, then for this property: 519 520 ```python 521 sequences["name"].get_shape() == [batch_size, num_unroll, d1, d2, ...]. 522 ``` 523 524 """ 525 return self._state_saver._received_sequences 526 527 def state(self, state_name): 528 """Returns batched state tensors. 529 530 Args: 531 state_name: string, matches a key provided in `initial_states`. 532 533 Returns: 534 A `Tensor`: a batched set of states, either initial states (if this is 535 the first run of the given example), or a value as stored during 536 a previous iteration via `save_state` control flow. 537 Its type is the same as `initial_states["state_name"].dtype`. 538 If we had at input: 539 540 ```python 541 initial_states[state_name].get_shape() == [d1, d2, ...], 542 ``` 543 544 then 545 546 ```python 547 state(state_name).get_shape() == [batch_size, d1, d2, ...] 548 ``` 549 550 Raises: 551 KeyError: if `state_name` does not match any of the initial states 552 declared in `initial_states`. 553 """ 554 return self._state_saver._received_states[state_name] 555 556 def save_state(self, state_name, value, name=None): 557 """Returns an op to save the current batch of state `state_name`. 558 559 Args: 560 state_name: string, matches a key provided in `initial_states`. 561 value: A `Tensor`. 562 Its type must match that of `initial_states[state_name].dtype`. 563 If we had at input: 564 565 ```python 566 initial_states[state_name].get_shape() == [d1, d2, ...] 567 ``` 568 569 then the shape of `value` must match: 570 571 ```python 572 tf.shape(value) == [batch_size, d1, d2, ...] 573 ``` 574 575 name: string (optional). The name scope for newly created ops. 576 577 Returns: 578 A control flow op that stores the new state of each entry into 579 the state saver. This op must be run for every iteration that 580 accesses data from the state saver (otherwise the state saver 581 will never progress through its states and run out of capacity). 582 583 Raises: 584 KeyError: if `state_name` does not match any of the initial states 585 declared in `initial_states`. 586 """ 587 if state_name not in self._state_saver._received_states.keys(): 588 raise KeyError("state was not declared: %s" % state_name) 589 default_name = "InputQueueingStateSaver_SaveState" 590 with ops.name_scope(name, default_name, values=[value]): 591 # Place all operations on the CPU. Barriers and queues are only 592 # implemented for CPU, but all the other book-keeping operations 593 # (reshape, shape, range, ...) would be placed on GPUs if available, 594 # unless we explicitly tie them to CPU. 595 with ops.colocate_with(self._state_saver._capacity_queue.queue_ref): 596 indices_where_not_done = array_ops.reshape( 597 array_ops.where( 598 math_ops.logical_not(self._state_saver._sequence_is_done)), 599 [-1]) 600 keeping_next_key = array_ops.gather( 601 self._state_saver._received_next_key, indices_where_not_done) 602 value = _check_shape( 603 array_ops.identity( 604 value, name="convert_%s" % state_name), 605 array_ops.shape(self._state_saver._received_states[state_name])) 606 keeping_state = array_ops.gather(value, indices_where_not_done) 607 return self._state_saver._barrier.insert_many( 608 self._state_saver._get_barrier_index("state", state_name), 609 keeping_next_key, 610 keeping_state, 611 name="BarrierInsertState_%s" % state_name) 612 613 614# pylint: enable=protected-access 615 616 617class SequenceQueueingStateSaver(object): 618 """SequenceQueueingStateSaver provides access to stateful values from input. 619 620 This class is meant to be used instead of, e.g., a `Queue`, for splitting 621 variable-length sequence inputs into segments of sequences with fixed length 622 and batching them into mini-batches. It maintains contexts and state for a 623 sequence across the segments. It can be used in conjunction with a 624 `QueueRunner` (see the example below). 625 626 The `SequenceQueueingStateSaver` (SQSS) accepts one example at a time via the 627 inputs `input_length`, `input_key`, `input_sequences` (a dict), 628 `input_context` (a dict), and `initial_states` (a dict). 629 The sequences, values in `input_sequences`, may have variable first dimension 630 (the `padded_length`), though this dimension must always be a multiple of 631 `num_unroll`. All other dimensions must be fixed and accessible via 632 `get_shape` calls. The length prior to padding can be recorded in 633 `input_length`. The context values in `input_context` must all have fixed and 634 well defined dimensions. The initial state values must all have fixed and 635 well defined dimensions. 636 637 The SQSS splits the sequences of an input example into segments of length 638 `num_unroll`. Across examples minibatches of size `batch_size` are formed. 639 These minibatches contain a segment of the sequences, copy the context values, 640 and maintain state, length, and key information of the original input 641 examples. In the first segment of an example the state is still the initial 642 state. It can then be updated; and updated state values are accessible in 643 subsequent segments of the same example. After each segment 644 `batch.save_state()` must be called which is done by the state_saving_rnn. 645 Without this call, the dequeue op associated with the SQSS will not run. 646 Internally, SQSS has a queue for the input examples. Its `capacity` is 647 configurable. If set smaller than `batch_size` then the dequeue op will block 648 indefinitely. A small multiple of `batch_size` is a good rule of thumb to 649 prevent that queue from becoming a bottleneck and slowing down training. 650 If set too large (and note that it defaults to unbounded) memory consumption 651 goes up. Moreover, when iterating over the same input examples multiple times 652 reusing the same `key` the `capacity` must be smaller than the number of 653 examples. 654 655 The prefetcher, which reads one unrolled, variable-length input sequence at 656 a time, is accessible via `prefetch_op`. The underlying `Barrier` object 657 is accessible via `barrier`. Processed minibatches, as well as 658 state read and write capabilities are accessible via `next_batch`. 659 Specifically, `next_batch` provides access to all of the minibatched 660 data, including the following, see `NextQueuedSequenceBatch` for details: 661 662 * `total_length`, `length`, `insertion_index`, `key`, `next_key`, 663 * `sequence` (the index each minibatch entry's time segment index), 664 * `sequence_count` (the total time segment count for each minibatch entry), 665 * `context` (a dict of the copied minibatched context values), 666 * `sequences` (a dict of the split minibatched variable-length sequences), 667 * `state` (to access the states of the current segments of these entries) 668 * `save_state` (to save the states for the next segments of these entries) 669 670 Example usage: 671 672 ```python 673 batch_size = 32 674 num_unroll = 20 675 lstm_size = 8 676 cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size) 677 initial_state_values = tf.zeros(cell.state_size, dtype=tf.float32) 678 679 raw_data = get_single_input_from_input_reader() 680 length, key, sequences, context = my_parser(raw_data) 681 assert "input" in sequences.keys() 682 assert "label" in context.keys() 683 initial_states = {"lstm_state": initial_state_value} 684 685 stateful_reader = tf.SequenceQueueingStateSaver( 686 batch_size, num_unroll, 687 length=length, input_key=key, input_sequences=sequences, 688 input_context=context, initial_states=initial_states, 689 capacity=batch_size*100) 690 691 batch = stateful_reader.next_batch 692 inputs = batch.sequences["input"] 693 context_label = batch.context["label"] 694 695 inputs_by_time = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1) 696 assert len(inputs_by_time) == num_unroll 697 698 lstm_output, _ = tf.contrib.rnn.static_state_saving_rnn( 699 cell, 700 inputs_by_time, 701 state_saver=batch, 702 state_name="lstm_state") 703 704 # Start a prefetcher in the background 705 sess = tf.Session() 706 num_threads = 3 707 queue_runner = tf.train.QueueRunner( 708 stateful_reader, [stateful_reader.prefetch_op] * num_threads) 709 tf.train.add_queue_runner(queue_runner) 710 tf.train.start_queue_runners(sess=session) 711 712 while True: 713 # Step through batches, perform training or inference... 714 session.run([lstm_output]) 715 ``` 716 717 **Note**: Usually the barrier is given to a QueueRunner as in the 718 examples above. The QueueRunner will close the barrier if the prefetch_op 719 receives an OutOfRange Error from upstream input queues (i.e., reaches 720 the end of the input). If the barrier is closed no further new examples 721 are added to the SQSS. The underlying barrier might, however, still 722 contain further unroll-steps of examples that have not undergone all 723 iterations. To gracefully finish all examples, the flag 724 `allow_small_batch` must be set to true, which causes the SQSS to issue 725 progressively smaller mini-batches with the remaining examples. 726 """ 727 728 def __init__(self, 729 batch_size, 730 num_unroll, 731 input_length, 732 input_key, 733 input_sequences, 734 input_context, 735 initial_states, 736 capacity=None, 737 allow_small_batch=False, 738 name=None): 739 """Creates the SequenceQueueingStateSaver. 740 741 Args: 742 batch_size: int or int32 scalar `Tensor`, how large minibatches should 743 be when accessing the `state()` method and `context`, `sequences`, etc, 744 properties. 745 num_unroll: Python integer, how many time steps to unroll at a time. 746 The input sequences of length `k` are then split into `k / num_unroll` 747 many segments. 748 input_length: An int32 scalar `Tensor`, the length of the sequence prior 749 to padding. This value may be at most `padded_length` for any given 750 input (see below for the definition of `padded_length`). 751 Batched and total lengths of the current iteration are made accessible 752 via the `length` and `total_length` properties. The shape of 753 input_length (scalar) must be fully specified. 754 input_key: A string scalar `Tensor`, the **unique** key for the given 755 input. This is used to keep track of the split minibatch elements 756 of this input. Batched keys of the current iteration are made 757 accessible via the `key` property. The shape of `input_key` (scalar) 758 must be fully specified. 759 input_sequences: A dict mapping string names to `Tensor` values. The 760 values must all have matching first dimension, called `padded_length`. 761 The `SequenceQueueingStateSaver` will split these tensors along 762 this first dimension into minibatch elements of dimension 763 `num_unroll`. Batched and segmented sequences of the current iteration 764 are made accessible via the `sequences` property. 765 766 **Note**: `padded_length` may be dynamic, and may vary from input 767 to input, but must always be a multiple of `num_unroll`. The remainder 768 of the shape (other than the first dimension) must be fully specified. 769 input_context: A dict mapping string names to `Tensor` values. The values 770 are treated as "global" across all time splits of the given input, 771 and will be copied across for all minibatch elements accordingly. 772 Batched and copied context of the current iteration are made 773 accessible via the `context` property. 774 775 **Note**: All input_context values must have fully defined shapes. 776 initial_states: A dict mapping string state names to multi-dimensional 777 values (e.g. constants or tensors). This input defines the set of 778 states that will be kept track of during computing iterations, and 779 which can be accessed via the `state` and `save_state` methods. 780 781 **Note**: All initial_state values must have fully defined shapes. 782 capacity: The max capacity of the SQSS in number of examples. Needs to be 783 at least `batch_size`. Defaults to unbounded. 784 allow_small_batch: If true, the SQSS will return smaller batches when 785 there aren't enough input examples to fill a whole batch and the end of 786 the input has been reached (i.e., the underlying barrier has been 787 closed). 788 name: An op name string (optional). 789 790 Raises: 791 TypeError: if any of the inputs is not an expected type. 792 ValueError: if any of the input values is inconsistent, e.g. if 793 not enough shape information is available from inputs to build 794 the state saver. 795 """ 796 if capacity is not None and isinstance(batch_size, ops.Tensor): 797 with ops.control_dependencies([check_ops.assert_greater_equal( 798 math_ops.cast(capacity, dtype=dtypes.int64), 799 math_ops.cast(batch_size, dtype=dtypes.int64), 800 message="capacity needs to be >= batch_size.")]): 801 input_key = array_ops.identity(input_key) 802 elif capacity is not None and capacity < batch_size: 803 raise ValueError("capacity %d needs to be >= batch_size %d" % ( 804 capacity, batch_size)) 805 # The barrier is ignorant of the number of actual examples, since a long 806 # example that requires many iterations produces more elements in the 807 # barrier than a short example. Furthermore, we don't have an upper bound 808 # on the length of examples, and hence have to keep the capacity of the 809 # barrier at infinite to avoid dead-lock. Instead we have to keep track of 810 # the number of active examples in this class, and block the prefetch_op 811 # when capacity is reached. To this end, we employ a FIFOQueue in which we 812 # store one token (its value doesn't matter) for each input example, and 813 # dequeue a token for each completed example. Since the capacity of this 814 # queue is limited the enqueue operation will block if capacity is reached. 815 self._capacity_queue = data_flow_ops.FIFOQueue( 816 capacity=capacity, dtypes=[dtypes.int32], shapes=[[]]) 817 # Place all operations on the CPU. Barriers and queues are only implemented 818 # for CPU, but all the other book-keeping operations 819 # (reshape, shape, range, ...) would be placed on GPUs if available, 820 # unless we explicitly tie them to CPU. 821 with ops.colocate_with(self._capacity_queue.queue_ref): 822 if not isinstance(initial_states, dict): 823 raise TypeError("initial_states must be a dictionary") 824 if not initial_states: 825 raise ValueError( 826 "initial_states may not be empty: at least one state variable is " 827 "required to properly enqueue split sequences to run in separate " 828 "iterations") 829 for k in initial_states: 830 if not isinstance(k, six.string_types): 831 raise TypeError("state name must be a string: %s" % k) 832 if ":" in k: 833 raise ValueError("state name may not have a colon: '%s'" % k) 834 835 op_vars = ([input_length, input_key] + list(input_sequences.values()) + 836 list(input_context.values())) 837 with ops.name_scope(name, "InputQueueingStateSaver", op_vars) as scope: 838 inputs = _SequenceInputWrapper(input_length, input_key, input_sequences, 839 input_context) 840 self._batch_size = batch_size 841 self._num_unroll = num_unroll 842 self._name = scope 843 844 # This step makes sure all shapes are well defined. We can now 845 # use get_shape() on any tensor in the output of this function 846 # and get a fully-defined shape. 847 (self._length, self._key, self._sorted_states, self._sorted_sequences, 848 self._sorted_context) = _prepare_sequence_inputs(inputs, 849 initial_states) 850 self._padded_length = array_ops.identity( 851 array_ops.shape(six.next(six.itervalues(self._sorted_sequences)))[ 852 0], 853 name="padded_length") # The name is useful for debugging 854 self._padded_length = _check_multiple_of(self._padded_length, 855 self._num_unroll) 856 857 # sequences should have length == all matching 858 self._sorted_sequences = collections.OrderedDict( 859 (k, _check_dimensions( 860 v, [0], [self._padded_length], 861 debug_prefix="sorted_sequences_%s" % k)) 862 for k, v in self._sorted_sequences.items()) 863 self._uninitialized_states = self._sorted_states 864 865 # Once this is set, self._get_barrier_*_index are available for use. 866 self._store_index_maps(self._sorted_sequences, self._sorted_context, 867 self._sorted_states) 868 869 # Make sure that the length is <= the padded_length 870 with ops.control_dependencies([ 871 control_flow_ops.Assert( 872 math_ops.less_equal(self._length, self._padded_length), [ 873 "Input length should be <= than length from sequences:", 874 self._length, " vs. ", self._padded_length 875 ]) 876 ]): 877 self._length = array_ops.identity(self._length) 878 879 # Only create barrier; enqueue and dequeue operations happen when you 880 # access prefetch_op and next_batch. 881 self._create_barrier() 882 self._scope = scope 883 self._allow_small_batch = allow_small_batch 884 self._prefetch_op = None 885 self._next_batch = None 886 887 @property 888 def name(self): 889 return self._name 890 891 @property 892 def barrier(self): 893 return self._barrier 894 895 @property 896 def batch_size(self): 897 return self._batch_size 898 899 @property 900 def num_unroll(self): 901 return self._num_unroll 902 903 @property 904 def prefetch_op(self): 905 """The op used to prefetch new data into the state saver. 906 907 Running it once enqueues one new input example into the state saver. 908 The first time this gets called, it additionally creates the prefetch_op. 909 Subsequent calls simply return the previously created `prefetch_op`. 910 911 It should be run in a separate thread via e.g. a `QueueRunner`. 912 913 Returns: 914 An `Operation` that performs prefetching. 915 """ 916 if not self._prefetch_op: 917 with ops.name_scope(None), ops.name_scope( 918 self._scope, values=[self._barrier.barrier_ref]): 919 self._create_prefetch_op() 920 return self._prefetch_op 921 922 @property 923 def next_batch(self): 924 """The `NextQueuedSequenceBatch` providing access to batched output data. 925 926 Also provides access to the `state` and `save_state` methods. 927 The first time this gets called, it additionally prepares barrier reads 928 and creates `NextQueuedSequenceBatch` / next_batch objects. Subsequent 929 calls simply return the previously created `next_batch`. 930 931 In order to access data in `next_batch` without blocking, the `prefetch_op` 932 must have been run at least `batch_size` times (ideally in a separate 933 thread, or launched via a `QueueRunner`). After processing a segment in 934 `next_batch()`, `batch.save_state()` must be called which is done by the 935 state_saving_rnn. Without this call, the dequeue op associated with the SQSS 936 will not run. 937 938 Returns: 939 A cached `NextQueuedSequenceBatch` instance. 940 """ 941 # This is needed to prevent errors if next_batch is called before 942 # prefetch_op is created. 943 if not self._prefetch_op: 944 with ops.name_scope(None), ops.name_scope( 945 self._scope, values=[self._barrier.barrier_ref]): 946 self._create_prefetch_op() 947 if not self._next_batch: 948 with ops.name_scope(None), ops.name_scope( 949 self._scope, values=[self._barrier.barrier_ref]): 950 self._prepare_barrier_reads() 951 return self._next_batch 952 953 def close(self, cancel_pending_enqueues=False, name=None): 954 """Closes the barrier and the FIFOQueue. 955 956 This operation signals that no more segments of new sequences will be 957 enqueued. New segments of already inserted sequences may still be enqueued 958 and dequeued if there is a sufficient number filling a batch or 959 allow_small_batch is true. Otherwise dequeue operations will fail 960 immediately. 961 962 Args: 963 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 964 `False`. If `True`, all pending enqueues to the underlying queues will 965 be cancelled, and completing already started sequences is not possible. 966 name: Optional name for the op. 967 968 Returns: 969 The operation that closes the barrier and the FIFOQueue. 970 """ 971 with ops.name_scope(name, "SQSSClose", [self._prefetch_op]) as name: 972 barrier_close = self.barrier.close(cancel_pending_enqueues, 973 "BarrierClose") 974 fifo_queue_close = self._capacity_queue.close(cancel_pending_enqueues, 975 "FIFOClose") 976 return control_flow_ops.group(barrier_close, fifo_queue_close, name=name) 977 978 def _store_index_maps(self, sequences, context, states): 979 """Prepares the internal dictionaries _name_to_index and _index_to_name. 980 981 These dictionaries are used to keep track of indices into the barrier. 982 983 Args: 984 sequences: `OrderedDict` of string, `Tensor` pairs. 985 context: `OrderedDict` of string, `Tensor` pairs. 986 states: `OrderedDict` of string, `Tensor` pairs. 987 """ 988 assert isinstance(sequences, dict) 989 assert isinstance(context, dict) 990 assert isinstance(states, dict) 991 self._name_to_index = { 992 name: ix 993 for (ix, name) in enumerate([ 994 "__length", "__total_length", "__next_key", "__sequence", 995 "__sequence_count" 996 ] + ["__sequence__%s" % k for k in sequences.keys()] + [ 997 "__context__%s" % k for k in context.keys() 998 ] + ["__state__%s" % k for k in states.keys()])} 999 self._index_to_name = [ 1000 name 1001 for (name, _) in sorted( 1002 self._name_to_index.items(), key=lambda n_ix: n_ix[1]) 1003 ] 1004 1005 def _get_barrier_length_index(self): 1006 return self._name_to_index["__length"] 1007 1008 def _get_barrier_total_length_index(self): 1009 return self._name_to_index["__total_length"] 1010 1011 def _get_barrier_next_key_index(self): 1012 return self._name_to_index["__next_key"] 1013 1014 def _get_barrier_sequence_index(self): 1015 return self._name_to_index["__sequence"] 1016 1017 def _get_barrier_sequence_count_index(self): 1018 return self._name_to_index["__sequence_count"] 1019 1020 def _get_barrier_index(self, index_type, name): 1021 assert index_type in ("sequence", "context", "state") 1022 key = "__%s__%s" % (index_type, name) 1023 assert key in self._name_to_index, ( 1024 "Requested a name not in the value type %s: %s" % (index_type, name)) 1025 return self._name_to_index[key] 1026 1027 def _create_barrier(self): 1028 """Create the barrier. 1029 1030 This method initializes the Barrier object with the right types and shapes. 1031 """ 1032 # Create the barrier 1033 sequence_dtypes = [v.dtype for k, v in self._sorted_sequences.items()] 1034 context_dtypes = [v.dtype for k, v in self._sorted_context.items()] 1035 state_dtypes = [v.dtype for k, v in self._sorted_states.items()] 1036 types = ([ 1037 dtypes.int32, # length 1038 dtypes.int32, # total_length 1039 dtypes.string, # next_keys 1040 dtypes.int32, # sequence 1041 dtypes.int32 1042 ] # expanded_sequence_count 1043 + sequence_dtypes + context_dtypes + state_dtypes) 1044 sequence_shapes = [ 1045 [self._num_unroll] + self._sorted_sequences[k].get_shape().as_list()[1:] 1046 for k in self._sorted_sequences.keys() 1047 ] 1048 context_shapes = [ 1049 self._sorted_context[k].get_shape().as_list() 1050 for k in self._sorted_context.keys() 1051 ] 1052 state_shapes = [ 1053 self._sorted_states[k].get_shape().as_list() 1054 for k in self._sorted_states.keys() 1055 ] 1056 shapes = ([ 1057 (), # length 1058 (), # total_length 1059 (), # next_keys 1060 (), # sequence 1061 () 1062 ] # expanded_sequence_count 1063 + sequence_shapes + context_shapes + state_shapes) 1064 1065 self._barrier = data_flow_ops.Barrier(types=types, shapes=shapes) 1066 1067 def _create_prefetch_op(self): 1068 """Group insert_many ops and create prefetch_op. 1069 1070 This method implements the "meat" of the logic underlying the 1071 `SequenceQueueingStateSaver`. It performs dynamic reshaping of 1072 sequences, copying of context, and initial insertion of these values, 1073 as well as the key, next_key, sequence, sequence_count, and initial 1074 states into the barrier. 1075 """ 1076 # Step 1: identify how many barrier entries to split this input 1077 # into, store the result as a scalar 1078 sequence_count = math_ops.div(self._padded_length, self._num_unroll) 1079 sequence_count_vec = array_ops.expand_dims(sequence_count, 0) 1080 1081 # The final unrolled sequence's length is num_unroll only in 1082 # the case that num_unroll divides it evenly. 1083 ones = array_ops.ones(sequence_count_vec, dtype=dtypes.int32) 1084 sequence = math_ops.range(sequence_count) 1085 expanded_length = math_ops.maximum( 1086 0, self._length - self._num_unroll * sequence) 1087 expanded_length = math_ops.minimum(self._num_unroll, expanded_length) 1088 expanded_total_length = self._length * ones 1089 expanded_sequence_count = sequence_count * ones 1090 current_keys = string_ops.string_join( 1091 [ 1092 string_ops.as_string( 1093 sequence, width=5, fill="0"), "_of_", string_ops.as_string( 1094 sequence_count, width=5, fill="0"), ":", self._key 1095 ], 1096 name="StringJoinCurrentKeys") 1097 next_keys = array_ops.concat( 1098 [ 1099 array_ops.slice(current_keys, [1], [-1]), array_ops.expand_dims( 1100 string_ops.string_join( 1101 ["STOP:", self._key], name="StringJoinStop"), 1102 0) 1103 ], 1104 0, 1105 name="concat_next_keys") 1106 reshaped_sequences = collections.OrderedDict(( 1107 k, 1108 _check_dimensions( 1109 # Reshape sequences to sequence_count rows 1110 array_ops.reshape( 1111 v, 1112 array_ops.concat( 1113 [ 1114 array_ops.expand_dims(sequence_count, 0), 1115 array_ops.expand_dims(self._num_unroll, 0), 1116 v.get_shape().as_list()[1:] 1117 ], 1118 0, 1119 name="concat_sequences_%s" % k), 1120 name="reshape_sequences_%s" % k), 1121 [0, 1] + list(range(2, v.get_shape().ndims + 1)), 1122 [sequence_count, self._num_unroll] + v.get_shape().as_list()[1:], 1123 debug_prefix="reshaped_sequences_%s" % 1124 k)) for k, v in self._sorted_sequences.items()) 1125 expanded_context = collections.OrderedDict( 1126 ( 1127 k, 1128 _check_dimensions( 1129 # Copy context to be sequence_count rows 1130 array_ops.tile( 1131 array_ops.expand_dims(v, 0), 1132 array_ops.concat( 1133 [ 1134 array_ops.expand_dims(sequence_count, 0), 1135 [1] * v.get_shape().ndims 1136 ], 1137 0, 1138 name="concat_context_%s" % k), 1139 name="tile_context_%s" % k), 1140 [0] + list(range(1, v.get_shape().ndims + 1)), 1141 [sequence_count] + v.get_shape().as_list(), 1142 debug_prefix="expanded_context_%s" % k)) 1143 for k, v in self._sorted_context.items()) 1144 1145 # Storing into the barrier, for each current_key: 1146 # sequence_ix, sequence_count, next_key, length, 1147 # context... (copied), sequences... (truncated) 1148 # Also storing into the barrier for the first key 1149 # states (using initial_states). 1150 insert_sequence_op = self._barrier.insert_many( 1151 self._get_barrier_sequence_index(), 1152 current_keys, 1153 sequence, 1154 name="BarrierInsertSequence") 1155 insert_sequence_count_op = self._barrier.insert_many( 1156 self._get_barrier_sequence_count_index(), 1157 current_keys, 1158 expanded_sequence_count, 1159 name="BarrierInsertSequenceCount") 1160 insert_next_key_op = self._barrier.insert_many( 1161 self._get_barrier_next_key_index(), 1162 current_keys, 1163 next_keys, 1164 name="BarrierInsertNextKey") 1165 insert_length_op = self._barrier.insert_many( 1166 self._get_barrier_length_index(), 1167 current_keys, 1168 expanded_length, 1169 name="BarrierInsertLength") 1170 insert_total_length_op = self._barrier.insert_many( 1171 self._get_barrier_total_length_index(), 1172 current_keys, 1173 expanded_total_length, 1174 name="BarrierInsertTotalLength") 1175 insert_context_ops = dict((name, self._barrier.insert_many( 1176 self._get_barrier_index("context", name), 1177 current_keys, 1178 value, 1179 name="BarrierInsertContext_%s" % name)) 1180 for (name, value) in expanded_context.items()) 1181 insert_sequences_ops = dict((name, self._barrier.insert_many( 1182 self._get_barrier_index("sequence", name), 1183 current_keys, 1184 value, 1185 name="BarrierInsertSequences_%s" % name)) 1186 for (name, value) in reshaped_sequences.items()) 1187 1188 # An op that blocks if we reached capacity in number of active examples. 1189 TOKEN_WITH_IGNORED_VALUE = 21051976 # pylint: disable=invalid-name 1190 insert_capacity_token_op = self._capacity_queue.enqueue( 1191 (TOKEN_WITH_IGNORED_VALUE,)) 1192 1193 # Insert just the initial state. Specifically force this to run 1194 # the insert sequence op *first* so that the Barrier receives 1195 # an insert with *all* the segments and the segments all get the same index. 1196 with ops.control_dependencies( 1197 [insert_sequence_op, insert_capacity_token_op]): 1198 insert_initial_state_ops = dict( 1199 (name, self._barrier.insert_many( 1200 self._get_barrier_index("state", name), 1201 array_ops.stack([current_keys[0]]), 1202 array_ops.stack([value]), 1203 name="BarrierInitialInsertState_%s" % name)) 1204 for (name, value) in self._uninitialized_states.items()) 1205 1206 all_inserts = ([ 1207 insert_capacity_token_op, insert_sequence_op, insert_sequence_count_op, 1208 insert_next_key_op, insert_length_op, insert_total_length_op 1209 ] + list(insert_initial_state_ops.values()) + 1210 list(insert_context_ops.values()) + 1211 list(insert_sequences_ops.values())) 1212 1213 self._prefetch_op = control_flow_ops.group( 1214 *all_inserts, name="StateSaverPrefetchGroup") 1215 1216 def _prepare_barrier_reads(self): 1217 """Creates ops for reading the barrier, as used by properties like `length`. 1218 """ 1219 # Ops for reading from the barrier. These ops must be run in a 1220 # different thread than the prefetcher op to avoid blocking. 1221 received = self._barrier.take_many( 1222 self._batch_size, self._allow_small_batch, name="BarrierTakeMany") 1223 1224 self._received_indices = received[0] 1225 self._received_keys = received[1] 1226 received_values = received[2] 1227 1228 self._received_sequence = received_values[self._get_barrier_sequence_index( 1229 )] 1230 self._received_sequence_count = received_values[ 1231 self._get_barrier_sequence_count_index()] 1232 self._received_next_key = received_values[self._get_barrier_next_key_index( 1233 )] 1234 self._received_length = received_values[self._get_barrier_length_index()] 1235 self._received_total_length = received_values[ 1236 self._get_barrier_total_length_index()] 1237 self._received_context = collections.OrderedDict( 1238 (name, received_values[self._get_barrier_index("context", name)]) 1239 for name in self._sorted_context.keys()) 1240 self._received_sequences = collections.OrderedDict( 1241 (name, received_values[self._get_barrier_index("sequence", name)]) 1242 for name in self._sorted_sequences.keys()) 1243 1244 self._received_batch_size = array_ops.squeeze( 1245 array_ops.shape(self._received_length)) 1246 1247 # Which examples are we done with? 1248 self._sequence_is_done = ( 1249 self._received_sequence + 1 >= self._received_sequence_count) 1250 1251 # Compute the number of finished sequences and dequeue as many tokens from 1252 # the capacity queue. 1253 finished_sequences = (math_ops.reduce_sum( 1254 math_ops.cast(self._sequence_is_done, dtypes.int32))) 1255 # TODO(ebrevdo): convert to dequeue_up_to when FIFOQueue supports it. 1256 dequeue_op = self._capacity_queue.dequeue_many(finished_sequences) 1257 1258 # Tie the dequeue_op to the received_state, such that it is definitely 1259 # carried out. 1260 with ops.control_dependencies([dequeue_op]): 1261 self._received_states = collections.OrderedDict( 1262 (name, array_ops.identity(received_values[self._get_barrier_index( 1263 "state", name)])) for name in self._sorted_states.keys()) 1264 self._next_batch = NextQueuedSequenceBatch(self) 1265 1266 1267def batch_sequences_with_states(input_key, 1268 input_sequences, 1269 input_context, 1270 input_length, 1271 initial_states, 1272 num_unroll, 1273 batch_size, 1274 num_threads=3, 1275 capacity=1000, 1276 allow_small_batch=True, 1277 pad=True, 1278 make_keys_unique=False, 1279 make_keys_unique_seed=None, 1280 name=None): 1281 """Creates batches of segments of sequential input. 1282 1283 This method creates a `SequenceQueueingStateSaver` (SQSS) and adds it to 1284 the queuerunners. It returns a `NextQueuedSequenceBatch`. 1285 1286 It accepts one example at a time identified by a unique `input_key`. 1287 `input_sequence` is a dict with values that are tensors with time as first 1288 dimension. This time dimension must be the same across those tensors of an 1289 example. It can vary across examples. Although it always has to be a multiple 1290 of `num_unroll`. Hence, padding may be necessary and it is turned on by 1291 default by `pad=True`. 1292 1293 `input_length` is a Tensor scalar or an int recording the time dimension prior 1294 to padding. It should be between 0 and the time dimension. One reason we want 1295 to keep track of it is so that we can take it into consideration when 1296 computing the loss. If `pad=True` then `input_length` can be `None` and will 1297 be inferred. 1298 1299 This methods segments `input_sequence` into segments of length `num_unroll`. 1300 It batches input sequences from `batch_size` many examples. These mini-batches 1301 are available through the `sequence` property of the output. Moreover, for 1302 each entry in the batch we can access its original `input_key` in `key` and 1303 its input length in `total_length`. `length` records within this segment how 1304 many non-padded time steps there are. 1305 1306 Static features of an example that do not vary across time can be part of the 1307 `input_context`, a dict with Tensor values. This method copies the context for 1308 each segment and makes it available in the `context` of the output. 1309 1310 This method can maintain and update a state for each example. It accepts some 1311 initial_states as a dict with Tensor values. The first mini-batch an example 1312 is contained has initial_states as entry of the `state`. If save_state is 1313 called then the next segment will have the updated entry of the `state`. 1314 See `NextQueuedSequenceBatch` for a complete list of properties and methods. 1315 1316 Example usage: 1317 1318 ```python 1319 batch_size = 32 1320 num_unroll = 20 1321 num_enqueue_threads = 3 1322 lstm_size = 8 1323 cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size) 1324 1325 key, sequences, context = my_parser(raw_data) 1326 initial_state_values = tf.zeros((state_size,), dtype=tf.float32) 1327 initial_states = {"lstm_state": initial_state_values} 1328 batch = tf.batch_sequences_with_states( 1329 input_key=key, 1330 input_sequences=sequences, 1331 input_context=context, 1332 input_length=tf.shape(sequences["input"])[0], 1333 initial_states=initial_states, 1334 num_unroll=num_unroll, 1335 batch_size=batch_size, 1336 num_threads=num_enqueue_threads, 1337 capacity=batch_size * num_enqueue_threads * 2) 1338 1339 inputs = batch.sequences["input"] 1340 context_label = batch.context["label"] 1341 1342 inputs_by_time = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1) 1343 assert len(inputs_by_time) == num_unroll 1344 1345 lstm_output, _ = tf.contrib.rnn.static_state_saving_rnn( 1346 cell, 1347 inputs_by_time, 1348 state_saver=batch, 1349 state_name="lstm_state") 1350 1351 # Start a prefetcher in the background 1352 sess = tf.Session() 1353 1354 tf.train.start_queue_runners(sess=session) 1355 1356 while True: 1357 # Step through batches, perform training or inference... 1358 session.run([lstm_output]) 1359 ``` 1360 1361 Args: 1362 input_key: A string scalar `Tensor`, the **unique** key for the given 1363 input example. This is used to keep track of the split minibatch elements 1364 of this input. Batched keys of the current iteration are made 1365 accessible via the `key` property. The shape of `input_key` (scalar) must 1366 be fully specified. Consider setting `make_keys_unique` to True when 1367 iterating over the same input multiple times. 1368 1369 **Note**: if `make_keys_unique=False` then `input_key`s must be unique. 1370 input_sequences: A dict mapping string names to `Tensor` values. The values 1371 must all have matching first dimension, called `value_length`. They may 1372 vary from input to input. The remainder of the shape (other than the first 1373 dimension) must be fully specified. 1374 The `SequenceQueueingStateSaver` will split these tensors along 1375 this first dimension into minibatch elements of dimension `num_unrolled`. 1376 Batched and segmented sequences of the current iteration are made 1377 accessible via the `sequences` property. 1378 1379 **Note**: if `pad=False`, then `value_length` must always be a multiple 1380 of `num_unroll`. 1381 input_context: A dict mapping string names to `Tensor` values. The values 1382 are treated as "global" across all time splits of the given input example, 1383 and will be copied across for all minibatch elements accordingly. 1384 Batched and copied context of the current iteration are made 1385 accessible via the `context` property. 1386 1387 **Note**: All input_context values must have fully defined shapes. 1388 input_length: None or an int32 scalar `Tensor`, the length of the sequence 1389 prior to padding. If `input_length=None` and `pad=True` then the length 1390 will be inferred and will be equal to `value_length`. If `pad=False` then 1391 `input_length` cannot be `None`: `input_length` must be specified. Its 1392 shape of `input_length` (scalar) must be fully specified. Its value may be 1393 at most `value_length` for any given input (see above for the definition 1394 of `value_length`). Batched and total lengths of the current iteration are 1395 made accessible via the `length` and `total_length` properties. 1396 initial_states: A dict mapping string state names to multi-dimensional 1397 values (e.g. constants or tensors). This input defines the set of 1398 states that will be kept track of during computing iterations, and 1399 which can be accessed via the `state` and `save_state` methods. 1400 1401 **Note**: All initial_state values must have fully defined shapes. 1402 num_unroll: Python integer, how many time steps to unroll at a time. 1403 The input sequences of length k are then split into k / num_unroll many 1404 segments. 1405 batch_size: int or int32 scalar `Tensor`, how large minibatches should 1406 be when accessing the `state()` method and `context`, `sequences`, etc, 1407 properties. 1408 num_threads: The int number of threads enqueuing input examples into a 1409 queue. 1410 capacity: The max capacity of the queue in number of examples. Needs to be 1411 at least `batch_size`. Defaults to 1000. When iterating over the same 1412 input example multiple times reusing their keys the `capacity` must be 1413 smaller than the number of examples. 1414 allow_small_batch: If true, the queue will return smaller batches when 1415 there aren't enough input examples to fill a whole batch and the end of 1416 the input has been reached. 1417 pad: If `True`, `input_sequences` will be padded to multiple of 1418 `num_unroll`. In that case `input_length` may be `None` and is assumed to 1419 be the length of first dimension of values in `input_sequences` 1420 (i.e. `value_length`). 1421 make_keys_unique: Whether to append a random integer to the `input_key` in 1422 an effort to make it unique. The seed can be set via 1423 `make_keys_unique_seed`. 1424 make_keys_unique_seed: If `make_keys_unique=True` this fixes the seed with 1425 which a random postfix is generated. 1426 name: An op name string (optional). 1427 1428 Returns: 1429 A NextQueuedSequenceBatch with segmented and batched inputs and their 1430 states. 1431 1432 Raises: 1433 TypeError: if any of the inputs is not an expected type. 1434 ValueError: if any of the input values is inconsistent, e.g. if 1435 not enough shape information is available from inputs to build 1436 the state saver. 1437 """ 1438 tensor_list = (list(input_sequences.values()) + list(input_context.values()) + 1439 list(initial_states.values())) 1440 with ops.name_scope(name, "batch_sequences_with_states", tensor_list) as name: 1441 if pad: 1442 length, input_sequences = _padding(input_sequences, num_unroll) 1443 input_length = input_length if input_length is not None else length 1444 elif input_sequences: 1445 # Assert that value_length is a multiple of num_unroll. 1446 checked_input_sequences = {} 1447 for key, value in input_sequences.items(): 1448 if (isinstance(value, sparse_tensor.SparseTensor) or 1449 isinstance(value, sparse_tensor.SparseTensorValue)): 1450 value_length = value.dense_shape[0] 1451 with ops.control_dependencies([ 1452 control_flow_ops.Assert( 1453 math_ops.logical_and( 1454 math_ops.equal(value_length % num_unroll, 0), 1455 math_ops.not_equal(value_length, 0)), 1456 [ 1457 string_ops.string_join([ 1458 "SparseTensor %s first dimension should be a " 1459 "multiple of: " % key, 1460 string_ops.as_string(num_unroll), 1461 ", but saw value: ", 1462 string_ops.as_string(value_length), 1463 ". Consider setting pad=True."])])]): 1464 checked_input_sequences[key] = sparse_tensor.SparseTensor( 1465 indices=array_ops.identity( 1466 value.indices, name="multiple_of_checked"), 1467 values=array_ops.identity( 1468 value.values, name="multiple_of_checked"), 1469 dense_shape=array_ops.identity( 1470 value.dense_shape, name="multiple_of_checked")) 1471 else: 1472 if not isinstance(value, ops.Tensor): 1473 try: 1474 value = ops.convert_to_tensor(value) 1475 except TypeError: 1476 raise TypeError( 1477 "Unsupported input_sequences expected Tensor or SparseTensor " 1478 "values, got: %s for key %s" % (str(type(value)), key)) 1479 value_length = array_ops.shape(value)[0] 1480 with ops.control_dependencies([ 1481 control_flow_ops.Assert( 1482 math_ops.logical_and( 1483 math_ops.equal(value_length % num_unroll, 0), 1484 math_ops.not_equal(value_length, 0)), 1485 [ 1486 string_ops.string_join([ 1487 "Tensor %s first dimension should be a multiple " 1488 "of: " % key, 1489 string_ops.as_string(num_unroll), 1490 ", but saw value: ", 1491 string_ops.as_string(value_length), 1492 ". Consider setting pad=True." 1493 ]) 1494 ]) 1495 ]): 1496 checked_input_sequences[key] = array_ops.identity( 1497 value, name="multiple_of_checked") 1498 input_sequences = checked_input_sequences 1499 # Move SparseTensors in context into input_sequences. 1500 _move_sparse_tensor_out_context(input_context, input_sequences, num_unroll) 1501 # Deconstruct SparseTensors in sequence into a dense Tensor before inputting 1502 # to SQSS. 1503 (transformed_input_seq, 1504 sparse_tensor_keys, 1505 tensor_list) = _deconstruct_sparse_tensor_seq(input_sequences) 1506 1507 if make_keys_unique: 1508 input_key = string_ops.string_join([ 1509 input_key, 1510 string_ops.as_string( 1511 random_ops.random_uniform( 1512 (), minval=0, maxval=100000000, dtype=dtypes.int32, 1513 seed=make_keys_unique_seed))]) 1514 1515 # setup stateful queue reader 1516 stateful_reader = SequenceQueueingStateSaver( 1517 batch_size, 1518 num_unroll, 1519 input_length=input_length, 1520 input_key=input_key, 1521 input_sequences=transformed_input_seq, 1522 input_context=input_context, 1523 initial_states=initial_states, 1524 capacity=capacity, 1525 allow_small_batch=allow_small_batch) 1526 1527 barrier = stateful_reader.barrier 1528 summary.scalar("queue/%s/ready_segment_batches_" % barrier.name, 1529 math_ops.cast(barrier.ready_size(), dtypes.float32)) 1530 1531 q_runner = queue_runner.QueueRunner( 1532 stateful_reader, [stateful_reader.prefetch_op] * num_threads, 1533 queue_closed_exception_types=(errors.OutOfRangeError, 1534 errors.CancelledError)) 1535 queue_runner.add_queue_runner(q_runner) 1536 batch = stateful_reader.next_batch 1537 1538 # Reconstruct SparseTensors in sequence. 1539 _reconstruct_sparse_tensor_seq( 1540 batch.sequences, 1541 sparse_tensor_keys, 1542 tensor_list, 1543 batch_size, 1544 num_unroll) 1545 # Move select SparseTensors back to context. 1546 _move_sparse_tensor_in_context(batch.context, batch.sequences) 1547 return batch 1548 1549 1550def _padding(sequences, num_unroll): 1551 """For a dictionary of sequences, pads tensors to a multiple of `num_unroll`. 1552 1553 Args: 1554 sequences: dictionary with `Tensor` values. 1555 num_unroll: int specifying to what multiple to pad sequences to. 1556 Returns: 1557 length: Scalar `Tensor` of dimension 0 of all the values in sequences. 1558 padded_sequence: Dictionary of sequences that are padded to a multiple of 1559 `num_unroll`. 1560 Raises: 1561 ValueError: If `num_unroll` not an int or sequences not a dictionary from 1562 string to `Tensor`. 1563 """ 1564 if not isinstance(num_unroll, numbers.Integral): 1565 raise ValueError("Unsupported num_unroll expected int, got: %s" % 1566 str(num_unroll)) 1567 if not isinstance(sequences, dict): 1568 raise TypeError("Unsupported sequences expected dict, got: %s" % 1569 str(sequences)) 1570 for key, value in sequences.items(): 1571 if not isinstance(key, six.string_types): 1572 raise TypeError("Unsupported sequences key expected string, got: %s" % 1573 str(key)) 1574 if not sequences: 1575 return 0, {} 1576 1577 # Sort 'sequences_dict' so 'length' will have a predictable value below. 1578 sequences_dict = collections.OrderedDict() 1579 for key, value in sorted(sequences.items()): 1580 if not (isinstance(value, sparse_tensor.SparseTensor) or 1581 isinstance(value, sparse_tensor.SparseTensorValue)): 1582 sequences_dict[key] = ops.convert_to_tensor(value) 1583 else: 1584 sequences_dict[key] = value 1585 1586 lengths = [array_ops.shape(value)[0] for value in sequences_dict.values() 1587 if isinstance(value, ops.Tensor)] 1588 if lengths: 1589 length = lengths[0] 1590 all_lengths_equal = [ 1591 control_flow_ops.Assert( 1592 math_ops.equal(l, length), [string_ops.string_join( 1593 ["All sequence lengths must match, but received lengths: ", 1594 string_ops.as_string(lengths)])]) 1595 for l in lengths] 1596 length = control_flow_ops.with_dependencies(all_lengths_equal, length) 1597 else: # Only have SparseTensors 1598 sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values() 1599 if isinstance(value, sparse_tensor.SparseTensor)] 1600 length = math_ops.reduce_max(math_ops.cast(sparse_lengths, dtypes.int32)) 1601 1602 unroll = array_ops.constant(num_unroll) 1603 padded_length = length + ((unroll - (length % unroll)) % unroll) 1604 padded_sequences = {} 1605 for key, value in sequences_dict.items(): 1606 if isinstance(value, ops.Tensor): 1607 # 1. create shape of paddings 1608 # first dimension of value will be increased by num_paddings to 1609 # padded_length 1610 num_paddings = [padded_length - array_ops.shape(value)[0]] 1611 # the shape of the paddings that we concat with the original value will be 1612 # [num_paddings, tf.shape(value)[1], tf.shape(value)[2], ..., 1613 # tf.shape(value)[tf.rank(value) - 1])] 1614 padding_shape = array_ops.concat( 1615 (num_paddings, array_ops.shape(value)[1:]), 0) 1616 # 2. fill padding shape with dummies 1617 dummy = array_ops.constant( 1618 "" if value.dtype == dtypes.string else 0, dtype=value.dtype) 1619 paddings = array_ops.fill(dims=padding_shape, value=dummy) 1620 # 3. concat values with paddings 1621 padded_sequences[key] = array_ops.concat([value, paddings], 0) 1622 else: 1623 padded_shape = array_ops.concat( 1624 [[math_ops.cast(padded_length, dtypes.int64)], value.dense_shape[1:]], 1625 0) 1626 padded_sequences[key] = sparse_tensor.SparseTensor( 1627 indices=value.indices, 1628 values=value.values, 1629 dense_shape=padded_shape) 1630 return length, padded_sequences 1631 1632 1633_SPARSE_CONTEXT_PREFIX_KEY = "_context_in_seq_" 1634 1635 1636def _move_sparse_tensor_out_context(input_context, input_sequences, num_unroll): 1637 """Moves `SparseTensor`s from `input_context` into `input_sequences` as seq. 1638 1639 For `key, value` pairs in `input_context` with `SparseTensor` `value` removes 1640 them from `input_context` and transforms the `value` into a sequence and 1641 then adding `key`, transformed `value` into `input_sequences`. 1642 The transformation is done by adding a new first dimension of `value_length` 1643 equal to that of the other values in input_sequences` and tiling the `value` 1644 every `num_unroll` steps. 1645 1646 Args: 1647 input_context: dictionary with `Tensor` or `SparseTensor` values. To be 1648 modified to take out `SparseTensor` values. 1649 input_sequences: dictionary with `Tensor` or `SparseTensor` values. To be 1650 modified to add transformed `SparseTensor` values from `input_context`. 1651 num_unroll: int specifying to what multiple to pad sequences to. 1652 """ 1653 value_length = array_ops.constant(1) 1654 if input_sequences: 1655 seq = list(input_sequences.values())[0] 1656 if isinstance(seq, ops.Tensor): 1657 with ops.control_dependencies([seq]): 1658 value_length = array_ops.shape(seq)[0] 1659 else: 1660 value_length = seq.dense_shape[0] 1661 value_length = math_ops.cast(value_length, dtype=dtypes.int64) 1662 def _copy_sparse_tensor(sp_tensor): 1663 """Operation to tile a sparse tensor along a newly added 0 dimension. 1664 1665 Adding a new first dimension of `value_length` and tiling the `sp_tensor` 1666 every `num_unroll` steps. 1667 1668 Args: 1669 sp_tensor: `SparseTensor`. 1670 Returns: 1671 `SparseTensor` sequence with `sp_tensor` tiled. 1672 """ 1673 n = value_length // num_unroll 1674 n = math_ops.cast(n, dtype=dtypes.int32) 1675 values = array_ops.tile(sp_tensor.values, array_ops.expand_dims(n, 0)) 1676 shape = array_ops.concat( 1677 [array_ops.expand_dims(value_length, 0), sp_tensor.dense_shape], 0) 1678 1679 # Construct new indices by multiplying old ones and prepending [0, n). 1680 # First multiply indices n times along a newly created 0-dimension. 1681 multiplied_indices = array_ops.tile( 1682 array_ops.expand_dims(sp_tensor.indices, 0), 1683 array_ops.stack([n, 1, 1])) 1684 1685 # Construct indicator for [0, n). 1686 # [ [ [0] [0] ... [0] ] 1687 # [ [num_unroll] [num_unroll] ... [num_unroll] ] 1688 # ... 1689 # [ [num_unroll*(n-1)] [num_unroll*(n-1)] ... [num_unroll*(n-1)] ] ] 1690 # of shape [n, shape(sp_tensor.indices)[0], 1] 1691 # Get current dimensions of indices. 1692 dim0 = array_ops.shape(sp_tensor.indices)[0] 1693 dim1 = array_ops.shape(sp_tensor.indices)[1] 1694 ind = math_ops.range(start=0, limit=value_length, delta=num_unroll) 1695 1696 # ind.set_shape([n]) 1697 ind = array_ops.expand_dims(ind, 1) 1698 ind = array_ops.expand_dims(ind, 2) 1699 ind = array_ops.tile(ind, [1, dim0, 1]) 1700 1701 # Concatenate both and reshape. 1702 indices = array_ops.concat([ind, multiplied_indices], 2) 1703 indices = array_ops.reshape(indices, [dim0 * n, dim1 + 1]) 1704 1705 return sparse_tensor.SparseTensor(indices=indices, 1706 values=values, 1707 dense_shape=shape) 1708 1709 sparse_tensor_keys = [ 1710 k for k in sorted(input_context.keys()) 1711 if (isinstance(input_context[k], sparse_tensor.SparseTensor) or 1712 isinstance(input_context[k], sparse_tensor.SparseTensorValue))] 1713 for key in sparse_tensor_keys: 1714 input_sequences[_SPARSE_CONTEXT_PREFIX_KEY + key] = _copy_sparse_tensor( 1715 input_context[key]) 1716 del input_context[key] 1717 1718 1719def _move_sparse_tensor_in_context(context, sequences): 1720 sparse_tensor_keys = [ 1721 k for k in sorted(sequences) if k.startswith(_SPARSE_CONTEXT_PREFIX_KEY)] 1722 for key in sparse_tensor_keys: 1723 new_key = key[len(_SPARSE_CONTEXT_PREFIX_KEY):] 1724 sp_tensor = sequences[key] 1725 # Take out time dimension. 1726 sp_tensor = sparse_tensor.SparseTensor( 1727 sp_tensor.indices, # with only 0s at column 1 representing time. 1728 sp_tensor.values, 1729 array_ops.concat( 1730 [[sp_tensor.dense_shape[0]], # batch 1731 [1], # time 1732 sp_tensor.dense_shape[2:]], # SparseTensor shape prior to batching 1733 0)) 1734 new_shape = array_ops.concat( 1735 [[sp_tensor.dense_shape[0]], sp_tensor.dense_shape[2:]], 0) 1736 context[new_key] = sparse_ops.sparse_reshape(sp_tensor, new_shape) 1737 del sequences[key] 1738 1739 1740def _deconstruct_sparse_tensor_seq(input_sequence, shared_name=None): 1741 """Converts `SparseTensor` values into `Tensors` of IDs and meta data. 1742 1743 Given a dict of keys -> `Tensor` or `SparseTensor` transforms the 1744 `SparseTensor` values into `Tensor` values of IDs by calling `_store_sparse`. 1745 The IDs are pointers into and underlying `SparseTensorsMap` that is being 1746 constructed. Additional meta data is returned in order to be able to 1747 reconstruct `SparseTensor` values after batching and segmenting the IDs 1748 `Tensor`. 1749 1750 Args: 1751 input_sequence: dictionary with `Tensor` or `SparseTensor` values. 1752 shared_name: The shared name for the underlying `SparseTensorsMap` 1753 (optional, defaults to the name of the newly created op). 1754 Returns: 1755 A tuple `(sequence, sparse_tensor_keys, tensor_list)` where `sequence` is 1756 dictionary with the same keys as `input_sequence` but only `Tensor` values, 1757 `sparse_tensor_keys` is a list of the keys of the `SparseTensor` values that 1758 were converted, and `tensor_list` is a list of the same length with 1759 `Tensor` objects. 1760 """ 1761 sparse_tensor_keys = [ 1762 k for k in sorted(input_sequence.keys()) 1763 if (isinstance(input_sequence[k], sparse_tensor.SparseTensor) or 1764 isinstance(input_sequence[k], sparse_tensor.SparseTensorValue))] 1765 if not sparse_tensor_keys: 1766 return input_sequence, None, sparse_tensor_keys 1767 sparse_tensor_list = [input_sequence[k] for k in sparse_tensor_keys] 1768 tensor_list = [_store_sparse(sp_tensor, shared_name=shared_name) 1769 for sp_tensor in sparse_tensor_list] 1770 transformed_input_seq = dict(input_sequence) 1771 tensor_op_list = [] 1772 for i, k in enumerate(sparse_tensor_keys): 1773 transformed_input_seq[k] = tensor_list[i] 1774 tensor_op_list += [tensor_list[i].op] 1775 return transformed_input_seq, sparse_tensor_keys, tensor_op_list 1776 1777 1778def _reconstruct_sparse_tensor_seq(sequence, 1779 sparse_tensor_keys, 1780 tensor_op_list, 1781 batch_size, 1782 num_unroll): 1783 """Inverse of _deconstruct_sparse_tensor_seq. 1784 1785 Given a dict of keys -> `Tensor` reconstructs `SparseTensor` values for keys 1786 in `sparse_tensor_keys`. Their `Tensor` values are assumed to be IDs into the 1787 underlying `SparseTensorsMap`. The `dense_shape` of the `SparseTensor`s is 1788 `[batch_size, num_unroll, d_0, d_1, ..., d_n]` when the original 1789 `SparseTensor` that got deconstructed with `_deconstruct_sparse_tensor_seq` 1790 has a `dense_shape` of `[None, d_0, d_1, ..., d_n]`. 1791 1792 Args: 1793 sequence: dictionary with only `Tensor` values that is being updated. 1794 sparse_tensor_keys: list of the keys present in `sequence` identifying 1795 `SparseTensor` values that should be reconstructed. 1796 tensor_op_list: list of the same length as `sparse_tensor_keys` with 1797 `Tensor` objects. 1798 batch_size: int or int32 scalar `Tensor`, how large minibatches should 1799 be. 1800 num_unroll: Python integer, how many time steps were unrolled at a time. 1801 """ 1802 def _flatten_tensor(tensor): 1803 """Flattens `Tensor` of `shape [batch_size, num_unroll]` into 1D `Tensor`. 1804 1805 The main use of this function is to work around the limitation of 1806 `_restore_sparse` to only accept 1D handles. 1807 1808 Args: 1809 tensor: 2D `Tensor` of `shape [batch_size, num_unroll]` 1810 Returns: 1811 1D `Tensor`. 1812 """ 1813 return array_ops.reshape(tensor, [-1]) 1814 1815 def _unflatten_sparse_tensor(sp_tensor): 1816 """Recreates `[batch_size, num_unroll]` dimensions in the `SparseTensor`. 1817 1818 Counter-part of `_flatten_tensor` which is called on the input of 1819 `_restore_sparse` while this method is called on the output of it. 1820 Together they work around the limitation of `_restore_sparse` to only 1821 accept 1D handles. 1822 1823 The `indices` in `sp_tensor` is a 2D `Tensor` of `shape [N, ndims]`, where 1824 `N` is the number of `values` and `ndims` is the number of dimension in its 1825 dense counterpart. Among `ndims` the first entry corresponds to the batch 1826 dimension `[0, num_unroll * batch_size)` from which we need to recreate the 1827 2 dimensions `batch_size` and `num_unroll`. 1828 1829 The reason this reconstruction works is because the output of 1830 `_restore_sparse` despite being a `SparseTensor` is actually dense w.r.t. 1831 that first entry. 1832 1833 Args: 1834 sp_tensor: A SparseTensor. 1835 Returns: 1836 A SparseTensor with a +1 higher rank than the input. 1837 """ 1838 idx_batch = math_ops.cast( 1839 math_ops.floor(sp_tensor.indices[:, 0] / num_unroll), dtypes.int64) 1840 idx_time = math_ops.mod(sp_tensor.indices[:, 0], num_unroll) 1841 indices = array_ops.concat( 1842 [ 1843 array_ops.expand_dims(idx_batch, 1), 1844 array_ops.expand_dims(idx_time, 1), sp_tensor.indices[:, 1:] 1845 ], 1846 axis=1) 1847 dense_shape = array_ops.concat( 1848 [[math_ops.cast(batch_size, dtype=dtypes.int64)], 1849 [math_ops.cast(num_unroll, dtype=dtypes.int64)], 1850 sp_tensor.dense_shape[1:]], axis=0) 1851 return sparse_tensor.SparseTensor( 1852 indices=indices, 1853 values=sp_tensor.values, 1854 dense_shape=dense_shape) 1855 1856 if not sparse_tensor_keys: 1857 return 1858 tensor_list = [sequence[k] for k in sparse_tensor_keys] 1859 sp_tensors = [ 1860 _restore_sparse(sparse_map_op=i, 1861 # Flatten the 2D Tensor [batch_size, num_unroll] of 1862 # handles to a 1D Tensor. 1863 # Reconstruct the dimensions later. 1864 # TODO(b/34247140): Remove this workaround. 1865 sparse_handles=_flatten_tensor(s), rank=None) 1866 for i, s in zip(tensor_op_list, tensor_list)] 1867 num_unroll = ops.convert_to_tensor(num_unroll, dtype=dtypes.int64, 1868 name="num_unroll_int64") 1869 1870 # Recreate the [batch_size, num_unroll] dimensions in the SparseTensors. 1871 # The dense_shape will have a +1 higher rank. 1872 # TODO(b/34247140): Remove this workaround. 1873 sp_tensors_higher_dim = [_unflatten_sparse_tensor(s) for s in sp_tensors] 1874 1875 # Set values to SparseTensors for sparse_tensor_keys. 1876 for i, key in enumerate(sparse_tensor_keys): 1877 sequence[key] = sp_tensors_higher_dim[i] 1878 return 1879