• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""CTC (Connectionist Temporal Classification) Operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import function
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import functional_ops
31from tensorflow.python.ops import gen_ctc_ops
32from tensorflow.python.ops import inplace_ops
33from tensorflow.python.ops import linalg_ops
34from tensorflow.python.ops import map_fn
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import nn_ops
37from tensorflow.python.ops import sparse_ops
38from tensorflow.python.ops.nn_grad import _BroadcastMul
39from tensorflow.python.util import deprecation
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import tf_export
42
43
44# pylint: disable=protected-access, invalid-name
45@tf_export(v1=["nn.ctc_loss"])
46def ctc_loss(labels, inputs=None, sequence_length=None,
47             preprocess_collapse_repeated=False,
48             ctc_merge_repeated=True,
49             ignore_longer_outputs_than_inputs=False, time_major=True,
50             logits=None):
51  """Computes the CTC (Connectionist Temporal Classification) Loss.
52
53  This op implements the CTC loss as presented in the article:
54
55  [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
56  Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
57  with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
58  pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
59
60  Input requirements:
61
62  ```
63  sequence_length(b) <= time for all b
64
65  max(labels.indices(labels.indices[:, 1] == b, 2))
66    <= sequence_length(b) for all b.
67  ```
68
69  Notes:
70
71  This class performs the softmax operation for you, so inputs should
72  be e.g. linear projections of outputs by an LSTM.
73
74  The `inputs` Tensor's innermost dimension size, `num_classes`, represents
75  `num_labels + 1` classes, where num_labels is the number of true labels, and
76  the largest value `(num_classes - 1)` is reserved for the blank label.
77
78  For example, for a vocabulary containing 3 labels `[a, b, c]`,
79  `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
80
81  Regarding the arguments `preprocess_collapse_repeated` and
82  `ctc_merge_repeated`:
83
84  If `preprocess_collapse_repeated` is True, then a preprocessing step runs
85  before loss calculation, wherein repeated labels passed to the loss
86  are merged into single labels.  This is useful if the training labels come
87  from, e.g., forced alignments and therefore have unnecessary repetitions.
88
89  If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
90  repeated non-blank labels will not be merged and are interpreted
91  as individual labels.  This is a simplified (non-standard) version of CTC.
92
93  Here is a table of the (roughly) expected first order behavior:
94
95  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
96
97    Classical CTC behavior: Outputs true repeated classes with blanks in
98    between, and can also output repeated classes with no blanks in
99    between that need to be collapsed by the decoder.
100
101  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
102
103    Never learns to output repeated classes, as they are collapsed
104    in the input labels before training.
105
106  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
107
108    Outputs repeated classes with blanks in between, but generally does not
109    require the decoder to collapse/merge repeated classes.
110
111  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
112
113    Untested.  Very likely will not learn to output repeated classes.
114
115  The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
116  of the CTCLoss when dealing with sequences that have longer outputs than
117  inputs. If true, the CTCLoss will simply return zero gradient for those
118  items, otherwise an InvalidArgument error is returned, stopping training.
119
120  Args:
121    labels: An `int32` `SparseTensor`.
122      `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
123      the id for (batch b, time t).
124      `labels.values[i]` must take on values in `[0, num_labels)`.
125      See `core/ops/ctc_ops.cc` for more details.
126    inputs: 3-D `float` `Tensor`.
127      If time_major == False, this will be a `Tensor` shaped:
128        `[batch_size, max_time, num_classes]`.
129      If time_major == True (default), this will be a `Tensor` shaped:
130        `[max_time, batch_size, num_classes]`.
131      The logits.
132    sequence_length: 1-D `int32` vector, size `[batch_size]`.
133      The sequence lengths.
134    preprocess_collapse_repeated: Boolean.  Default: False.
135      If True, repeated labels are collapsed prior to the CTC calculation.
136    ctc_merge_repeated: Boolean.  Default: True.
137    ignore_longer_outputs_than_inputs: Boolean. Default: False.
138      If True, sequences with longer outputs than inputs will be ignored.
139    time_major: The shape format of the `inputs` Tensors.
140      If True, these `Tensors` must be shaped `[max_time, batch_size,
141      num_classes]`.
142      If False, these `Tensors` must be shaped `[batch_size, max_time,
143      num_classes]`.
144      Using `time_major = True` (default) is a bit more efficient because it
145      avoids transposes at the beginning of the ctc_loss calculation.  However,
146      most TensorFlow data is batch-major, so by this function also accepts
147      inputs in batch-major form.
148    logits: Alias for inputs.
149
150  Returns:
151    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
152      probabilities.
153
154  Raises:
155    TypeError: if labels is not a `SparseTensor`.
156  """
157  # The second, third, etc output tensors contain the gradients.  We use it in
158  # _CTCLossGrad() below.
159  if not isinstance(labels, sparse_tensor.SparseTensor):
160    raise TypeError("Expected labels (first argument) to be a SparseTensor")
161
162  # For internal calculations, we transpose to [time, batch, num_classes]
163  inputs = deprecation.deprecated_argument_lookup(
164      "logits", logits, "inputs", inputs)
165  if not time_major:
166    inputs = array_ops.transpose(inputs, [1, 0, 2])  # (B,T,N) => (T,B,N)
167
168  loss, _ = gen_ctc_ops.ctc_loss(
169      inputs,
170      labels.indices,
171      labels.values,
172      sequence_length,
173      preprocess_collapse_repeated=preprocess_collapse_repeated,
174      ctc_merge_repeated=ctc_merge_repeated,
175      ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
176
177  return loss
178
179
180# pylint: disable=unused-argument
181@ops.RegisterGradient("CTCLoss")
182def _CTCLossGrad(op, grad_loss, _):
183  """The derivative provided by CTC Loss.
184
185  Args:
186     op: the CTCLoss op.
187     grad_loss: The backprop for cost.
188
189  Returns:
190     The CTC Loss gradient.
191  """
192  # Outputs are: loss, grad
193  #
194  # Currently there is no way to take the second derivative of this op
195  # due to the fused implementation's interaction with tf.gradients(),
196  # so we make sure we prevent silently incorrect results by raising
197  # an error if the second derivative is requested via prevent_gradient.
198  grad_without_gradient = array_ops.prevent_gradient(
199      op.outputs[1], message="Currently there is no way to take the second "
200      " derivative of ctc_loss due to the fused implementation's interaction "
201      " with tf.gradients()")
202  # Return gradient for inputs and None for
203  # labels_indices, labels_values and sequence_length
204  return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
205
206
207@tf_export("nn.ctc_greedy_decoder")
208def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
209  """Performs greedy decoding on the logits given in input (best path).
210
211  Note: Regardless of the value of merge_repeated, if the maximum index of a
212  given time and batch corresponds to the blank index `(num_classes - 1)`, no
213  new element is emitted.
214
215  If `merge_repeated` is `True`, merge repeated classes in output.
216  This means that if consecutive logits' maximum indices are the same,
217  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
218  is the blank label) becomes
219
220    * `A B B B` if `merge_repeated=True`.
221    * `A B B B B` if `merge_repeated=False`.
222
223  Args:
224    inputs: 3-D `float` `Tensor` sized
225      `[max_time, batch_size, num_classes]`.  The logits.
226    sequence_length: 1-D `int32` vector containing sequence lengths,
227      having size `[batch_size]`.
228    merge_repeated: Boolean.  Default: True.
229
230  Returns:
231    A tuple `(decoded, neg_sum_logits)` where
232
233    decoded: A single-element list. `decoded[0]`
234      is an `SparseTensor` containing the decoded outputs s.t.:
235
236      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
237        The rows store: `[batch, time]`.
238
239      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
240        The vector stores the decoded classes.
241
242      `decoded.dense_shape`: Shape vector, size `(2)`.
243        The shape values are: `[batch_size, max_decoded_length]`
244
245    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
246        sequence found, the negative of the sum of the greatest logit at each
247        timeframe.
248  """
249  outputs = gen_ctc_ops.ctc_greedy_decoder(
250      inputs, sequence_length, merge_repeated=merge_repeated)
251  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
252  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, decoded_shape)],
253          log_probabilities)
254
255
256@tf_export(v1=["nn.ctc_beam_search_decoder"])
257def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
258                            top_paths=1, merge_repeated=True):
259  """Performs beam search decoding on the logits given in input.
260
261  **Note** The `ctc_greedy_decoder` is a special case of the
262  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
263  that decoder is faster for this special case).
264
265  If `merge_repeated` is `True`, merge repeated classes in the output beams.
266  This means that if consecutive entries in a beam are the same,
267  only the first of these is emitted.  That is, when the sequence is
268  `A B B * B * B` (where '*' is the blank label), the return value is:
269
270    * `A B` if `merge_repeated = True`.
271    * `A B B B` if `merge_repeated = False`.
272
273  Args:
274    inputs: 3-D `float` `Tensor`, size
275      `[max_time x batch_size x num_classes]`.  The logits.
276    sequence_length: 1-D `int32` vector containing sequence lengths,
277      having size `[batch_size]`.
278    beam_width: An int scalar >= 0 (beam search beam width).
279    top_paths: An int scalar >= 0, <= beam_width (controls output size).
280    merge_repeated: Boolean.  Default: True.
281
282  Returns:
283    A tuple `(decoded, log_probabilities)` where
284
285    decoded: A list of length top_paths, where `decoded[j]`
286      is a `SparseTensor` containing the decoded outputs:
287
288      `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
289        The rows store: [batch, time].
290
291      `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
292        The vector stores the decoded classes for beam j.
293
294      `decoded[j].dense_shape`: Shape vector, size `(2)`.
295        The shape values are: `[batch_size, max_decoded_length[j]]`.
296
297    log_probability: A `float` matrix `(batch_size x top_paths)` containing
298        sequence log-probabilities.
299  """
300
301  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
302      gen_ctc_ops.ctc_beam_search_decoder(
303          inputs, sequence_length, beam_width=beam_width, top_paths=top_paths,
304          merge_repeated=merge_repeated))
305
306  return (
307      [sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape)
308       in zip(decoded_ixs, decoded_vals, decoded_shapes)],
309      log_probabilities)
310
311
312@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
313def ctc_beam_search_decoder_v2(inputs, sequence_length, beam_width=100,
314                               top_paths=1):
315  """Performs beam search decoding on the logits given in input.
316
317  **Note** The `ctc_greedy_decoder` is a special case of the
318  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
319  that decoder is faster for this special case).
320
321  Args:
322    inputs: 3-D `float` `Tensor`, size
323      `[max_time, batch_size, num_classes]`.  The logits.
324    sequence_length: 1-D `int32` vector containing sequence lengths,
325      having size `[batch_size]`.
326    beam_width: An int scalar >= 0 (beam search beam width).
327    top_paths: An int scalar >= 0, <= beam_width (controls output size).
328
329  Returns:
330    A tuple `(decoded, log_probabilities)` where
331
332    decoded: A list of length top_paths, where `decoded[j]`
333      is a `SparseTensor` containing the decoded outputs:
334
335      `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
336        The rows store: `[batch, time]`.
337
338      `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
339        The vector stores the decoded classes for beam `j`.
340
341      `decoded[j].dense_shape`: Shape vector, size `(2)`.
342        The shape values are: `[batch_size, max_decoded_length[j]]`.
343
344    log_probability: A `float` matrix `[batch_size, top_paths]` containing
345        sequence log-probabilities.
346  """
347
348  # Note, merge_repeated is an invalid optimization that is removed from the
349  # public API: it returns low probability paths.
350  return ctc_beam_search_decoder(inputs, sequence_length=sequence_length,
351                                 beam_width=beam_width, top_paths=top_paths,
352                                 merge_repeated=False)
353
354
355ops.NotDifferentiable("CTCGreedyDecoder")
356ops.NotDifferentiable("CTCBeamSearchDecoder")
357
358
359def _ctc_state_trans(label_seq):
360  """Compute CTC alignment model transition matrix.
361
362  Args:
363    label_seq: tensor of shape [batch_size, max_seq_length]
364
365  Returns:
366    tensor of shape [batch_size, states, states] with a state transition matrix
367    computed for each sequence of the batch.
368  """
369
370  with ops.name_scope("ctc_state_trans"):
371    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
372    batch_size = _get_dim(label_seq, 0)
373    num_labels = _get_dim(label_seq, 1)
374
375    num_label_states = num_labels + 1
376    num_states = 2 * num_label_states
377
378    label_states = math_ops.range(num_label_states)
379    blank_states = label_states + num_label_states
380
381    # Start state to first label.
382    start_to_label = [[1, 0]]
383
384    # Blank to label transitions.
385    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
386
387    # Label to blank transitions.
388    label_to_blank = array_ops.stack([blank_states, label_states], 1)
389
390    # Scatter transitions that don't depend on sequence.
391    indices = array_ops.concat(
392        [start_to_label, blank_to_label, label_to_blank], 0)
393    values = array_ops.ones([_get_dim(indices, 0)])
394    trans = array_ops.scatter_nd(
395        indices, values, shape=[num_states, num_states])
396    trans += linalg_ops.eye(num_states)  # Self-loops.
397
398    # Label to label transitions. Disallow transitions between repeated labels
399    # with no blank state in between.
400    batch_idx = array_ops.zeros_like(label_states[2:])
401    indices = array_ops.stack(
402        [batch_idx, label_states[2:], label_states[1:-1]], 1)
403    indices = array_ops.tile(
404        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
405    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
406    indices += array_ops.expand_dims(batch_idx, 1)
407    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
408    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
409    batched_shape = [batch_size, num_states, num_states]
410    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
411
412    return array_ops.expand_dims(trans, 0) + label_to_label
413
414
415def ctc_state_log_probs(seq_lengths, max_seq_length):
416  """Computes CTC alignment initial and final state log probabilities.
417
418  Create the initial/final state values directly as log values to avoid
419  having to take a float64 log on tpu (which does not exist).
420
421  Args:
422    seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
423    max_seq_length: int, max sequence length possible.
424
425  Returns:
426    initial_state_log_probs, final_state_log_probs
427  """
428
429  batch_size = _get_dim(seq_lengths, 0)
430  num_label_states = max_seq_length + 1
431  num_duration_states = 2
432  num_states = num_duration_states * num_label_states
433  log_0 = math_ops.cast(
434      math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307),
435      dtypes.float32)
436
437  initial_state_log_probs = array_ops.one_hot(
438      indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
439      depth=num_states,
440      on_value=0.0,
441      off_value=log_0, axis=1)
442
443  label_final_state_mask = array_ops.one_hot(
444      seq_lengths, depth=num_label_states, axis=0)
445  duration_final_state_mask = array_ops.ones(
446      [num_duration_states, 1, batch_size])
447  final_state_mask = duration_final_state_mask * label_final_state_mask
448  final_state_log_probs = (1.0 - final_state_mask) * log_0
449  final_state_log_probs = array_ops.reshape(
450      final_state_log_probs, [num_states, batch_size])
451
452  return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
453
454
455def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
456  """Project ilabel log probs to state log probs."""
457
458  num_label_states = _get_dim(labels, 1)
459  blank = ilabel_log_probs[:, :, :1]
460  blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
461  one_hot = array_ops.one_hot(labels, depth=num_labels)
462  one_hot = array_ops.expand_dims(one_hot, axis=0)
463  ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
464  state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
465  state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
466  return array_ops.pad(
467      state_log_probs, [[0, 0], [0, 0], [1, 0]],
468      constant_values=math_ops.log(0.0))
469
470
471def _state_to_olabel(labels, num_labels, states):
472  """Sum state log probs to ilabel log probs."""
473
474  num_label_states = _get_dim(labels, 1) + 1
475  label_states = states[:, :, 1:num_label_states]
476  blank_states = states[:, :, num_label_states:]
477  one_hot = array_ops.one_hot(
478      labels - 1, depth=(num_labels - 1),
479      on_value=0.0, off_value=math_ops.log(0.0))
480  one_hot = array_ops.expand_dims(one_hot, axis=0)
481  label_states = array_ops.expand_dims(label_states, axis=3)
482  label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
483  blank_olabels = math_ops.reduce_logsumexp(
484      blank_states, axis=2, keepdims=True)
485  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
486
487
488# pylint: disable=redefined-outer-name
489def _state_to_olabel_unique(labels, num_labels, states, unique):
490  """Sum state log probs to ilabel log probs using unique label indices."""
491
492  num_label_states = _get_dim(labels, 1) + 1
493  label_states = states[:, :, 1:num_label_states]
494  blank_states = states[:, :, num_label_states:]
495
496  unique_y, unique_idx = unique
497  mul_reduce = _sum_states(unique_idx, label_states)
498
499  num_frames = states.shape[0]
500  batch_size = states.shape[1]
501  num_states = num_label_states - 1
502  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
503  batch_state_major = array_ops.reshape(
504      batch_state_major, [batch_size * num_states, num_frames])
505  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
506  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
507  indices = array_ops.reshape(indices, [-1, 1])
508  scatter = array_ops.scatter_nd(
509      indices=indices,
510      updates=batch_state_major,
511      shape=[batch_size * num_labels, num_frames])
512  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
513  scatter = array_ops.where(
514      math_ops.equal(scatter, 0.0),
515      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)),
516      scatter)
517  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
518  label_olabels = label_olabels[:, :, 1:]
519
520  blank_olabels = math_ops.reduce_logsumexp(
521      blank_states, axis=2, keepdims=True)
522
523  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
524
525
526def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
527  """Computes the CTC loss and gradients.
528
529  Most users will want fwd_bwd.ctc_loss
530
531  This function returns the computed gradient, it does not have a gradient
532  of its own defined.
533
534  Args:
535    logits: tensor of shape [frames, batch_size, num_labels]
536    labels: tensor of shape [batch_size, max_label_seq_length]
537    label_length: tensor of shape [batch_size]
538      Length of reference label sequence in labels.
539    logit_length: tensor of shape [batch_size]
540      Length of input sequence in logits.
541    unique: (optional) unique label indices as computed by unique(labels)
542      If supplied, enables an implementation that is faster and more memory
543      efficient on TPU.
544
545  Returns:
546    loss: tensor of shape [batch_size]
547    gradient: tensor of shape [frames, batch_size, num_labels]
548  """
549
550  num_labels = _get_dim(logits, 2)
551  max_label_seq_length = _get_dim(labels, 1)
552
553  ilabel_log_probs = nn_ops.log_softmax(logits)
554  state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
555  state_trans_probs = _ctc_state_trans(labels)
556  initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
557      label_length, max_label_seq_length)
558  fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
559      state_trans_log_probs=math_ops.log(state_trans_probs),
560      initial_state_log_probs=initial_state_log_probs,
561      final_state_log_probs=final_state_log_probs,
562      observed_log_probs=state_log_probs,
563      sequence_length=logit_length)
564
565  if unique:
566    olabel_log_probs = _state_to_olabel_unique(
567        labels, num_labels, fwd_bwd_log_probs, unique)
568  else:
569    olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
570
571  grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
572  loss = -log_likelihood
573  return loss, grad
574
575
576def _ctc_loss_grad(op, grad_loss, _):
577  grad = op.outputs[1]
578  grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
579  grad += [None] * (len(op.inputs) - len(grad))
580  return grad
581
582
583def _ctc_loss_shape(op):
584  return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
585
586
587@tf_export("nn.ctc_loss", v1=["nn.ctc_loss_v2"])
588def ctc_loss_v2(labels, logits, label_length, logit_length,
589                logits_time_major=True, unique=None,
590                blank_index=None, name=None):
591  """Computes CTC (Connectionist Temporal Classification) loss.
592
593  This op implements the CTC loss as presented in the article:
594
595  [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
596  Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
597  with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
598  pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
599
600  Notes:
601      - Same as the "Classic CTC" in TensorFlow 1.x's tf.nn.ctc_loss setting of
602        preprocess_collapse_repeated=False, ctc_merge_repeated=True
603      - Labels may be supplied as either a dense, zero-padded tensor with a
604        vector of label sequence lengths OR as a SparseTensor.
605      - On TPU and GPU:
606          - Only dense padded labels are supported.
607      - On CPU:
608          - Caller may use SparseTensor or dense padded labels but calling with
609            a SparseTensor will be significantly faster.
610      - Default blank label is 0 rather num_classes - 1, unless overridden by
611        blank_index.
612
613  Args:
614    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
615    logits: tensor of shape [frames, batch_size, num_labels],
616      if logits_time_major == False, shape is [batch_size, frames, num_labels].
617    label_length: tensor of shape [batch_size], None if labels is SparseTensor
618      Length of reference label sequence in labels.
619    logit_length: tensor of shape [batch_size]
620      Length of input sequence in logits.
621    logits_time_major: (optional) If True (default), logits is shaped
622      [time, batch, logits]. If False, shape is [batch, time, logits]
623    unique: (optional) Unique label indices as computed by
624      ctc_unique_labels(labels).  If supplied, enable a faster, memory
625      efficient implementation on TPU.
626    blank_index: (optional) Set the class index to use for the blank label.
627      Negative values will start from num_classes, ie, -1 will reproduce the
628      ctc_loss behavior of using num_classes - 1 for the blank symbol.
629      There is some memory/performance overhead to switching from the default
630      of 0 as an additional shifted copy of the logits may be created.
631    name: A name for this `Op`. Defaults to "ctc_loss_dense".
632
633  Returns:
634    loss: tensor of shape [batch_size], negative log probabilities.
635  """
636  if isinstance(labels, sparse_tensor.SparseTensor):
637    if blank_index is None:
638      raise ValueError(
639          "blank_index must be given when using SparseTensor labels.")
640
641    if blank_index < 0:
642      blank_index += _get_dim(logits, 2)
643
644    if blank_index != _get_dim(logits, 2) - 1:
645      logits = array_ops.concat([
646          logits[:, :, :blank_index],
647          logits[:, :, blank_index+1:],
648          logits[:, :, blank_index:blank_index+1],
649      ], axis=2)
650      labels = sparse_tensor.SparseTensor(
651          labels.indices,
652          array_ops.where(labels.values < blank_index,
653                          labels.values,
654                          labels.values - 1),
655          labels.dense_shape)
656
657    return ctc_loss(labels=labels,
658                    inputs=logits,
659                    sequence_length=logit_length,
660                    time_major=logits_time_major)
661
662  if blank_index is None:
663    blank_index = 0
664
665  return ctc_loss_dense(labels=labels,
666                        logits=logits,
667                        label_length=label_length,
668                        logit_length=logit_length,
669                        logits_time_major=logits_time_major,
670                        unique=unique,
671                        blank_index=blank_index,
672                        name=name)
673
674
675def ctc_loss_dense(labels, logits, label_length, logit_length,
676                   logits_time_major=True, unique=None,
677                   blank_index=0, name=None):
678  """Computes CTC (Connectionist Temporal Classification) loss.
679
680  This op implements the CTC loss as presented in the article:
681
682  [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
683  Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
684  with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
685  pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
686
687  Using the batched forward backward algorithm described in:
688
689  [Sim, K. C., Narayanan, A., Bagby, T., Sainath, T. N., & Bacchiani, M.
690  Improving the efficiency of forward-backward algorithm using batched
691    computation in TensorFlow.
692  Automatic Speech Recognition and Understanding Workshop (ASRU),
693    2017 IEEE (pp. 258-264).
694  ](https://ieeexplore.ieee.org/iel7/8260578/8268903/08268944.pdf)
695
696  Notes:
697    Significant differences from tf.nn.ctc_loss:
698      Supports GPU and TPU (tf.nn.ctc_loss supports CPU only):
699        For batched operations, GPU and TPU are significantly faster than using
700        ctc_loss on CPU.
701        This implementation runs on CPU, but significantly slower than ctc_loss.
702      Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
703      Logits and labels are dense arrays with padding rather than SparseTensor.
704      The only mode supported is the same as:
705        preprocess_collapse_repeated=False, ctc_merge_repeated=True
706        To collapse labels, the caller can preprocess label sequence first.
707
708    The dense implementation supports both CPU, GPU and TPU. A fast path is
709    provided that significantly improves memory use for large vocabulary if the
710    caller preprocesses label sequences to get unique label indices on the CPU
711    (eg. in the data input pipeline) using ctc_ops.unique and simplies this in
712    the optional "unique" kwarg. This is especially useful for TPU and GPU but
713    also works with if used on CPU.
714
715  Args:
716    labels: tensor of shape [batch_size, max_label_seq_length]
717    logits: tensor of shape [frames, batch_size, num_labels],
718      if logits_time_major == False, shape is [batch_size, frames, num_labels].
719    label_length: tensor of shape [batch_size]
720      Length of reference label sequence in labels.
721    logit_length: tensor of shape [batch_size]
722      Length of input sequence in logits.
723    logits_time_major: (optional) If True (default), logits is shaped
724      [time, batch, logits]. If False, shape is [batch, time, logits]
725    unique: (optional) Unique label indices as computed by unique(labels).
726      If supplied, enable a faster, memory efficient implementation on TPU.
727    blank_index: (optional) Set the class index to use for the blank label.
728      Negative values will start from num_classes, ie, -1 will reproduce the
729      ctc_loss behavior of using num_classes - 1 for the blank symbol.
730      There is some memory/performance overhead to switching from the default
731      of 0 as an additional shifted copy of the logits may be created.
732    name: A name for this `Op`. Defaults to "ctc_loss_dense".
733
734  Returns:
735    loss: tensor of shape [batch_size], negative log probabilities.
736  """
737
738  with ops.name_scope(name, "ctc_loss_dense",
739                      [logits, labels, label_length, logit_length]):
740    logits = ops.convert_to_tensor(logits, name="logits")
741    labels = ops.convert_to_tensor(labels, name="labels")
742    label_length = ops.convert_to_tensor(label_length, name="label_length")
743    logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
744
745    if not logits_time_major:
746      logits = array_ops.transpose(logits, perm=[1, 0, 2])
747
748    if blank_index != 0:
749      if blank_index < 0:
750        blank_index += _get_dim(logits, 2)
751      logits = array_ops.concat([
752          logits[:, :, blank_index:blank_index+1],
753          logits[:, :, :blank_index],
754          logits[:, :, blank_index+1:],
755      ], axis=2)
756      labels = array_ops.where(labels < blank_index, labels + 1, labels)
757
758    args = [logits, labels, label_length, logit_length]
759
760    if unique:
761      unique_y, unique_idx = unique
762      args.extend([unique_y, unique_idx])
763
764    # TODO(tombagby): Update to tfe.defun
765    @function.Defun(*[x.dtype for x in args],
766                    python_grad_func=_ctc_loss_grad,
767                    shape_func=_ctc_loss_shape)
768    def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
769                         *unique_t):
770      """Compute CTC loss."""
771      logits_t.set_shape(logits.shape)
772      labels_t.set_shape(labels.shape)
773      label_length_t.set_shape(label_length.shape)
774      logit_length_t.set_shape(logit_length.shape)
775      kwargs = dict(
776          logits=logits_t,
777          labels=labels_t,
778          label_length=label_length_t,
779          logit_length=logit_length_t)
780      if unique_t:
781        kwargs["unique"] = unique_t
782      return ctc_loss_and_grad(**kwargs)
783
784    return compute_ctc_loss(*args)[0]
785
786
787@tf_export("nn.collapse_repeated")
788def collapse_repeated(labels, seq_length, name=None):
789  """Merge repeated labels into single labels.
790
791  Args:
792    labels: Tensor of shape (batch, max value in seq_length)
793    seq_length: Tensor of shape (batch), sequence length of each batch element.
794    name: A name for this `Op`. Defaults to "collapse_repeated_labels".
795
796  Returns:
797    tuple of Tensor of shape (batch, max_seq_length) with repeated labels
798    collapsed and padded to max_seq_length, eg:
799        [[A, A, B, B, A],
800         [A, B, C, D, E]] => [[A, B, A, 0, 0],
801                              [A, B, C, D, E]]
802    and int tensor of shape [batch] with new sequence lengths.
803  """
804
805  with ops.name_scope(name, "collapse_repeated_labels",
806                      [labels, seq_length]):
807    labels = ops.convert_to_tensor(labels, name="labels")
808    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
809
810    # Mask labels that don't equal previous label.
811    label_mask = array_ops.concat(
812        [array_ops.ones_like(labels[:, :1], dtypes.bool),
813         math_ops.not_equal(labels[:, 1:], labels[:, :-1])],
814        axis=1)
815
816    # Filter labels that aren't in the original sequence.
817    maxlen = _get_dim(labels, 1)
818    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
819    label_mask = math_ops.logical_and(label_mask, seq_mask)
820
821    # Count masks for new sequence lengths.
822    new_seq_len = math_ops.reduce_sum(
823        math_ops.cast(label_mask, dtypes.int32), axis=1)
824
825    # Mask indexes based on sequence length mask.
826    new_maxlen = math_ops.reduce_max(new_seq_len)
827    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
828
829    # Flatten everything and mask out labels to keep and sparse indices.
830    flat_labels = array_ops.reshape(labels, [-1])
831    flat_label_mask = array_ops.reshape(label_mask, [-1])
832    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
833    idx = math_ops.range(_get_dim(flat_idx_mask, 0))
834
835    # Scatter to flat shape.
836    flat = array_ops.scatter_nd(
837        indices=array_ops.expand_dims(
838            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
839        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
840        shape=array_ops.shape(flat_idx_mask))
841
842    # Reshape back to square batch.
843    batch_size = _get_dim(labels, 0)
844    new_shape = [batch_size, new_maxlen]
845    return (array_ops.reshape(flat, new_shape),
846            math_ops.cast(new_seq_len, seq_length.dtype))
847
848
849def dense_labels_to_sparse(dense, length):
850  """Convert dense labels with sequence lengths to sparse tensor.
851
852  Args:
853    dense: tensor of shape [batch, max_length]
854    length: int tensor of shape [batch]
855      The length of each sequence in dense.
856
857  Returns:
858    tf.SparseTensor with values only for the valid elements of sequences.
859  """
860
861  flat_values = array_ops.reshape(dense, [-1])
862  flat_indices = math_ops.range(
863      array_ops.shape(flat_values, out_type=dtypes.int64)[0])
864  mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
865  flat_mask = array_ops.reshape(mask, [-1])
866  indices = array_ops.expand_dims(
867      array_ops.boolean_mask(flat_indices, flat_mask), 1)
868  values = array_ops.boolean_mask(flat_values, flat_mask)
869  sparse = sparse_tensor.SparseTensor(
870      indices=indices, values=math_ops.cast(values, dtypes.int32),
871      dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
872  reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
873  max_length = math_ops.reduce_max(length)
874  return sparse_tensor.SparseTensor(
875      indices=reshaped.indices,
876      values=reshaped.values,
877      dense_shape=[
878          math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
879          math_ops.cast(max_length, dtypes.int64)])
880
881
882@tf_export("nn.ctc_unique_labels")
883def ctc_unique_labels(labels, name=None):
884  """Get unique labels and indices for batched labels for tf.nn.ctc_loss.
885
886  For use with tf.nn.ctc_loss_v2 optional argument `unique`: This op can be
887  used to preprocess labels in input pipeline to for better speed/memory use
888  computing the ctc loss on TPU.
889
890  Example:
891    ctc_unique_labels([[3, 4, 4, 3]]) ->
892      unique labels padded with 0: [[3, 4, 0, 0]]
893      indices of original labels in unique: [0, 1, 1, 0]
894
895  Args:
896    labels: tensor of shape [batch_size, max_label_length] padded with 0.
897    name: A name for this `Op`. Defaults to "ctc_unique_labels".
898
899  Returns:
900    tuple of
901      - unique labels, tensor of shape `[batch_size, max_label_length]`
902      - indices into unique labels, shape `[batch_size, max_label_length]`
903  """
904
905  with ops.name_scope(name, "ctc_unique_labels", [labels]):
906    labels = ops.convert_to_tensor(labels, name="labels")
907    def _unique(x):
908      u = array_ops.unique(x)
909      y = array_ops.pad(
910          u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
911      y = math_ops.cast(y, dtypes.int64)
912      return [y, u.idx]
913    return map_fn.map_fn(
914        _unique, labels, dtype=[dtypes.int64, dtypes.int32])
915
916
917def _sum_states(idx, states):
918  """Take logsumexp for each unique state out of all label states.
919
920  Args:
921    idx: tensor of shape [batch, label_length]
922      For each sequence, indices into a set of unique labels as computed by
923      calling unique.
924    states: tensor of shape [frames, batch, label_length]
925      Log probabilities for each label state.
926
927  Returns:
928    tensor of shape [frames, batch_size, label_length], log probabilites summed
929      for each unique label of the sequence.
930  """
931
932  with ops.name_scope("sum_states"):
933    idx = ops.convert_to_tensor(idx, name="idx")
934    num_states = _get_dim(states, 2)
935    states = array_ops.expand_dims(states, axis=2)
936    one_hot = array_ops.one_hot(
937        idx, depth=num_states, on_value=0.0, off_value=math_ops.log(0.0),
938        axis=1)
939    return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
940
941
942def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
943                          final_state_log_probs, observed_log_probs,
944                          sequence_length):
945  """Forward-backward algorithm computed in log domain.
946
947  Args:
948    state_trans_log_probs: tensor of shape [states, states] or
949      if different transition matrix per batch [batch_size, states, states]
950    initial_state_log_probs: tensor of shape [batch_size, states]
951    final_state_log_probs: tensor of shape [batch_size, states]
952    observed_log_probs: tensor of shape [frames, batch_size, states]
953    sequence_length: tensor of shape [batch_size]
954
955  Returns:
956    forward backward log probabilites: tensor of shape [frames, batch, states]
957    log_likelihood: tensor of shape [batch_size]
958
959  Raises:
960    ValueError: If state_trans_log_probs has unknown or incorrect rank.
961  """
962
963  if state_trans_log_probs.shape.ndims == 2:
964    perm = [1, 0]
965  elif state_trans_log_probs.shape.ndims == 3:
966    perm = [0, 2, 1]
967  else:
968    raise ValueError(
969        "state_trans_log_probs rank must be known and == 2 or 3, is: %s" %
970        state_trans_log_probs.shape.ndims)
971
972  bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
973  batch_size = _get_dim(observed_log_probs, 1)
974
975  def _forward(state_log_prob, obs_log_prob):
976    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
977    state_log_prob += state_trans_log_probs
978    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
979    state_log_prob += obs_log_prob
980    log_prob_sum = math_ops.reduce_logsumexp(
981        state_log_prob, axis=-1, keepdims=True)
982    state_log_prob -= log_prob_sum
983    return state_log_prob
984
985  fwd = _scan(_forward, observed_log_probs, initial_state_log_probs,
986              inclusive=True)
987
988  def _backward(accs, elems):
989    """Calculate log probs and cumulative sum masked for sequence length."""
990    state_log_prob, cum_log_sum = accs
991    obs_log_prob, mask = elems
992    state_log_prob += obs_log_prob
993    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
994    state_log_prob += bwd_state_trans_log_probs
995    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
996
997    log_prob_sum = math_ops.reduce_logsumexp(
998        state_log_prob, axis=-1, keepdims=True)
999    state_log_prob -= log_prob_sum
1000
1001    cum_log_sum += array_ops.squeeze(log_prob_sum) * mask
1002    batched_mask = array_ops.expand_dims(mask, axis=1)
1003    out = state_log_prob * batched_mask
1004    out += final_state_log_probs * (1.0 - batched_mask)
1005    return out, cum_log_sum
1006
1007  zero_log_sum = array_ops.zeros([batch_size])
1008  maxlen = _get_dim(observed_log_probs, 0)
1009  mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
1010  mask = array_ops.transpose(mask, perm=[1, 0])
1011
1012  bwd, cum_log_sum = _scan(_backward, (observed_log_probs, mask),
1013                           (final_state_log_probs, zero_log_sum),
1014                           reverse=True, inclusive=True)
1015
1016  fwd_bwd_log_probs = fwd[1:] + bwd[1:]
1017  fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
1018      fwd_bwd_log_probs, axis=2, keepdims=True)
1019  fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
1020  fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
1021
1022  log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
1023
1024  return fwd_bwd_log_probs, log_likelihood
1025
1026
1027# TODO(tombagby): This is currently faster for the ctc implementation than using
1028# functional_ops.scan, but could be replaced by that or something similar if
1029# things change.
1030def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
1031  """Repeatedly applies callable `fn` to a sequence of elements.
1032
1033  Implemented by functional_ops.While, tpu friendly, no gradient.
1034
1035  This is similar to functional_ops.scan but significantly faster on tpu/gpu
1036  for the forward backward use case.
1037
1038  Examples:
1039    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]
1040
1041    Multiple accumulators:
1042      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
1043
1044    Multiple inputs:
1045      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
1046
1047  Args:
1048    fn: callable, fn(accumulators, element) return new accumulator values.
1049      The (possibly nested) sequence of accumulators is the same as `initial`
1050      and the return value must have the same structure.
1051    elems: A (possibly nested) tensor which will be unpacked along the first
1052      dimension. The resulting slices will be the second argument to fn. The
1053      first dimension of all nested input tensors must be the same.
1054    initial: A tensor or (possibly nested) sequence of tensors with initial
1055      values for the accumulators.
1056    reverse: (optional) True enables scan and output elems in reverse order.
1057    inclusive: (optional) True includes the initial accumulator values in the
1058      output. Length of output will be len(elem sequence) + 1. Not meaningful
1059      if final_only is True.
1060    final_only: (optional) When True, return only the final accumulated values,
1061      not the concatenation of accumulated values for each input.
1062
1063  Returns:
1064    A (possibly nested) sequence of tensors with the results of applying fn
1065    to tensors unpacked from elems and previous accumulator values.
1066  """
1067
1068  flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
1069  num_elems = array_ops.shape(flat_elems[0])[0]
1070  pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
1071  flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
1072  pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
1073  accum_dtypes = [x.dtype for x in flat_initial]
1074  num_accums = len(flat_initial)
1075
1076  # Types for counter, [outputs], [accumulators] loop arguments.
1077  if final_only:
1078    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
1079  else:
1080    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
1081
1082  # TODO(tombagby): Update to tfe.defun
1083  @function.Defun(*loop_dtypes)
1084  def cond(i, num_elems, *args):
1085    del args
1086    return i >= 0 if reverse else i < num_elems
1087
1088  # The loop *args are [output tensors] + [accumulator tensors] which must
1089  # be paired. Each output corresponds to one accumulator.
1090  @function.Defun(*loop_dtypes)
1091  def body(i, num_elems, *args):
1092    """Loop body."""
1093    i.set_shape([])
1094    if final_only:
1095      accum = args
1096    else:
1097      out, accum = args[:num_accums], args[num_accums:]
1098    slices = [array_ops.gather(e, i) for e in flat_elems]
1099    accum = fn(pack(accum), pack_elems(slices))
1100    flat_accum = nest.flatten(accum)
1101    if final_only:
1102      new_out = []
1103    else:
1104      update_i = i + 1 if inclusive and not reverse else i
1105      new_out = [inplace_ops.alias_inplace_update(x, update_i, y)
1106                 for x, y in zip(out, flat_accum)]
1107    i = i - 1 if reverse else i + 1
1108    return [i, num_elems] + new_out + flat_accum
1109
1110  init_i = (array_ops.shape(flat_elems[0])[0] - 1 if reverse
1111            else constant_op.constant(0, dtype=dtypes.int32))
1112  outputs = []
1113  if not final_only:
1114    num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
1115    for initial_accum in flat_initial:
1116      out_shape = array_ops.concat(
1117          [[num_outputs], array_ops.shape(initial_accum)], 0)
1118      out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
1119      if inclusive:
1120        out = inplace_ops.alias_inplace_add(
1121            out, init_i + (1 if reverse else 0), initial_accum)
1122      outputs.append(out)
1123  loop_in = [init_i, num_elems] + outputs + flat_initial
1124  hostmem = [
1125      i for i, x in enumerate(loop_in)
1126      if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
1127  ]
1128
1129  # TODO(tombagby): Update to while_v2.
1130  loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
1131  out = loop_results[2:num_accums + 2]
1132  return pack(out)
1133
1134
1135def _get_dim(tensor, i):
1136  """Get value of tensor shape[i] preferring static value if available."""
1137  return tensor_shape.dimension_value(
1138      tensor.shape[i]) or array_ops.shape(tensor)[i]
1139