• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""CTC (Connectionist Temporal Classification) Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import uuid
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function as function_eager
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import device
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import function
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.framework import tensor_shape
33
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import custom_gradient
36from tensorflow.python.ops import functional_ops
37from tensorflow.python.ops import gen_ctc_ops
38from tensorflow.python.ops import inplace_ops
39from tensorflow.python.ops import linalg_ops
40from tensorflow.python.ops import map_fn
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import nn_ops
43from tensorflow.python.ops import sparse_ops
44from tensorflow.python.ops.nn_grad import _BroadcastMul
45from tensorflow.python.util import deprecation
46from tensorflow.python.util import dispatch
47from tensorflow.python.util import nest
48from tensorflow.python.util.tf_export import tf_export
49
50_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
51_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
52_CPU_DEVICE_NAME = "CPU"
53_GPU_DEVICE_NAME = "GPU"
54
55
56def _get_context_device_type():
57  """Parse the current context and return the device type, eg CPU/GPU."""
58  current_device = context.context().device_name
59  if current_device is None:
60    return None
61  return device.DeviceSpec.from_string(current_device).device_type
62
63
64def _generate_defun_backend(unique_api_name, preferred_device, func):
65  function_attributes = {
66      _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
67      _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
68  }
69  return function_eager.defun_with_attributes(
70      func=func, attributes=function_attributes, autograph=False)
71
72# pylint: disable=protected-access, invalid-name
73@tf_export(v1=["nn.ctc_loss"])
74@dispatch.add_dispatch_support
75def ctc_loss(labels,
76             inputs=None,
77             sequence_length=None,
78             preprocess_collapse_repeated=False,
79             ctc_merge_repeated=True,
80             ignore_longer_outputs_than_inputs=False,
81             time_major=True,
82             logits=None):
83  """Computes the CTC (Connectionist Temporal Classification) Loss.
84
85  This op implements the CTC loss as presented in (Graves et al., 2006).
86
87  Input requirements:
88
89  ```
90  sequence_length(b) <= time for all b
91
92  max(labels.indices(labels.indices[:, 1] == b, 2))
93    <= sequence_length(b) for all b.
94  ```
95
96  Notes:
97
98  This class performs the softmax operation for you, so inputs should
99  be e.g. linear projections of outputs by an LSTM.
100
101  The `inputs` Tensor's innermost dimension size, `num_classes`, represents
102  `num_labels + 1` classes, where num_labels is the number of true labels, and
103  the largest value `(num_classes - 1)` is reserved for the blank label.
104
105  For example, for a vocabulary containing 3 labels `[a, b, c]`,
106  `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
107
108  Regarding the arguments `preprocess_collapse_repeated` and
109  `ctc_merge_repeated`:
110
111  If `preprocess_collapse_repeated` is True, then a preprocessing step runs
112  before loss calculation, wherein repeated labels passed to the loss
113  are merged into single labels.  This is useful if the training labels come
114  from, e.g., forced alignments and therefore have unnecessary repetitions.
115
116  If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
117  repeated non-blank labels will not be merged and are interpreted
118  as individual labels.  This is a simplified (non-standard) version of CTC.
119
120  Here is a table of the (roughly) expected first order behavior:
121
122  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
123
124    Classical CTC behavior: Outputs true repeated classes with blanks in
125    between, and can also output repeated classes with no blanks in
126    between that need to be collapsed by the decoder.
127
128  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
129
130    Never learns to output repeated classes, as they are collapsed
131    in the input labels before training.
132
133  * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
134
135    Outputs repeated classes with blanks in between, but generally does not
136    require the decoder to collapse/merge repeated classes.
137
138  * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
139
140    Untested.  Very likely will not learn to output repeated classes.
141
142  The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
143  of the CTCLoss when dealing with sequences that have longer outputs than
144  inputs. If true, the CTCLoss will simply return zero gradient for those
145  items, otherwise an InvalidArgument error is returned, stopping training.
146
147  Args:
148    labels: An `int32` `SparseTensor`.
149      `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id
150        for (batch b, time t). `labels.values[i]` must take on values in `[0,
151        num_labels)`. See `core/ops/ctc_ops.cc` for more details.
152    inputs: 3-D `float` `Tensor`.
153      If time_major == False, this will be a `Tensor` shaped: `[batch_size,
154        max_time, num_classes]`.
155      If time_major == True (default), this will be a `Tensor` shaped:
156        `[max_time, batch_size, num_classes]`. The logits.
157    sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence
158      lengths.
159    preprocess_collapse_repeated: Boolean.  Default: False. If True, repeated
160      labels are collapsed prior to the CTC calculation.
161    ctc_merge_repeated: Boolean.  Default: True.
162    ignore_longer_outputs_than_inputs: Boolean. Default: False. If True,
163      sequences with longer outputs than inputs will be ignored.
164    time_major: The shape format of the `inputs` Tensors. If True, these
165      `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False,
166      these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
167      Using `time_major = True` (default) is a bit more efficient because it
168      avoids transposes at the beginning of the ctc_loss calculation.  However,
169      most TensorFlow data is batch-major, so by this function also accepts
170      inputs in batch-major form.
171    logits: Alias for inputs.
172
173  Returns:
174    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
175      probabilities.
176
177  Raises:
178    TypeError: if labels is not a `SparseTensor`.
179
180  References:
181      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
182      with Recurrent Neural Networks:
183        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
184        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
185  """
186  return _ctc_loss_impl(
187      labels,
188      inputs,
189      sequence_length,
190      preprocess_collapse_repeated,
191      ctc_merge_repeated,
192      ignore_longer_outputs_than_inputs,
193      time_major,
194      logits,
195      use_cudnn=False)
196
197
198def _ctc_loss_impl(labels,
199                   inputs=None,
200                   sequence_length=None,
201                   preprocess_collapse_repeated=False,
202                   ctc_merge_repeated=True,
203                   ignore_longer_outputs_than_inputs=False,
204                   time_major=True,
205                   logits=None,
206                   use_cudnn=False):
207  # Helper function of ctc_loss with one additional param:
208  # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
209  #   index has to be 0.
210
211  # The second, third, etc output tensors contain the gradients.  We use it in
212  # _CTCLossGrad() below.
213  if not isinstance(labels, sparse_tensor.SparseTensor):
214    raise TypeError("Expected labels (first argument) to be a SparseTensor")
215
216  # For internal calculations, we transpose to [time, batch, num_classes]
217  inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs",
218                                                  inputs)
219  if not time_major:
220    inputs = array_ops.transpose(inputs, [1, 0, 2])  # (B,T,N) => (T,B,N)
221
222  # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
223  # blank index to be 0, but v1 views it as the last index.
224  if use_cudnn:
225    ctc_loss_func = gen_ctc_ops.ctc_loss_v2
226  else:
227    ctc_loss_func = gen_ctc_ops.ctc_loss
228
229  loss, _ = ctc_loss_func(
230      inputs,
231      labels.indices,
232      labels.values,
233      sequence_length,
234      preprocess_collapse_repeated=preprocess_collapse_repeated,
235      ctc_merge_repeated=ctc_merge_repeated,
236      ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
237
238  return loss
239
240# pylint: disable=unused-argument
241def _CTCLossGradImpl(op, grad_loss, _):
242  # Outputs are: loss, grad
243  #
244  # Currently there is no way to take the second derivative of this op
245  # due to the fused implementation's interaction with tf.gradients(),
246  # so we make sure we prevent silently incorrect results by raising
247  # an error if the second derivative is requested via prevent_gradient.
248  grad_without_gradient = array_ops.prevent_gradient(
249      op.outputs[1],
250      message="Currently there is no way to take the second "
251      " derivative of ctc_loss due to the fused implementation's interaction "
252      " with tf.gradients()")
253  # Return gradient for inputs and None for
254  # labels_indices, labels_values and sequence_length
255  return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
256
257
258# pylint: disable=unused-argument
259@ops.RegisterGradient("CTCLoss")
260def _CTCLossGrad(op, grad_loss, _):
261  """The derivative provided by CTC Loss.
262
263  Args:
264     op: the CTCLoss op.
265     grad_loss: The backprop for cost.
266
267  Returns:
268     The CTC Loss gradient.
269  """
270  return _CTCLossGradImpl(op, grad_loss, _)
271
272
273# pylint: disable=unused-argument
274@ops.RegisterGradient("CTCLossV2")
275def _CTCLossV2Grad(op, grad_loss, _):
276  """The derivative provided by CTC Loss V2.
277
278  Args:
279     op: the CTCLossV2 op.
280     grad_loss: The backprop for cost.
281
282  Returns:
283     The CTC Loss V2 gradient.
284  """
285  return _CTCLossGradImpl(op, grad_loss, _)
286
287
288@tf_export("nn.ctc_greedy_decoder")
289@dispatch.add_dispatch_support
290def ctc_greedy_decoder(inputs,
291                       sequence_length,
292                       merge_repeated=True,
293                       blank_index=None):
294  """Performs greedy decoding on the logits given in input (best path).
295
296  Given a tensor as `inputs`, the `blank_index` parameter defines the class
297  index of the blank symbol.
298
299  For example:
300
301  If `blank_index` is equal to 1:
302
303  >>> inf = float("inf")
304  >>> logits = tf.constant([[[   0., -inf, -inf],
305  ...                        [ -2.3, -inf, -0.1]],
306  ...                       [[ -inf, -0.5, -inf],
307  ...                        [ -inf, -inf, -0.1]],
308  ...                       [[ -inf, -inf, -inf],
309  ...                        [ -0.1, -inf, -2.3]]])
310  >>> seq_lens = tf.constant([2, 3])
311  >>> outputs = tf.nn.ctc_greedy_decoder(
312  ...     logits,
313  ...     seq_lens,
314  ...     blank_index=1)
315
316  Notes:
317
318  - Regardless of the value of `merge_repeated`, if an index of a
319    given time and batch corresponds to the `blank_index`, no new
320    element is emitted.
321  - Default `blank_index` is `(num_classes - 1)`, unless overriden.
322
323  If `merge_repeated` is `True`, merge repeated classes in output.
324  This means that if consecutive logits' maximum indices are the same,
325  only the first of these is emitted.  The sequence `A B B * B * B` (where '*'
326  is the blank label) becomes
327
328    * `A B B B` if `merge_repeated=True`.
329    * `A B B B B` if `merge_repeated=False`.
330
331  Args:
332    inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`.
333      The logits.
334    sequence_length: 1-D `int32` vector containing sequence lengths, having size
335      `[batch_size]`.
336    merge_repeated: Boolean.  Default: True.
337    blank_index: (Optional). Default: `num_classes - 1`. Define the class index
338      to use for the blank label. Negative values will start from num_classes,
339      ie, -1 will reproduce the ctc_greedy_decoder behavior of using
340      num_classes - 1 for the blank symbol, which corresponds to the default.
341
342  Returns:
343    A tuple `(decoded, neg_sum_logits)` where
344
345    decoded: A single-element list. `decoded[0]`
346      is an `SparseTensor` containing the decoded outputs s.t.:
347
348      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
349        The rows store: `[batch, time]`.
350
351      `decoded.values`: Values vector, size `(total_decoded_outputs)`.
352        The vector stores the decoded classes.
353
354      `decoded.dense_shape`: Shape vector, size `(2)`.
355        The shape values are: `[batch_size, max_decoded_length]`
356
357    neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
358        sequence found, the negative of the sum of the greatest logit at each
359        timeframe.
360  """
361
362  outputs = gen_ctc_ops.ctc_greedy_decoder(
363      inputs,
364      sequence_length,
365      merge_repeated=merge_repeated,
366      blank_index=blank_index)
367  (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
368  return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val,
369                                      decoded_shape)], log_probabilities)
370
371
372@tf_export(v1=["nn.ctc_beam_search_decoder"])
373@dispatch.add_dispatch_support
374def ctc_beam_search_decoder(inputs,
375                            sequence_length,
376                            beam_width=100,
377                            top_paths=1,
378                            merge_repeated=True):
379  """Performs beam search decoding on the logits given in input.
380
381  **Note** The `ctc_greedy_decoder` is a special case of the
382  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
383  that decoder is faster for this special case).
384
385  If `merge_repeated` is `True`, merge repeated classes in the output beams.
386  This means that if consecutive entries in a beam are the same,
387  only the first of these is emitted.  That is, when the sequence is
388  `A B B * B * B` (where '*' is the blank label), the return value is:
389
390    * `A B` if `merge_repeated = True`.
391    * `A B B B` if `merge_repeated = False`.
392
393  Args:
394    inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`.
395      The logits.
396    sequence_length: 1-D `int32` vector containing sequence lengths, having size
397      `[batch_size]`.
398    beam_width: An int scalar >= 0 (beam search beam width).
399    top_paths: An int scalar >= 0, <= beam_width (controls output size).
400    merge_repeated: Boolean.  Default: True.
401
402  Returns:
403    A tuple `(decoded, log_probabilities)` where
404
405    decoded: A list of length top_paths, where `decoded[j]`
406      is a `SparseTensor` containing the decoded outputs:
407
408      `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
409        The rows store: [batch, time].
410
411      `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
412        The vector stores the decoded classes for beam j.
413
414      `decoded[j].dense_shape`: Shape vector, size `(2)`.
415        The shape values are: `[batch_size, max_decoded_length[j]]`.
416
417    log_probability: A `float` matrix `(batch_size x top_paths)` containing
418        sequence log-probabilities.
419  """
420
421  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
422      gen_ctc_ops.ctc_beam_search_decoder(
423          inputs,
424          sequence_length,
425          beam_width=beam_width,
426          top_paths=top_paths,
427          merge_repeated=merge_repeated))
428
429  return ([
430      sparse_tensor.SparseTensor(ix, val, shape)
431      for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)
432  ], log_probabilities)
433
434
435@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"])
436@dispatch.add_dispatch_support
437def ctc_beam_search_decoder_v2(inputs,
438                               sequence_length,
439                               beam_width=100,
440                               top_paths=1):
441  """Performs beam search decoding on the logits given in input.
442
443  **Note** The `ctc_greedy_decoder` is a special case of the
444  `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but
445  that decoder is faster for this special case).
446
447  Args:
448    inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`.
449      The logits.
450    sequence_length: 1-D `int32` vector containing sequence lengths, having size
451      `[batch_size]`.
452    beam_width: An int scalar >= 0 (beam search beam width).
453    top_paths: An int scalar >= 0, <= beam_width (controls output size).
454
455  Returns:
456    A tuple `(decoded, log_probabilities)` where
457
458    decoded: A list of length top_paths, where `decoded[j]`
459      is a `SparseTensor` containing the decoded outputs:
460
461      `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`;
462        The rows store: `[batch, time]`.
463
464      `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`.
465        The vector stores the decoded classes for beam `j`.
466
467      `decoded[j].dense_shape`: Shape vector, size `(2)`.
468        The shape values are: `[batch_size, max_decoded_length[j]]`.
469
470    log_probability: A `float` matrix `[batch_size, top_paths]` containing
471        sequence log-probabilities.
472  """
473
474  # Note, merge_repeated is an invalid optimization that is removed from the
475  # public API: it returns low probability paths.
476  return ctc_beam_search_decoder(
477      inputs,
478      sequence_length=sequence_length,
479      beam_width=beam_width,
480      top_paths=top_paths,
481      merge_repeated=False)
482
483
484ops.NotDifferentiable("CTCGreedyDecoder")
485ops.NotDifferentiable("CTCBeamSearchDecoder")
486
487
488def _ctc_state_trans(label_seq):
489  """Compute CTC alignment model transition matrix.
490
491  Args:
492    label_seq: tensor of shape [batch_size, max_seq_length]
493
494  Returns:
495    tensor of shape [batch_size, states, states] with a state transition matrix
496    computed for each sequence of the batch.
497  """
498
499  with ops.name_scope("ctc_state_trans"):
500    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
501    batch_size = _get_dim(label_seq, 0)
502    num_labels = _get_dim(label_seq, 1)
503
504    num_label_states = num_labels + 1
505    num_states = 2 * num_label_states
506
507    label_states = math_ops.range(num_label_states)
508    blank_states = label_states + num_label_states
509
510    # Start state to first label.
511    start_to_label = [[1, 0]]
512
513    # Blank to label transitions.
514    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
515
516    # Label to blank transitions.
517    label_to_blank = array_ops.stack([blank_states, label_states], 1)
518
519    # Scatter transitions that don't depend on sequence.
520    indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank],
521                               0)
522    values = array_ops.ones([_get_dim(indices, 0)])
523    trans = array_ops.scatter_nd(
524        indices, values, shape=[num_states, num_states])
525    trans += linalg_ops.eye(num_states)  # Self-loops.
526
527    # Label to label transitions. Disallow transitions between repeated labels
528    # with no blank state in between.
529    batch_idx = array_ops.zeros_like(label_states[2:])
530    indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]],
531                              1)
532    indices = array_ops.tile(
533        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
534    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
535    indices += array_ops.expand_dims(batch_idx, 1)
536    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
537    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
538    batched_shape = [batch_size, num_states, num_states]
539    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
540
541    return array_ops.expand_dims(trans, 0) + label_to_label
542
543
544def ctc_state_log_probs(seq_lengths, max_seq_length):
545  """Computes CTC alignment initial and final state log probabilities.
546
547  Create the initial/final state values directly as log values to avoid
548  having to take a float64 log on tpu (which does not exist).
549
550  Args:
551    seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
552    max_seq_length: int, max sequence length possible.
553
554  Returns:
555    initial_state_log_probs, final_state_log_probs
556  """
557
558  batch_size = _get_dim(seq_lengths, 0)
559  num_label_states = max_seq_length + 1
560  num_duration_states = 2
561  num_states = num_duration_states * num_label_states
562  log_0 = math_ops.cast(
563      math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32)
564
565  initial_state_log_probs = array_ops.one_hot(
566      indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
567      depth=num_states,
568      on_value=0.0,
569      off_value=log_0,
570      axis=1)
571
572  label_final_state_mask = array_ops.one_hot(
573      seq_lengths, depth=num_label_states, axis=0)
574  duration_final_state_mask = array_ops.ones(
575      [num_duration_states, 1, batch_size])
576  final_state_mask = duration_final_state_mask * label_final_state_mask
577  final_state_log_probs = (1.0 - final_state_mask) * log_0
578  final_state_log_probs = array_ops.reshape(final_state_log_probs,
579                                            [num_states, batch_size])
580
581  return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
582
583
584def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
585  """Project ilabel log probs to state log probs."""
586
587  num_label_states = _get_dim(labels, 1)
588  blank = ilabel_log_probs[:, :, :1]
589  blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
590  one_hot = array_ops.one_hot(labels, depth=num_labels)
591  one_hot = array_ops.expand_dims(one_hot, axis=0)
592  ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
593  state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
594  state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
595  return array_ops.pad(
596      state_log_probs, [[0, 0], [0, 0], [1, 0]],
597      constant_values=math_ops.log(0.0))
598
599
600def _state_to_olabel(labels, num_labels, states):
601  """Sum state log probs to ilabel log probs."""
602
603  num_label_states = _get_dim(labels, 1) + 1
604  label_states = states[:, :, 1:num_label_states]
605  blank_states = states[:, :, num_label_states:]
606  one_hot = array_ops.one_hot(
607      labels - 1,
608      depth=(num_labels - 1),
609      on_value=0.0,
610      off_value=math_ops.log(0.0))
611  one_hot = array_ops.expand_dims(one_hot, axis=0)
612  label_states = array_ops.expand_dims(label_states, axis=3)
613  label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
614  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
615  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
616
617
618# pylint: disable=redefined-outer-name
619def _state_to_olabel_unique(labels, num_labels, states, unique):
620  """Sum state log probs to ilabel log probs using unique label indices."""
621
622  num_label_states = _get_dim(labels, 1) + 1
623  label_states = states[:, :, 1:num_label_states]
624  blank_states = states[:, :, num_label_states:]
625
626  unique_y, unique_idx = unique
627  mul_reduce = _sum_states(unique_idx, label_states)
628
629  num_frames = states.shape[0]
630  batch_size = states.shape[1]
631  num_states = num_label_states - 1
632  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
633  batch_state_major = array_ops.reshape(batch_state_major,
634                                        [batch_size * num_states, num_frames])
635  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
636  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
637  indices = array_ops.reshape(indices, [-1, 1])
638  scatter = array_ops.scatter_nd(
639      indices=indices,
640      updates=batch_state_major,
641      shape=[batch_size * num_labels, num_frames])
642  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
643
644  mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool)
645  mask = array_ops.scatter_nd(
646      indices=indices,
647      updates=mask,
648      shape=[batch_size * num_labels, num_frames])
649  mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames])
650
651  scatter = array_ops.where(
652      mask, scatter,
653      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)))
654
655  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
656  label_olabels = label_olabels[:, :, 1:]
657
658  blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True)
659
660  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
661
662
663def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
664  """Computes the CTC loss and gradients.
665
666  Most users will want fwd_bwd.ctc_loss
667
668  This function returns the computed gradient, it does not have a gradient
669  of its own defined.
670
671  Args:
672    logits: tensor of shape [frames, batch_size, num_labels]
673    labels: tensor of shape [batch_size, max_label_seq_length]
674    label_length: tensor of shape [batch_size] Length of reference label
675      sequence in labels.
676    logit_length: tensor of shape [batch_size] Length of input sequence in
677      logits.
678    unique: (optional) unique label indices as computed by unique(labels) If
679      supplied, enables an implementation that is faster and more memory
680      efficient on TPU.
681
682  Returns:
683    loss: tensor of shape [batch_size]
684    gradient: tensor of shape [frames, batch_size, num_labels]
685  """
686
687  num_labels = _get_dim(logits, 2)
688  max_label_seq_length = _get_dim(labels, 1)
689
690  ilabel_log_probs = nn_ops.log_softmax(logits)
691  state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
692  state_trans_probs = _ctc_state_trans(labels)
693  initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
694      label_length, max_label_seq_length)
695  fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
696      state_trans_log_probs=math_ops.log(state_trans_probs),
697      initial_state_log_probs=initial_state_log_probs,
698      final_state_log_probs=final_state_log_probs,
699      observed_log_probs=state_log_probs,
700      sequence_length=logit_length)
701
702  if unique:
703    olabel_log_probs = _state_to_olabel_unique(labels, num_labels,
704                                               fwd_bwd_log_probs, unique)
705  else:
706    olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
707
708  grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
709
710  # Applies the sequence mask for the gradient. It is enough to appply the mask
711  # only for ilabel_log_probs because olabel_log_probs already consider the
712  # mask. However, it is just safe and clean to apply it for the gradient.
713  max_logit_length = _get_dim(logits, 0)
714  logit_mask = array_ops.sequence_mask(logit_length, max_logit_length,
715                                       dtypes.float32)
716  logit_mask = array_ops.transpose(logit_mask, perm=[1, 0])
717  logit_mask = array_ops.expand_dims(logit_mask, axis=2)
718  grad *= logit_mask
719
720  loss = -log_likelihood
721  return loss, grad
722
723
724def _ctc_loss_grad(op, grad_loss, _):
725  grad = op.outputs[1]
726  grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
727  grad += [None] * (len(op.inputs) - len(grad))
728  return grad
729
730
731def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
732                          blank_index):
733  part_before = logits[:, :, :blank_index]
734  part_after = logits[:, :, blank_index + 1:]
735  part_blank = logits[:, :, blank_index:blank_index + 1]
736  logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
737  labels = sparse_tensor.SparseTensor(
738      labels.indices,
739      array_ops.where(labels.values < blank_index, labels.values,
740                      labels.values - 1), labels.dense_shape)
741  return _ctc_loss_impl(
742      labels=labels,
743      inputs=logits,
744      sequence_length=logit_length,
745      time_major=logits_time_major,
746      use_cudnn=False)
747
748
749def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
750                       blank_index):
751  part_before = logits[:, :, :blank_index]
752  part_after = logits[:, :, blank_index + 1:]
753  part_blank = logits[:, :, blank_index:blank_index + 1]
754  logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
755  labels = sparse_tensor.SparseTensor(
756      labels.indices,
757      array_ops.where(labels.values < blank_index, labels.values + 1,
758                      labels.values), labels.dense_shape)
759  return _ctc_loss_impl(
760      labels=labels,
761      inputs=logits,
762      sequence_length=logit_length,
763      time_major=logits_time_major,
764      use_cudnn=True)
765
766
767def _ctc_loss_shape(op):
768  return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
769
770
771# pylint: disable=protected-access, invalid-name
772@tf_export(v1=["nn.ctc_loss_v2"])
773@dispatch.add_dispatch_support
774def ctc_loss_v2(labels,
775                logits,
776                label_length,
777                logit_length,
778                logits_time_major=True,
779                unique=None,
780                blank_index=None,
781                name=None):
782  """Computes CTC (Connectionist Temporal Classification) loss.
783
784  This op implements the CTC loss as presented in (Graves et al., 2006).
785
786  Notes:
787
788  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
789    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
790  - Labels may be supplied as either a dense, zero-padded tensor with a
791    vector of label sequence lengths OR as a SparseTensor.
792  - On TPU and GPU: Only dense padded labels are supported.
793  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
794    a SparseTensor will be significantly faster.
795  - Default blank label is 0 rather num_classes - 1, unless overridden by
796    blank_index.
797
798  Args:
799    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
800    logits: tensor of shape [frames, batch_size, num_labels], if
801      logits_time_major == False, shape is [batch_size, frames, num_labels].
802    label_length: tensor of shape [batch_size], None if labels is SparseTensor
803      Length of reference label sequence in labels.
804    logit_length: tensor of shape [batch_size] Length of input sequence in
805      logits.
806    logits_time_major: (optional) If True (default), logits is shaped [time,
807      batch, logits]. If False, shape is [batch, time, logits]
808    unique: (optional) Unique label indices as computed by
809      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
810      implementation on TPU.
811    blank_index: (optional) Set the class index to use for the blank label.
812      Negative values will start from num_classes, ie, -1 will reproduce the
813      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
814      some memory/performance overhead to switching from the default of 0 as an
815      additional shifted copy of the logits may be created.
816    name: A name for this `Op`. Defaults to "ctc_loss_dense".
817
818  Returns:
819    loss: tensor of shape [batch_size], negative log probabilities.
820
821  References:
822      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
823      with Recurrent Neural Networks:
824        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
825        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
826  """
827  if isinstance(labels, sparse_tensor.SparseTensor):
828    if blank_index is None:
829      raise ValueError(
830          "blank_index must be given when using SparseTensor labels.")
831
832    if blank_index < 0:
833      blank_index += _get_dim(logits, 2)
834
835    if blank_index != _get_dim(logits, 2) - 1:
836      logits = array_ops.concat([
837          logits[:, :, :blank_index],
838          logits[:, :, blank_index + 1:],
839          logits[:, :, blank_index:blank_index + 1],
840      ],
841                                axis=2)
842      labels = sparse_tensor.SparseTensor(
843          labels.indices,
844          array_ops.where(labels.values < blank_index, labels.values,
845                          labels.values - 1), labels.dense_shape)
846
847    return ctc_loss(
848        labels=labels,
849        inputs=logits,
850        sequence_length=logit_length,
851        time_major=logits_time_major)
852
853  if blank_index is None:
854    blank_index = 0
855
856  return ctc_loss_dense(
857      labels=labels,
858      logits=logits,
859      label_length=label_length,
860      logit_length=logit_length,
861      logits_time_major=logits_time_major,
862      unique=unique,
863      blank_index=blank_index,
864      name=name)
865
866
867@tf_export("nn.ctc_loss", v1=[])
868@dispatch.add_dispatch_support
869def ctc_loss_v3(labels,
870                logits,
871                label_length,
872                logit_length,
873                logits_time_major=True,
874                unique=None,
875                blank_index=None,
876                name=None):
877  """Computes CTC (Connectionist Temporal Classification) loss.
878
879  This op implements the CTC loss as presented in (Graves et al., 2006).
880
881  Notes:
882
883  - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
884    setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
885  - Labels may be supplied as either a dense, zero-padded tensor with a
886    vector of label sequence lengths OR as a SparseTensor.
887  - On TPU and GPU: Only dense padded labels are supported.
888  - On CPU: Caller may use SparseTensor or dense padded labels but calling with
889    a SparseTensor will be significantly faster.
890  - Default blank label is 0 rather num_classes - 1, unless overridden by
891    blank_index.
892
893  Args:
894    labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
895    logits: tensor of shape [frames, batch_size, num_labels], if
896      logits_time_major == False, shape is [batch_size, frames, num_labels].
897    label_length: tensor of shape [batch_size], None if labels is SparseTensor
898      Length of reference label sequence in labels.
899    logit_length: tensor of shape [batch_size] Length of input sequence in
900      logits.
901    logits_time_major: (optional) If True (default), logits is shaped [time,
902      batch, logits]. If False, shape is [batch, time, logits]
903    unique: (optional) Unique label indices as computed by
904      ctc_unique_labels(labels).  If supplied, enable a faster, memory efficient
905      implementation on TPU.
906    blank_index: (optional) Set the class index to use for the blank label.
907      Negative values will start from num_classes, ie, -1 will reproduce the
908      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
909      some memory/performance overhead to switching from the default of 0 as an
910      additional shifted copy of the logits may be created.
911    name: A name for this `Op`. Defaults to "ctc_loss_dense".
912
913  Returns:
914    loss: tensor of shape [batch_size], negative log probabilities.
915
916  References:
917      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
918      with Recurrent Neural Networks:
919        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
920        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
921  """
922  if isinstance(labels, sparse_tensor.SparseTensor):
923    if blank_index is None:
924      raise ValueError(
925          "blank_index must be given when using SparseTensor labels.")
926
927    if blank_index < 0:
928      blank_index += _get_dim(logits, 2)
929
930    params = {
931        "labels": labels,
932        "logits": logits,
933        "logit_length": logit_length,
934        "logits_time_major": logits_time_major,
935        "blank_index": blank_index
936    }
937
938    if context.executing_eagerly():
939      device_type = _get_context_device_type()
940      can_use_gpu = (
941          # Either user specified GPU or unspecified but GPU is available.
942          (device_type == _GPU_DEVICE_NAME or
943           (device_type is None and context.num_gpus() > 0)))
944      # Under eager context, check the device placement and prefer the
945      if can_use_gpu:
946        res = _ctc_loss_op_cudnn(**params)
947      else:
948        res = _ctc_loss_op_standard(**params)
949    else:
950      api_name = "ctc_loss_" + str(uuid.uuid4())
951      ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
952                                                     _ctc_loss_op_standard)
953      ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
954                                                  _ctc_loss_op_cudnn)
955      res = ctc_loss_op_standard(**params)
956      function_eager.register(ctc_loss_op_cudnn, **params)
957    return res
958
959  if blank_index is None:
960    blank_index = 0
961
962  return ctc_loss_dense(
963      labels=labels,
964      logits=logits,
965      label_length=label_length,
966      logit_length=logit_length,
967      logits_time_major=logits_time_major,
968      unique=unique,
969      blank_index=blank_index,
970      name=name)
971
972
973def ctc_loss_dense(labels,
974                   logits,
975                   label_length,
976                   logit_length,
977                   logits_time_major=True,
978                   unique=None,
979                   blank_index=0,
980                   name=None):
981  """Computes CTC (Connectionist Temporal Classification) loss.
982
983  This op implements the CTC loss as presented in (Graves et al., 2006),
984  using the batched forward backward algorithm described in (Sim et al., 2017).
985
986  Notes:
987    Significant differences from tf.compat.v1.nn.ctc_loss:
988      Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only):
989        For batched operations, GPU and TPU are significantly faster than using
990        ctc_loss on CPU.
991        This implementation runs on CPU, but significantly slower than ctc_loss.
992      Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
993      Logits and labels are dense arrays with padding rather than SparseTensor.
994      The only mode supported is the same as:
995        preprocess_collapse_repeated=False, ctc_merge_repeated=True
996        To collapse labels, the caller can preprocess label sequence first.
997
998    The dense implementation supports both CPU, GPU and TPU. A fast path is
999    provided that significantly improves memory use for large vocabulary if the
1000    caller preprocesses label sequences to get unique label indices on the CPU
1001    (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in
1002    the optional "unique" kwarg. This is especially useful for TPU and GPU but
1003    also works with if used on CPU.
1004
1005  Args:
1006    labels: tensor of shape [batch_size, max_label_seq_length]
1007    logits: tensor of shape [frames, batch_size, num_labels], if
1008      logits_time_major == False, shape is [batch_size, frames, num_labels].
1009    label_length: tensor of shape [batch_size] Length of reference label
1010      sequence in labels.
1011    logit_length: tensor of shape [batch_size] Length of input sequence in
1012      logits.
1013    logits_time_major: (optional) If True (default), logits is shaped [time,
1014      batch, logits]. If False, shape is [batch, time, logits]
1015    unique: (optional) Unique label indices as computed by unique(labels). If
1016      supplied, enable a faster, memory efficient implementation on TPU.
1017    blank_index: (optional) Set the class index to use for the blank label.
1018      Negative values will start from num_classes, ie, -1 will reproduce the
1019      ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
1020      some memory/performance overhead to switching from the default of 0 as an
1021      additional shifted copy of the logits may be created.
1022    name: A name for this `Op`. Defaults to "ctc_loss_dense".
1023
1024  Returns:
1025    loss: tensor of shape [batch_size], negative log probabilities.
1026
1027  References:
1028      Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
1029      with Recurrent Neural Networks:
1030        [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891)
1031        ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
1032      Improving the efficiency of forward-backward algorithm using batched
1033      computation in TensorFlow:
1034        [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944)
1035        ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf))
1036  """
1037
1038  with ops.name_scope(name, "ctc_loss_dense",
1039                      [logits, labels, label_length, logit_length]):
1040    logits = ops.convert_to_tensor(logits, name="logits")
1041    labels = ops.convert_to_tensor(labels, name="labels")
1042    label_length = ops.convert_to_tensor(label_length, name="label_length")
1043    logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
1044
1045    if not logits_time_major:
1046      logits = array_ops.transpose(logits, perm=[1, 0, 2])
1047
1048    if blank_index != 0:
1049      if blank_index < 0:
1050        blank_index += _get_dim(logits, 2)
1051      logits = array_ops.concat([
1052          logits[:, :, blank_index:blank_index + 1],
1053          logits[:, :, :blank_index],
1054          logits[:, :, blank_index + 1:],
1055      ],
1056                                axis=2)
1057      labels = array_ops.where(labels < blank_index, labels + 1, labels)
1058
1059    args = [logits, labels, label_length, logit_length]
1060
1061    if unique:
1062      unique_y, unique_idx = unique
1063      if blank_index != 0:
1064        unique_y = array_ops.where(unique_y < blank_index, unique_y + 1,
1065                                   unique_y)
1066        label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1
1067        max_label_length = _get_dim(unique_y, 1)
1068        label_mask = array_ops.sequence_mask(label_mask_len, max_label_length)
1069        unique_y = array_ops.where(label_mask, unique_y,
1070                                   array_ops.zeros_like(unique_y))
1071      args.extend([unique_y, unique_idx])
1072
1073    @custom_gradient.custom_gradient
1074    def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
1075                         *unique_t):
1076      """Compute CTC loss."""
1077      logits_t.set_shape(logits.shape)
1078      labels_t.set_shape(labels.shape)
1079      label_length_t.set_shape(label_length.shape)
1080      logit_length_t.set_shape(logit_length.shape)
1081      kwargs = dict(
1082          logits=logits_t,
1083          labels=labels_t,
1084          label_length=label_length_t,
1085          logit_length=logit_length_t)
1086      if unique_t:
1087        kwargs["unique"] = unique_t
1088      result = ctc_loss_and_grad(**kwargs)
1089      def grad(grad_loss):
1090        grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]]
1091        grad += [None] * (len(args) - len(grad))
1092        return grad
1093
1094      return result[0], grad
1095
1096    return compute_ctc_loss(*args)
1097
1098
1099@tf_export("nn.collapse_repeated")
1100@dispatch.add_dispatch_support
1101def collapse_repeated(labels, seq_length, name=None):
1102  """Merge repeated labels into single labels.
1103
1104  Args:
1105    labels: Tensor of shape [batch, max value in seq_length]
1106    seq_length: Tensor of shape [batch], sequence length of each batch element.
1107    name: A name for this `Op`. Defaults to "collapse_repeated_labels".
1108
1109  Returns:
1110    A tuple `(collapsed_labels, new_seq_length)` where
1111
1112    collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
1113    labels collapsed and padded to max_seq_length, eg:
1114    `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`
1115
1116    new_seq_length: int tensor of shape [batch] with new sequence lengths.
1117  """
1118
1119  with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
1120    labels = ops.convert_to_tensor(labels, name="labels")
1121    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
1122
1123    # Mask labels that don't equal previous label.
1124    label_mask = array_ops.concat([
1125        array_ops.ones_like(labels[:, :1], dtypes.bool),
1126        math_ops.not_equal(labels[:, 1:], labels[:, :-1])
1127    ],
1128                                  axis=1)
1129
1130    # Filter labels that aren't in the original sequence.
1131    maxlen = _get_dim(labels, 1)
1132    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
1133    label_mask = math_ops.logical_and(label_mask, seq_mask)
1134
1135    # Count masks for new sequence lengths.
1136    new_seq_len = math_ops.reduce_sum(
1137        math_ops.cast(label_mask, dtypes.int32), axis=1)
1138
1139    # Mask indexes based on sequence length mask.
1140    new_maxlen = math_ops.reduce_max(new_seq_len)
1141    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
1142
1143    # Flatten everything and mask out labels to keep and sparse indices.
1144    flat_labels = array_ops.reshape(labels, [-1])
1145    flat_label_mask = array_ops.reshape(label_mask, [-1])
1146    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
1147    idx = math_ops.range(_get_dim(flat_idx_mask, 0))
1148
1149    # Scatter to flat shape.
1150    flat = array_ops.scatter_nd(
1151        indices=array_ops.expand_dims(
1152            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
1153        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
1154        shape=array_ops.shape(flat_idx_mask))
1155
1156    # Reshape back to square batch.
1157    batch_size = _get_dim(labels, 0)
1158    new_shape = [batch_size, new_maxlen]
1159    return (array_ops.reshape(flat, new_shape),
1160            math_ops.cast(new_seq_len, seq_length.dtype))
1161
1162
1163def dense_labels_to_sparse(dense, length):
1164  """Convert dense labels with sequence lengths to sparse tensor.
1165
1166  Args:
1167    dense: tensor of shape [batch, max_length]
1168    length: int tensor of shape [batch] The length of each sequence in dense.
1169
1170  Returns:
1171    tf.sparse.SparseTensor with values only for the valid elements of sequences.
1172  """
1173
1174  flat_values = array_ops.reshape(dense, [-1])
1175  flat_indices = math_ops.range(
1176      array_ops.shape(flat_values, out_type=dtypes.int64)[0])
1177  mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
1178  flat_mask = array_ops.reshape(mask, [-1])
1179  indices = array_ops.expand_dims(
1180      array_ops.boolean_mask(flat_indices, flat_mask), 1)
1181  values = array_ops.boolean_mask(flat_values, flat_mask)
1182  sparse = sparse_tensor.SparseTensor(
1183      indices=indices,
1184      values=math_ops.cast(values, dtypes.int32),
1185      dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
1186  reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
1187  max_length = math_ops.reduce_max(length)
1188  return sparse_tensor.SparseTensor(
1189      indices=reshaped.indices,
1190      values=reshaped.values,
1191      dense_shape=[
1192          math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
1193          math_ops.cast(max_length, dtypes.int64)
1194      ])
1195
1196
1197@tf_export("nn.ctc_unique_labels")
1198@dispatch.add_dispatch_support
1199def ctc_unique_labels(labels, name=None):
1200  """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`.
1201
1202  For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be
1203  used to preprocess labels in input pipeline to for better speed/memory use
1204  computing the ctc loss on TPU.
1205
1206  Example:
1207    ctc_unique_labels([[3, 4, 4, 3]]) ->
1208      unique labels padded with 0: [[3, 4, 0, 0]]
1209      indices of original labels in unique: [0, 1, 1, 0]
1210
1211  Args:
1212    labels: tensor of shape [batch_size, max_label_length] padded with 0.
1213    name: A name for this `Op`. Defaults to "ctc_unique_labels".
1214
1215  Returns:
1216    tuple of
1217      - unique labels, tensor of shape `[batch_size, max_label_length]`
1218      - indices into unique labels, shape `[batch_size, max_label_length]`
1219  """
1220
1221  with ops.name_scope(name, "ctc_unique_labels", [labels]):
1222    labels = ops.convert_to_tensor(labels, name="labels")
1223
1224    def _unique(x):
1225      u = array_ops.unique(x)
1226      y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
1227      y = math_ops.cast(y, dtypes.int64)
1228      return [y, u.idx]
1229
1230    return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32])
1231
1232
1233def _sum_states(idx, states):
1234  """Take logsumexp for each unique state out of all label states.
1235
1236  Args:
1237    idx: tensor of shape [batch, label_length] For each sequence, indices into a
1238      set of unique labels as computed by calling unique.
1239    states: tensor of shape [frames, batch, label_length] Log probabilities for
1240      each label state.
1241
1242  Returns:
1243    tensor of shape [frames, batch_size, label_length], log probabilites summed
1244      for each unique label of the sequence.
1245  """
1246
1247  with ops.name_scope("sum_states"):
1248    idx = ops.convert_to_tensor(idx, name="idx")
1249    num_states = _get_dim(states, 2)
1250    states = array_ops.expand_dims(states, axis=2)
1251    one_hot = array_ops.one_hot(
1252        idx,
1253        depth=num_states,
1254        on_value=0.0,
1255        off_value=math_ops.log(0.0),
1256        axis=1)
1257    return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
1258
1259
1260def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
1261                          final_state_log_probs, observed_log_probs,
1262                          sequence_length):
1263  """Forward-backward algorithm computed in log domain.
1264
1265  Args:
1266    state_trans_log_probs: tensor of shape [states, states] or if different
1267      transition matrix per batch [batch_size, states, states]
1268    initial_state_log_probs: tensor of shape [batch_size, states]
1269    final_state_log_probs: tensor of shape [batch_size, states]
1270    observed_log_probs: tensor of shape [frames, batch_size, states]
1271    sequence_length: tensor of shape [batch_size]
1272
1273  Returns:
1274    forward backward log probabilites: tensor of shape [frames, batch, states]
1275    log_likelihood: tensor of shape [batch_size]
1276
1277  Raises:
1278    ValueError: If state_trans_log_probs has unknown or incorrect rank.
1279  """
1280
1281  if state_trans_log_probs.shape.ndims == 2:
1282    perm = [1, 0]
1283  elif state_trans_log_probs.shape.ndims == 3:
1284    perm = [0, 2, 1]
1285  else:
1286    raise ValueError(
1287        "state_trans_log_probs rank must be known and == 2 or 3, is: %s" %
1288        state_trans_log_probs.shape.ndims)
1289
1290  bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
1291  batch_size = _get_dim(observed_log_probs, 1)
1292
1293  def _forward(state_log_prob, obs_log_prob):
1294    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1295    state_log_prob += state_trans_log_probs
1296    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1297    state_log_prob += obs_log_prob
1298    log_prob_sum = math_ops.reduce_logsumexp(
1299        state_log_prob, axis=-1, keepdims=True)
1300    state_log_prob -= log_prob_sum
1301    return state_log_prob
1302
1303  fwd = _scan(
1304      _forward, observed_log_probs, initial_state_log_probs, inclusive=True)
1305
1306  def _backward(accs, elems):
1307    """Calculate log probs and cumulative sum masked for sequence length."""
1308    state_log_prob, cum_log_sum = accs
1309    obs_log_prob, mask = elems
1310    state_log_prob += obs_log_prob
1311    state_log_prob = array_ops.expand_dims(state_log_prob, axis=1)  # Broadcast.
1312    state_log_prob += bwd_state_trans_log_probs
1313    state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
1314
1315    log_prob_sum = math_ops.reduce_logsumexp(
1316        state_log_prob, axis=-1, keepdims=True)
1317    state_log_prob -= log_prob_sum
1318
1319    cum_log_sum += array_ops.squeeze(log_prob_sum) * mask
1320    batched_mask = array_ops.expand_dims(mask, axis=1)
1321    out = state_log_prob * batched_mask
1322    out += final_state_log_probs * (1.0 - batched_mask)
1323    return out, cum_log_sum
1324
1325  zero_log_sum = array_ops.zeros([batch_size])
1326  maxlen = _get_dim(observed_log_probs, 0)
1327  mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
1328  mask = array_ops.transpose(mask, perm=[1, 0])
1329
1330  bwd, cum_log_sum = _scan(
1331      _backward, (observed_log_probs, mask),
1332      (final_state_log_probs, zero_log_sum),
1333      reverse=True,
1334      inclusive=True)
1335
1336  fwd_bwd_log_probs = fwd[1:] + bwd[1:]
1337  fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
1338      fwd_bwd_log_probs, axis=2, keepdims=True)
1339  fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
1340  fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
1341
1342  log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
1343
1344  return fwd_bwd_log_probs, log_likelihood
1345
1346
1347# TODO(tombagby): This is currently faster for the ctc implementation than using
1348# functional_ops.scan, but could be replaced by that or something similar if
1349# things change.
1350def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
1351  """Repeatedly applies callable `fn` to a sequence of elements.
1352
1353  Implemented by functional_ops.While, tpu friendly, no gradient.
1354
1355  This is similar to functional_ops.scan but significantly faster on tpu/gpu
1356  for the forward backward use case.
1357
1358  Examples:
1359    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]
1360
1361    Multiple accumulators:
1362      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
1363
1364    Multiple inputs:
1365      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
1366
1367  Args:
1368    fn: callable, fn(accumulators, element) return new accumulator values. The
1369      (possibly nested) sequence of accumulators is the same as `initial` and
1370      the return value must have the same structure.
1371    elems: A (possibly nested) tensor which will be unpacked along the first
1372      dimension. The resulting slices will be the second argument to fn. The
1373      first dimension of all nested input tensors must be the same.
1374    initial: A tensor or (possibly nested) sequence of tensors with initial
1375      values for the accumulators.
1376    reverse: (optional) True enables scan and output elems in reverse order.
1377    inclusive: (optional) True includes the initial accumulator values in the
1378      output. Length of output will be len(elem sequence) + 1. Not meaningful if
1379      final_only is True.
1380    final_only: (optional) When True, return only the final accumulated values,
1381      not the concatenation of accumulated values for each input.
1382
1383  Returns:
1384    A (possibly nested) sequence of tensors with the results of applying fn
1385    to tensors unpacked from elems and previous accumulator values.
1386  """
1387
1388  flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
1389  num_elems = array_ops.shape(flat_elems[0])[0]
1390  pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
1391  flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
1392  pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
1393  accum_dtypes = [x.dtype for x in flat_initial]
1394  num_accums = len(flat_initial)
1395
1396  # Types for counter, [outputs], [accumulators] loop arguments.
1397  if final_only:
1398    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
1399  else:
1400    loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
1401
1402  # TODO(tombagby): Update to tfe.defun
1403  def cond(i, num_elems, *args):
1404    del args
1405    return i >= 0 if reverse else i < num_elems
1406
1407  # The loop *args are [output tensors] + [accumulator tensors] which must
1408  # be paired. Each output corresponds to one accumulator.
1409  def body(i, num_elems, *args):
1410    """Loop body."""
1411    i.set_shape([])
1412    if final_only:
1413      accum = args
1414    else:
1415      out, accum = args[:num_accums], args[num_accums:]
1416    slices = [array_ops.gather(e, i) for e in flat_elems]
1417    accum = fn(pack(accum), pack_elems(slices))
1418    flat_accum = nest.flatten(accum)
1419    if final_only:
1420      new_out = []
1421    else:
1422      update_i = i + 1 if inclusive and not reverse else i
1423      new_out = [
1424          inplace_ops.alias_inplace_update(x, update_i, y)
1425          for x, y in zip(out, flat_accum)
1426      ]
1427    i = i - 1 if reverse else i + 1
1428    return [i, num_elems] + new_out + flat_accum
1429
1430  init_i = (
1431      array_ops.shape(flat_elems[0])[0] -
1432      1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
1433  outputs = []
1434  if not final_only:
1435    num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
1436    for initial_accum in flat_initial:
1437      out_shape = array_ops.concat(
1438          [[num_outputs], array_ops.shape(initial_accum)], 0)
1439      out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
1440      if inclusive:
1441        out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0),
1442                                            initial_accum)
1443      outputs.append(out)
1444  loop_in = [init_i, num_elems] + outputs + flat_initial
1445  hostmem = [
1446      i for i, x in enumerate(loop_in)
1447      if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
1448  ]
1449
1450  if context.executing_eagerly():
1451    loop_results = loop_in
1452    while cond(*loop_results):
1453      loop_results = body(*loop_results)
1454  else:
1455    # TODO(tombagby): Update to while_v2.
1456    cond = function.Defun(*loop_dtypes)(cond)
1457    body = function.Defun(*loop_dtypes)(body)
1458    loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
1459  out = loop_results[2:num_accums + 2]
1460  return pack(out)
1461
1462
1463def _get_dim(tensor, i):
1464  """Get value of tensor shape[i] preferring static value if available."""
1465  return tensor_shape.dimension_value(
1466      tensor.shape[i]) or array_ops.shape(tensor)[i]
1467