• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ctc_ops.ctc_loss_op."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import random_seed
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import ctc_ops
33from tensorflow.python.ops import gradients_impl
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import random_ops
36from tensorflow.python.ops import sparse_ops
37from tensorflow.python.platform import test
38
39
40def SimpleSparseTensorFrom(x):
41  """Create a very simple SparseTensor with dimensions (batch, time).
42
43  Args:
44    x: a list of lists of type int
45
46  Returns:
47    x_ix and x_val, the indices and values of the SparseTensor<2>.
48  """
49  x_ix = []
50  x_val = []
51  for batch_i, batch in enumerate(x):
52    for time, val in enumerate(batch):
53      x_ix.append([batch_i, time])
54      x_val.append(val)
55  x_shape = [len(x), np.asarray(x_ix).max(0)[1] + 1]
56  x_ix = constant_op.constant(x_ix, dtypes.int64)
57  x_val = constant_op.constant(x_val, dtypes.int32)
58  x_shape = constant_op.constant(x_shape, dtypes.int64)
59
60  return sparse_tensor.SparseTensor(x_ix, x_val, x_shape)
61
62
63def _ctc_loss_v2(labels, inputs, sequence_length,
64                 preprocess_collapse_repeated=False,
65                 ctc_merge_repeated=True,
66                 ignore_longer_outputs_than_inputs=False,
67                 time_major=True):
68  """Call ctc_loss_v2 with v1 args."""
69  assert not preprocess_collapse_repeated
70  assert ctc_merge_repeated
71  assert not ignore_longer_outputs_than_inputs
72  return ctc_ops.ctc_loss_v2(
73      labels=labels,
74      logits=inputs,
75      logit_length=sequence_length,
76      label_length=None,
77      blank_index=-1,
78      logits_time_major=time_major)
79
80
81class CTCLossTest(test.TestCase):
82
83  def _testCTCLoss(self,
84                   inputs,
85                   seq_lens,
86                   labels,
87                   loss_truth,
88                   grad_truth,
89                   expected_err_re=None):
90    self.assertEqual(len(inputs), len(grad_truth))
91
92    inputs_t = constant_op.constant(inputs)
93
94    with self.cached_session(use_gpu=False) as sess:
95      loss = _ctc_loss_v2(
96          inputs=inputs_t, labels=labels, sequence_length=seq_lens)
97      grad = gradients_impl.gradients(loss, [inputs_t])[0]
98
99      self.assertShapeEqual(loss_truth, loss)
100      self.assertShapeEqual(grad_truth, grad)
101
102      if expected_err_re is None:
103        (tf_loss, tf_grad) = self.evaluate([loss, grad])
104        self.assertAllClose(tf_loss, loss_truth, atol=1e-6)
105        self.assertAllClose(tf_grad, grad_truth, atol=1e-6)
106      else:
107        with self.assertRaisesOpError(expected_err_re):
108          self.evaluate([loss, grad])
109
110  @test_util.run_v1_only("b/120545219")
111  def testBasic(self):
112    """Test two batch entries."""
113    # Input and ground truth from Alex Graves' implementation.
114    #
115    #### Batch entry 0 #####
116    # targets: 0 1 2 1 0
117    # outputs:
118    # 0 0.633766 0.221185 0.0917319 0.0129757 0.0142857 0.0260553
119    # 1 0.111121 0.588392 0.278779 0.0055756 0.00569609 0.010436
120    # 2 0.0357786 0.633813 0.321418 0.00249248 0.00272882 0.0037688
121    # 3 0.0663296 0.643849 0.280111 0.00283995 0.0035545 0.00331533
122    # 4 0.458235 0.396634 0.123377 0.00648837 0.00903441 0.00623107
123    # alpha:
124    # 0 -3.64753 -0.456075 -inf -inf -inf -inf -inf -inf -inf -inf -inf
125    # 1 -inf -inf -inf -0.986437 -inf -inf -inf -inf -inf -inf -inf
126    # 2 -inf -inf -inf -inf -inf -2.12145 -inf -inf -inf -inf -inf
127    # 3 -inf -inf -inf -inf -inf -inf -inf -2.56174 -inf -inf -inf
128    # 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf -3.34211 -inf
129    # beta:
130    # 0 -inf -2.88604 -inf -inf -inf -inf -inf -inf -inf -inf -inf
131    # 1 -inf -inf -inf -2.35568 -inf -inf -inf -inf -inf -inf -inf
132    # 2 -inf -inf -inf -inf -inf -1.22066 -inf -inf -inf -inf -inf
133    # 3 -inf -inf -inf -inf -inf -inf -inf -0.780373 -inf -inf -inf
134    # 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf 0 0
135    # prob: -3.34211
136    # outputDerivs:
137    # 0 -0.366234 0.221185 0.0917319 0.0129757 0.0142857 0.0260553
138    # 1 0.111121 -0.411608 0.278779 0.0055756 0.00569609 0.010436
139    # 2 0.0357786 0.633813 -0.678582 0.00249248 0.00272882 0.0037688
140    # 3 0.0663296 -0.356151 0.280111 0.00283995 0.0035545 0.00331533
141    # 4 -0.541765 0.396634 0.123377 0.00648837 0.00903441 0.00623107
142    #
143    #### Batch entry 1 #####
144    #
145    # targets: 0 1 1 0
146    # outputs:
147    # 0 0.30176 0.28562 0.0831517 0.0862751 0.0816851 0.161508
148    # 1 0.24082 0.397533 0.0557226 0.0546814 0.0557528 0.19549
149    # 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 0.202456
150    # 3 0.280884 0.429522 0.0326593 0.0339046 0.0326856 0.190345
151    # 4 0.423286 0.315517 0.0338439 0.0393744 0.0339315 0.154046
152    # alpha:
153    # 0 -1.8232 -1.19812 -inf -inf -inf -inf -inf -inf -inf
154    # 1 -inf -2.19315 -2.83037 -2.1206 -inf -inf -inf -inf -inf
155    # 2 -inf -inf -inf -2.03268 -3.71783 -inf -inf -inf -inf
156    # 3 -inf -inf -inf -inf -inf -4.56292 -inf -inf -inf
157    # 4 -inf -inf -inf -inf -inf -inf -inf -5.42262 -inf
158    # beta:
159    # 0 -inf -4.2245 -inf -inf -inf -inf -inf -inf -inf
160    # 1 -inf -inf -inf -3.30202 -inf -inf -inf -inf -inf
161    # 2 -inf -inf -inf -inf -1.70479 -0.856738 -inf -inf -inf
162    # 3 -inf -inf -inf -inf -inf -0.859706 -0.859706 -0.549337 -inf
163    # 4 -inf -inf -inf -inf -inf -inf -inf 0 0
164    # prob: -5.42262
165    # outputDerivs:
166    # 0 -0.69824 0.28562 0.0831517 0.0862751 0.0816851 0.161508
167    # 1 0.24082 -0.602467 0.0557226 0.0546814 0.0557528 0.19549
168    # 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 -0.797544
169    # 3 0.280884 -0.570478 0.0326593 0.0339046 0.0326856 0.190345
170    # 4 -0.576714 0.315517 0.0338439 0.0393744 0.0339315 0.154046
171
172    # max_time_steps == 7
173    depth = 6
174
175    # seq_len_0 == 5
176    targets_0 = [0, 1, 2, 1, 0]
177    loss_log_prob_0 = -3.34211
178    # dimensions are time x depth
179    input_prob_matrix_0 = np.asarray(
180        [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
181         [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
182         [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
183         [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
184         [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
185        dtype=np.float32)
186    input_log_prob_matrix_0 = np.log(input_prob_matrix_0)
187    gradient_log_prob_0 = np.asarray(
188        [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
189         [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436],
190         [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688],
191         [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533],
192         [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
193        dtype=np.float32)
194
195    # seq_len_1 == 5
196    targets_1 = [0, 1, 1, 0]
197    loss_log_prob_1 = -5.42262
198    # dimensions are time x depth
199
200    input_prob_matrix_1 = np.asarray(
201        [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
202         [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
203         [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
204         [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
205         [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
206        dtype=np.float32)
207    input_log_prob_matrix_1 = np.log(input_prob_matrix_1)
208    gradient_log_prob_1 = np.asarray(
209        [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
210         [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549],
211         [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544],
212         [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345],
213         [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
214        dtype=np.float32)
215
216    # len max_time_steps array of 2 x depth matrices
217    inputs = [
218        np.vstack(
219            [input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]])
220        for t in range(5)
221    ] + 2 * [np.nan * np.ones((2, depth), np.float32)]
222
223    # convert inputs into [max_time x batch_size x depth tensor] Tensor
224    inputs = np.asarray(inputs, dtype=np.float32)
225
226    # len batch_size array of label vectors
227    labels = SimpleSparseTensorFrom([targets_0, targets_1])
228
229    # batch_size length vector of sequence_lengths
230    seq_lens = np.array([5, 5], dtype=np.int32)
231
232    # output: batch_size length vector of negative log probabilities
233    loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32)
234
235    # output: len max_time_steps array of 2 x depth matrices
236    grad_truth = [
237        np.vstack([gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]])
238        for t in range(5)
239    ] + 2 * [np.zeros((2, depth), np.float32)]
240
241    # convert grad_truth into [max_time x batch_size x depth] Tensor
242    grad_truth = np.asarray(grad_truth, dtype=np.float32)
243
244    self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth)
245
246  def test_time_major(self):
247    """Testing time_major param.
248
249
250    testing if transposing and setting time_major=False will result in the same
251    loss
252    """
253    # [max_time x batch_size x depth tensor]
254    inputs = np.random.randn(2, 2, 3).astype(np.float32)
255    labels = SimpleSparseTensorFrom([[0, 1], [1, 0]])
256    seq_lens = np.array([2, 2], dtype=np.int32)
257
258    inputs_t = constant_op.constant(inputs)
259
260    # Transposing tensor to [batch_size x max_time x depth tensor]
261    inputs_t_transposed = constant_op.constant(inputs.transpose(1, 0, 2))
262
263    with self.session(use_gpu=False) as sess:
264      loss = _ctc_loss_v2(
265          inputs=inputs_t, labels=labels, sequence_length=seq_lens)
266      loss_transposed = _ctc_loss_v2(
267          inputs=inputs_t_transposed,
268          labels=labels,
269          sequence_length=seq_lens,
270          time_major=False)
271
272      (tf_loss, tf_loss_transposed) = self.evaluate([loss, loss_transposed])
273      self.assertAllEqual(tf_loss, tf_loss_transposed)
274
275  @test_util.run_v1_only("b/120545219")
276  def testInvalidSecondGradient(self):
277    inputs = np.random.randn(2, 2, 3).astype(np.float32)
278    inputs_t = constant_op.constant(inputs)
279    labels = SimpleSparseTensorFrom([[0, 1], [1, 0]])
280    seq_lens = np.array([2, 2], dtype=np.int32)
281    v = [1.0]
282
283    with self.session(use_gpu=False):
284      loss = _ctc_loss_v2(
285          inputs=inputs_t, labels=labels, sequence_length=seq_lens)
286      # Taking this second gradient should fail, since it is not
287      # yet supported.
288      with self.assertRaisesRegex(LookupError, "explicitly disabled"):
289        _ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
290
291  @test_util.run_v1_only("b/120545219")
292  def testEmptyBatch(self):
293    inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2))
294    sequence_lengths = constant_op.constant([], dtype=dtypes.int32)
295    labels = sparse_tensor.SparseTensor(
296        indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64),
297        values=constant_op.constant([], shape=(0,), dtype=dtypes.int32),
298        dense_shape=[5, 5])
299
300    with self.session(use_gpu=False) as sess:
301      with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
302                                  "batch_size must not be 0"):
303        sess.run(_ctc_loss_v2(labels, inputs, sequence_lengths))
304
305
306class CTCLossTestV2(test.TestCase, parameterized.TestCase):
307
308  @test_util.run_in_graph_and_eager_modes
309  def testCtcLossV2(self):
310    random_seed.set_random_seed(5)
311
312    batch_size = 8
313    num_labels = 6
314    max_label_length = 5
315    num_frames = 12
316
317    labels = random_ops.random_uniform(
318        [batch_size, max_label_length], minval=1, maxval=num_labels,
319        dtype=dtypes.int64)
320    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
321
322    label_length = random_ops.random_uniform(
323        [batch_size], minval=2, maxval=max_label_length, dtype=dtypes.int64)
324    label_mask = array_ops.sequence_mask(
325        label_length, maxlen=max_label_length, dtype=label_length.dtype)
326    labels *= label_mask
327    logit_length = [num_frames] * batch_size
328
329    with backprop.GradientTape() as t:
330      t.watch(logits)
331      ref_loss = ctc_ops.ctc_loss_v2(
332          labels=labels,
333          logits=logits,
334          label_length=label_length,
335          logit_length=logit_length)
336    ref_grad = t.gradient(ref_loss, [logits])
337
338    sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
339
340    def assert_same_loss_and_grads(loss):
341      if context.executing_eagerly():
342        return
343      with self.cached_session():
344        self.assertAllClose(*self.evaluate([loss, ref_loss]))
345        grad = gradients_impl.gradients(loss, [logits])
346        self.assertAllClose(
347            *self.evaluate([grad, ref_grad]), rtol=2e-06, atol=2e-06)
348
349    assert_same_loss_and_grads(
350        ctc_ops.ctc_loss_v2(
351            labels=sparse_labels,
352            logits=logits,
353            label_length=label_length,
354            logit_length=logit_length,
355            blank_index=0))
356
357  @test_util.run_v1_only("b/120545219")
358  def testCtcLossDenseIsSameAsCtcLoss(self):
359    with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
360      random_seed.set_random_seed(5)
361
362      batch_size = 8
363      num_labels = 6
364      label_length = 5
365      minimum_logits_length = 10
366      num_frames = minimum_logits_length + batch_size
367      logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
368      labels = random_ops.random_uniform(
369          [batch_size, label_length], minval=1, maxval=num_labels,
370          dtype=dtypes.int64)
371
372      label_lengths = random_ops.random_uniform(
373          [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
374      label_mask = array_ops.sequence_mask(
375          label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
376      labels *= label_mask
377
378      logit_lengths = math_ops.range(batch_size) + minimum_logits_length
379
380      ctc_loss = ctc_ops.ctc_loss_dense(
381          labels=labels,
382          logits=logits,
383          label_length=label_lengths,
384          logit_length=logit_lengths)
385      ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
386
387      # Shift labels down by one (move blank from 0 to num_labels -1)
388      tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1
389      tf_nn_ctc_logits = array_ops.concat([
390          logits[:, :, 1:],
391          logits[:, :, 0:1],
392      ], axis=2)
393
394      tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
395          tf_ctc_loss_labels, label_lengths)
396
397      tf_nn_ctc_loss = ctc_ops.ctc_loss(
398          labels=tf_ctc_loss_labels,
399          inputs=tf_nn_ctc_logits,
400          sequence_length=logit_lengths,
401          time_major=True)
402      tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
403
404      with self.cached_session() as sess:
405        for _ in range(32):
406          self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
407          self.assertAllClose(
408              *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
409              rtol=4e-06,
410              atol=4e-06)
411
412  @test_util.run_v1_only("b/120545219")
413  def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self):
414    random_seed.set_random_seed(5)
415
416    batch_size = 8
417    num_labels = 6
418    label_length = 5
419    num_frames = 12
420    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
421    labels = random_ops.random_uniform(
422        [batch_size, label_length], minval=1, maxval=num_labels,
423        dtype=dtypes.int64)
424
425    label_lengths = random_ops.random_uniform(
426        [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
427    label_mask = array_ops.sequence_mask(
428        label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
429    labels *= label_mask
430
431    logit_lengths = [num_frames] * batch_size
432
433    ctc_loss = ctc_ops.ctc_loss_dense(
434        labels=labels,
435        logits=logits,
436        label_length=label_lengths,
437        logit_length=logit_lengths,
438        unique=ctc_ops.ctc_unique_labels(labels))
439    ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
440
441    # Shift labels down by one (move blank from 0 to num_labels -1)
442    tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1
443    tf_nn_ctc_logits = array_ops.concat([
444        logits[:, :, 1:],
445        logits[:, :, 0:1],
446    ], axis=2)
447
448    tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
449        tf_ctc_loss_labels, label_lengths)
450
451    tf_nn_ctc_loss = ctc_ops.ctc_loss(
452        labels=tf_ctc_loss_labels,
453        inputs=tf_nn_ctc_logits,
454        sequence_length=logit_lengths,
455        time_major=True)
456    tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
457
458    with self.cached_session():
459      for _ in range(32):
460        self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
461        self.assertAllClose(
462            *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
463            rtol=2e-06,
464            atol=2e-06)
465
466  @test_util.run_v1_only("b/120545219")
467  def testCtcLossDenseUniqueFastPathWithBlankIndexIsSameAsCtcLoss(self):
468    random_seed.set_random_seed(5)
469
470    batch_size = 8
471    num_labels = 6
472    label_length = 5
473    num_frames = 12
474    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
475    labels = random_ops.random_uniform([batch_size, label_length],
476                                       minval=0,
477                                       maxval=num_labels - 1,
478                                       dtype=dtypes.int64)
479
480    label_lengths = random_ops.random_uniform([batch_size],
481                                              minval=2,
482                                              maxval=label_length,
483                                              dtype=dtypes.int64)
484    label_mask = array_ops.sequence_mask(
485        label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
486    labels *= label_mask
487
488    logit_lengths = [num_frames] * batch_size
489
490    tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
491    tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(tf_ctc_loss_labels,
492                                                        label_lengths)
493
494    tf_nn_ctc_loss = ctc_ops.ctc_loss(
495        labels=tf_ctc_loss_labels,
496        inputs=logits,
497        sequence_length=logit_lengths,
498        time_major=True)
499    tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
500
501    # Shift the blank logits/labels to be somewhere in the middle.
502    blank_index = 2
503    shifted_logits = array_ops.concat([
504        logits[:, :, :blank_index],
505        logits[:, :, -1:],
506        logits[:, :, blank_index:-1],
507    ],
508                                      axis=2)
509    shifted_labels = array_ops.where_v2(labels < blank_index, labels,
510                                        labels + 1)
511
512    ctc_loss = ctc_ops.ctc_loss_dense(
513        labels=shifted_labels,
514        logits=shifted_logits,
515        label_length=label_lengths,
516        logit_length=logit_lengths,
517        blank_index=blank_index,
518        unique=ctc_ops.ctc_unique_labels(shifted_labels))
519    ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
520
521    with self.cached_session() as sess:
522      for _ in range(32):
523        self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
524        self.assertAllClose(
525            *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
526            rtol=2e-06,
527            atol=2e-06)
528
529  @test_util.run_v1_only("b/120545219")
530  def testCtcLossDenseWithBlankIndexIsSameAsCtcLoss(self):
531    random_seed.set_random_seed(5)
532
533    batch_size = 8
534    num_labels = 6
535    label_length = 5
536    num_frames = 12
537    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
538    labels = random_ops.random_uniform(
539        [batch_size, label_length], minval=0, maxval=num_labels-1,
540        dtype=dtypes.int64)
541
542    label_lengths = random_ops.random_uniform(
543        [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
544    label_mask = array_ops.sequence_mask(
545        label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
546    labels *= label_mask
547
548    logit_lengths = [num_frames] * batch_size
549
550    tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
551    tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
552        tf_ctc_loss_labels, label_lengths)
553
554    tf_nn_ctc_loss = ctc_ops.ctc_loss(
555        labels=tf_ctc_loss_labels,
556        inputs=logits,
557        sequence_length=logit_lengths,
558        time_major=True)
559    tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
560
561    # Shift the blank logits/labels to be somewhere in the middle.
562    blank_index = 2
563    shifted_logits = array_ops.concat([
564        logits[:, :, :blank_index],
565        logits[:, :, -1:],
566        logits[:, :, blank_index:-1],
567    ], axis=2)
568    shifted_labels = array_ops.where_v2(labels < blank_index, labels,
569                                        labels + 1)
570
571    ctc_loss = ctc_ops.ctc_loss_dense(
572        labels=shifted_labels,
573        logits=shifted_logits,
574        label_length=label_lengths,
575        logit_length=logit_lengths,
576        blank_index=blank_index)
577    ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
578
579    with self.cached_session() as sess:
580      for _ in range(32):
581        self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
582        self.assertAllClose(
583            *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
584            rtol=2e-06,
585            atol=2e-06)
586
587  @test_util.run_v1_only("b/120545219")
588  def testCtcLossDenseWithNegativeBlankIndexIsSameAsCtcLoss(self):
589    with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
590      random_seed.set_random_seed(5)
591
592      batch_size = 8
593      num_labels = 6
594      label_length = 5
595      num_frames = 12
596      logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
597      labels = random_ops.random_uniform(
598          [batch_size, label_length], minval=0, maxval=num_labels-1,
599          dtype=dtypes.int64)
600
601      label_lengths = random_ops.random_uniform(
602          [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
603      label_mask = array_ops.sequence_mask(
604          label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
605      labels *= label_mask
606
607      logit_lengths = [num_frames] * batch_size
608
609      ctc_loss = ctc_ops.ctc_loss_dense(
610          labels=labels,
611          logits=logits,
612          label_length=label_lengths,
613          logit_length=logit_lengths,
614          blank_index=-1)
615      ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
616
617      tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
618      tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
619          tf_ctc_loss_labels, label_lengths)
620
621      tf_nn_ctc_loss = ctc_ops.ctc_loss(
622          labels=tf_ctc_loss_labels,
623          inputs=logits,
624          sequence_length=logit_lengths,
625          time_major=True)
626      tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
627
628      with self.cached_session() as sess:
629        for _ in range(32):
630          self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
631          self.assertAllClose(
632              *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
633              rtol=2e-06,
634              atol=2e-06)
635
636  @parameterized.parameters((False, 0), (True, 0), (False, -1), (True, -1))
637  def testCtcLossDenseWithUndefinedStaticDimensions(self, unique, blank_index):
638    random_seed.set_random_seed(5)
639
640    # Trace without a batch size and number of frames
641    batch_size = None
642    num_labels = 6
643    label_length = 5
644    num_frames = None
645
646    @def_function.function
647    def func(labels, logits, label_lengths, logit_lengths):
648      unique_labels = ctc_ops.ctc_unique_labels(labels) if unique else None
649      return ctc_ops.ctc_loss_dense(
650          labels=labels,
651          logits=logits,
652          label_length=label_lengths,
653          logit_length=logit_lengths,
654          unique=unique_labels,
655          blank_index=blank_index)
656
657    labels_spec = tensor_spec.TensorSpec([batch_size, label_length],
658                                         dtypes.int64)
659    logits_spec = tensor_spec.TensorSpec([num_frames, batch_size, num_labels],
660                                         dtypes.float32)
661    label_lengths_spec = tensor_spec.TensorSpec([batch_size], dtypes.int64)
662    logit_lengths_spec = tensor_spec.TensorSpec([batch_size], dtypes.int64)
663
664    f = func.get_concrete_function(
665        labels_spec, logits_spec, label_lengths_spec, logit_lengths_spec)
666
667    # Execute with a defined batch size and number of frames
668    batch_size = 8
669    num_frames = 12
670
671    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
672    labels = random_ops.random_uniform(
673        [batch_size, label_length], minval=1, maxval=num_labels,
674        dtype=dtypes.int64)
675
676    label_lengths = random_ops.random_uniform(
677        [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
678    label_mask = array_ops.sequence_mask(
679        label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
680    labels *= label_mask
681
682    logit_lengths = constant_op.constant(
683        [num_frames] * batch_size, dtype=dtypes.int64)
684
685    f(labels, logits, label_lengths, logit_lengths)
686
687  def testCollapseRepeated(self):
688    collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
689        labels=[[1, 3, 3, 3, 0],
690                [1, 4, 4, 4, 0],
691                [4, 2, 2, 9, 4]],
692        seq_length=[4, 5, 5])
693    self.assertAllEqual(new_seq_lengths, [2, 3, 4])
694    self.assertAllEqual(
695        collapsed,
696        [[1, 3, 0, 0],
697         [1, 4, 0, 0],
698         [4, 2, 9, 4]])
699
700  def testCollapseRepeatedPreservesDtypes(self):
701    collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
702        labels=constant_op.constant(
703            [[1, 3, 3, 3, 0],
704             [1, 4, 4, 4, 0],
705             [4, 2, 2, 9, 4]],
706            dtype=dtypes.int64),
707        seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64))
708    self.assertEqual(new_seq_lengths.dtype, dtypes.int64)
709    self.assertEqual(collapsed.dtype, dtypes.int64)
710    self.assertAllEqual(new_seq_lengths, [2, 3, 4])
711    self.assertAllEqual(
712        collapsed,
713        [[1, 3, 0, 0],
714         [1, 4, 0, 0],
715         [4, 2, 9, 4]])
716
717  def testCollapseRepeatedExtraPadding(self):
718    collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
719        labels=[[1, 3, 3, 3, 0, 0, 0],
720                [1, 4, 4, 4, 0, 1, 2],
721                [4, 2, 2, 9, 4, 0, 0]],
722        seq_length=[4, 5, 5])
723    self.assertAllEqual(new_seq_lengths, [2, 3, 4])
724    self.assertAllEqual(
725        collapsed,
726        [[1, 3, 0, 0],
727         [1, 4, 0, 0],
728         [4, 2, 9, 4]])
729
730  def testCollapseRepeatedFrontRepeats(self):
731    collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
732        labels=[[1, 1, 1, 2, 2],
733                [1, 1, 1, 2, 2],
734                [1, 1, 1, 2, 2]],
735        seq_length=[5, 4, 3])
736    self.assertAllEqual(new_seq_lengths, [2, 2, 1])
737    self.assertAllEqual(
738        collapsed,
739        [[1, 2],
740         [1, 2],
741         [1, 0]])
742
743  def testCollapseRepeatedAllLabelsTheSame(self):
744    collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
745        labels=[[1, 1, 1, 1, 1],
746                [1, 1, 1, 1, 1],
747                [1, 1, 1, 1, 1]],
748        seq_length=[4, 5, 1])
749    self.assertAllEqual(new_seq_lengths, [1, 1, 1])
750    self.assertAllEqual(
751        collapsed,
752        [[1],
753         [1],
754         [1]])
755
756  def testDenseSequencesToSparse(self):
757    labels = [[1, 3, 3, 3, 0],
758              [1, 4, 4, 4, 0],
759              [4, 2, 2, 9, 4]]
760    length = [4, 5, 5]
761    sparse = ctc_ops.dense_labels_to_sparse(labels, length)
762    new_dense = sparse_ops.sparse_tensor_to_dense(sparse)
763
764    self.assertAllEqual(labels, new_dense)
765
766    padded_labels = [[1, 3, 3, 3, 0, 0, 0, 0],
767                     [1, 4, 4, 4, 0, 0, 0, 0],
768                     [4, 2, 2, 9, 4, 0, 0, 0]]
769    length = [4, 5, 5]
770    sparse = ctc_ops.dense_labels_to_sparse(padded_labels, length)
771    padded_dense = sparse_ops.sparse_tensor_to_dense(sparse)
772
773    self.assertAllEqual(padded_dense, new_dense)
774
775  def testUnique(self):
776    labels = [
777        [3, 4, 4, 3],
778        [1, 1, 1, 0],
779    ]
780    unique, idx = ctc_ops.ctc_unique_labels(labels)
781    self.assertAllEqual([
782        [3, 4, 0, 0],
783        [1, 0, 0, 0],
784    ], unique)
785    self.assertAllEqual([
786        [0, 1, 1, 0],
787        [0, 0, 0, 1],
788    ], idx)
789
790  def testSumStates(self):
791    idx = [
792        [0, 1, 0, 1],
793        [0, 0, 0, 1],
794    ]
795    states = math_ops.log([
796        [[1.0, 2.0, 3.0, 4.0],
797         [5.0, 6.0, 7.0, 8.0]],
798        [[0.1, 0.2, 0.3, 0.4],
799         [0.5, 0.6, 0.7, 0.8]],
800    ])
801    sum_of_states = math_ops.exp(ctc_ops._sum_states(idx, states))
802    self.assertAllClose([
803        [[4.0, 6.0, 0.0, 0.0],
804         [18.0, 8.0, 0.0, 0.0]],
805        [[0.4, 0.6, 0.0, 0.0],
806         [1.8, 0.8, 0.0, 0.0]]
807    ], sum_of_states)
808
809  def testStateToOlabel(self):
810    labels = [
811        [3, 4, 3, 4],
812        [1, 1, 1, 0],
813    ]
814    num_labels = 8
815
816    # 3 frames, 2 batch, 10 states (5 label, 5 blank).
817    states = [
818        [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
819         [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
820        [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
821         [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
822        [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
823         [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
824    ]
825    labels = ops.convert_to_tensor(labels)
826    states = math_ops.log(states)
827    olabel = ctc_ops._state_to_olabel(labels, num_labels, states)
828    olabel = math_ops.exp(olabel)
829    blank = olabel[:, :, 0]
830    self.assertAllClose(blank, [
831        [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
832         0.26 + 0.27 + 0.28 + 0.29 + 0.30],
833        [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
834         2.6 + 2.7 + 2.8 + 2.9 + 3.0],
835        [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
836         26.0 + 27.0 + 28.0 + 29.0 + 30.0]
837    ])
838    self.assertAllClose(olabel[:, :, 1:], [
839        [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
840         [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
841        [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
842         [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
843        [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
844         [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
845    ])
846
847  def testStateToOlabelUnique(self):
848    labels = [
849        [3, 4, 3, 4],
850        [1, 1, 1, 0],
851    ]
852    num_labels = 8
853
854    # 3 frames, 2 batch, 10 states (5 label, 5 blank).
855    states = [
856        [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
857         [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
858        [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
859         [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
860        [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
861         [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
862    ]
863    labels = ops.convert_to_tensor(labels)
864    states = math_ops.log(states)
865    olabel = ctc_ops._state_to_olabel_unique(
866        labels, num_labels, states, ctc_ops.ctc_unique_labels(labels))
867    olabel = math_ops.exp(olabel)
868    blank = olabel[:, :, 0]
869    self.assertAllClose(blank, [
870        [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
871         0.26 + 0.27 + 0.28 + 0.29 + 0.30],
872        [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
873         2.6 + 2.7 + 2.8 + 2.9 + 3.0],
874        [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
875         26.0 + 27.0 + 28.0 + 29.0 + 30.0]])
876    self.assertAllClose(olabel[:, :, 1:], [
877        [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
878         [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
879        [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
880         [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
881        [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
882         [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
883    ])
884
885  def testStateToOlabelUniqueSinglePath(self):
886    labels = [
887        [3, 4, 3],
888        [1, 0, 0],
889    ]
890    num_labels = 8
891
892    # 3 frames, 2 batch, 8 states (4 label, 4 blank).
893    #
894    # There is only single valid path for each sequence because the frame
895    # lengths and the label lengths are the same.
896    states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
897               [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
898              [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
899               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
900              [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
901               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
902    labels = ops.convert_to_tensor(labels)
903    states = math_ops.log(states)
904    olabel = ctc_ops._state_to_olabel_unique(labels, num_labels, states,
905                                             ctc_ops.ctc_unique_labels(labels))
906    olabel = math_ops.exp(olabel)
907    blank = olabel[:, :, 0]
908
909    self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
910    self.assertAllClose(olabel[:, :, 1:],
911                        [
912                            [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
913                             [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
914                            [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
915                             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
916                            [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
917                             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
918                        ])
919
920  @test_util.run_deprecated_v1
921  def testScan(self):
922    with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
923      out = ctc_ops._scan(
924          lambda accum, elem: accum + elem,
925          constant_op.constant([1.0, 2.0, 3.0]), 23.0)
926      self.assertAllEqual([24.0, 26.0, 29.0], out)
927
928      out = ctc_ops._scan(
929          lambda a, e: a + e,
930          constant_op.constant([1.0, 2.0, 3.0]), 23.0,
931          inclusive=True)
932      self.assertAllEqual([23.0, 24.0, 26.0, 29.0], out)
933
934      out = ctc_ops._scan(
935          lambda a, e: a + e,
936          constant_op.constant([1.0, 2.0, 3.0]), 23.0,
937          reverse=True)
938      self.assertAllEqual([29.0, 28.0, 26.0], out)
939
940      out = ctc_ops._scan(
941          lambda a, e: a + e,
942          constant_op.constant([1.0, 2.0, 3.0]), 23.0,
943          reverse=True,
944          inclusive=True)
945      self.assertAllEqual([29.0, 28.0, 26.0, 23.0], out)
946
947      out = ctc_ops._scan(
948          lambda a, e: a + e,
949          constant_op.constant([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]),
950          constant_op.constant([23.0, 24.0]))
951      self.assertAllEqual([[23.0, 25.0], [25.0, 28.0], [29.0, 33.0]], out)
952
953  @test_util.run_deprecated_v1
954  def testScanCapturesVariables(self):
955    with self.cached_session() as sess:
956      x = random_ops.random_uniform([])
957      fn = lambda accum, elem: accum + x * elem
958      out = ctc_ops._scan(fn, constant_op.constant([0.0, 1.0, 2.0]), 23.0)
959      self.assertAllClose(*sess.run([
960          [23.0 + x * 0.0, 23.0 + x * 1.0, 23.0 + x * 3.0], out
961      ]))
962
963  @test_util.run_deprecated_v1
964  def testScanMultipleAccumulators(self):
965    with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
966      def fn(accum, elem):
967        accum_a, accum_b = accum
968        return accum_a + elem, accum_b * elem
969      out = ctc_ops._scan(
970          fn, constant_op.constant([1.0, 2.0, 3.0]),
971          (23.0, constant_op.constant([1.0, 2.0])))
972      a, b = out
973      self.assertAllEqual([24.0, 26.0, 29.0], a)
974      self.assertAllEqual([[1.0, 2.0], [2.0, 4.0], [6.0, 12.0]], b)
975
976  @test_util.run_deprecated_v1
977  def testScanMultipleElements(self):
978    with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
979      def fn(accum, elem):
980        elem_a, elem_b = elem
981        return accum + (elem_a * elem_b)
982      elems_a = constant_op.constant([1.0, 2.0, 3.0])
983      elems_b = constant_op.constant([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
984      out = ctc_ops._scan(
985          fn, (elems_a, elems_b),
986          initial=constant_op.constant([0.0, 0.0]))
987      self.assertAllEqual(
988          [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
989
990
991def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu,
992                 sparse=True):
993  with test_util.device(use_gpu=use_gpu):
994    if sparse:
995      labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
996    with backprop.GradientTape() as t:
997      t.watch(logits)
998      ref_loss = ctc_ops.ctc_loss_v3(
999          labels=labels,
1000          logits=logits,
1001          label_length=label_length,
1002          logit_length=logit_length,
1003          blank_index=0)
1004    ref_grad = t.gradient(ref_loss, logits)
1005    return ref_loss, ref_grad
1006
1007
1008@test_util.run_all_in_graph_and_eager_modes
1009class CTCLossTestV3(test.TestCase, parameterized.TestCase):
1010
1011  @parameterized.parameters([False, True])
1012  @test_util.run_v2_only
1013  def testCtcLossV3(self, run_tf_func):
1014    """Testing GPU CTC loss.
1015
1016
1017    testing if GPU CTC loss will generate same result with CPU version
1018    """
1019    if not test.is_gpu_available():
1020      self.skipTest("Need GPU for testing.")
1021    if not context.executing_eagerly():
1022      self.skipTest("Need eager execution for testing.")
1023    random_seed.set_random_seed(5)
1024
1025    batch_size = 8
1026    num_labels = 6
1027    max_label_length = 5
1028    num_frames = 12
1029
1030    labels = random_ops.random_uniform([batch_size, max_label_length],
1031                                       minval=1,
1032                                       maxval=num_labels,
1033                                       dtype=dtypes.int64)
1034    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
1035
1036    label_length = random_ops.random_uniform([batch_size],
1037                                             minval=2,
1038                                             maxval=max_label_length,
1039                                             dtype=dtypes.int64)
1040    label_mask = array_ops.sequence_mask(
1041        label_length, maxlen=max_label_length, dtype=label_length.dtype)
1042    labels *= label_mask
1043    logit_length = [num_frames] * batch_size
1044
1045    if run_tf_func:
1046      ctc_loss = def_function.function(_ctc_loss_v3)
1047    else:
1048      ctc_loss = _ctc_loss_v3
1049
1050    ref_loss, ref_grad = ctc_loss(labels, logits, label_length, logit_length,
1051                                  False)
1052    loss, grad = ctc_loss(labels, logits, label_length, logit_length, True)
1053
1054    self.assertAllClose(loss, ref_loss, atol=1e-6)
1055    self.assertAllClose(grad, ref_grad, atol=2e-6)
1056
1057  @parameterized.parameters([False, True])
1058  def testCtcLossFp16(self, sparse_labels):
1059    batch_size = 8
1060    num_labels = 6
1061    max_label_length = 5
1062    num_frames = 12
1063
1064    labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
1065    labels = ops.convert_to_tensor(labels, dtypes.int64)
1066    fp16_logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
1067    fp16_logits = ops.convert_to_tensor(fp16_logits, dtypes.float16)
1068    label_length = np.random.randint(2, max_label_length, [batch_size])
1069    label_length = ops.convert_to_tensor(label_length, dtypes.int64)
1070
1071    label_mask = array_ops.sequence_mask(
1072        label_length, maxlen=max_label_length, dtype=label_length.dtype)
1073    labels *= label_mask
1074    logit_length = [num_frames] * batch_size
1075
1076    fp16_loss, fp16_grad = _ctc_loss_v3(
1077        labels, fp16_logits, label_length, logit_length, use_gpu=True,
1078        sparse=sparse_labels)
1079    fp32_loss, fp32_grad = _ctc_loss_v3(
1080        labels, math_ops.cast(fp16_logits, dtypes.float32), label_length,
1081        logit_length, use_gpu=True, sparse=sparse_labels)
1082
1083    self.assertEqual(fp16_loss.dtype, dtypes.float16)
1084    self.assertEqual(fp16_grad.dtype, dtypes.float16)
1085    self.assertAllClose(
1086        self.evaluate(fp16_loss),
1087        self.evaluate(math_ops.cast(fp32_loss, dtypes.float16))
1088    )
1089    self.assertAllClose(
1090        self.evaluate(fp16_grad),
1091        self.evaluate(math_ops.cast(fp32_grad, dtypes.float16))
1092    )
1093
1094  @parameterized.parameters([False, True])
1095  def testCtcLossWithListLogits(self, sparse_labels):
1096    batch_size = 8
1097    num_labels = 6
1098    max_label_length = 5
1099    num_frames = 12
1100
1101    labels = np.random.randint(1, num_labels, [batch_size, max_label_length])
1102    labels = ops.convert_to_tensor(labels, dtypes.int64)
1103    logits = np.random.uniform(size=[num_frames, batch_size, num_labels])
1104    label_length = np.random.randint(2, max_label_length, [batch_size])
1105    label_length = ops.convert_to_tensor(label_length, dtypes.int64)
1106
1107    label_mask = array_ops.sequence_mask(
1108        label_length, maxlen=max_label_length, dtype=label_length.dtype)
1109    labels *= label_mask
1110    logit_length = [num_frames] * batch_size
1111    if sparse_labels:
1112      labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
1113
1114    list_loss = ctc_ops.ctc_loss_v3(
1115        labels=labels,
1116        logits=logits.tolist(),
1117        label_length=label_length,
1118        logit_length=logit_length,
1119        blank_index=0)
1120    tensor_loss = ctc_ops.ctc_loss_v3(
1121        labels=labels,
1122        logits=ops.convert_to_tensor(logits, dtypes.float32),
1123        label_length=label_length,
1124        logit_length=logit_length,
1125        blank_index=0)
1126
1127    self.assertAllClose(self.evaluate(list_loss), self.evaluate(tensor_loss))
1128
1129  @test_util.run_v2_only
1130  def testCtcLossAlgorithmFallback(self):
1131    """Test if GPU CTC loss can fallback to the correct algorithm."""
1132    if not test.is_gpu_available():
1133      self.skipTest("Need GPU for testing.")
1134    if not context.executing_eagerly():
1135      self.skipTest("Need eager execution for testing.")
1136    random_seed.set_random_seed(5)
1137
1138    batch_size = 1
1139    num_labels = 11777
1140    max_label_length = 2
1141    num_frames = 1
1142
1143    labels = random_ops.random_uniform([batch_size, max_label_length],
1144                                       minval=1,
1145                                       maxval=num_labels,
1146                                       dtype=dtypes.int64)
1147    logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
1148
1149    label_length = random_ops.random_uniform([batch_size],
1150                                             minval=1,
1151                                             maxval=max_label_length,
1152                                             dtype=dtypes.int64)
1153    logit_length = [num_frames] * batch_size
1154
1155    loss, grad = _ctc_loss_v3(labels, logits, label_length, logit_length, True)
1156    ref_loss, ref_grad = _ctc_loss_v3(labels, logits, label_length,
1157                                      logit_length, False)
1158
1159    self.assertAllClose(loss, ref_loss, atol=1e-6)
1160    self.assertAllClose(grad, ref_grad, atol=2e-6)
1161
1162
1163@test_util.run_all_in_graph_and_eager_modes
1164class CTCLossDeterministicTest(test.TestCase, parameterized.TestCase):
1165
1166  def _randomFloats(self, shape):
1167    x = (2 * np.random.random_sample(shape) - 1)
1168    return constant_op.constant(x, dtype=dtypes.float32)
1169
1170  def _genInputParams(self,
1171                      num_classes=10,
1172                      batch_size=32,
1173                      max_label_sequence_length=50,
1174                      num_frames=100,
1175                      logits_time_major=True,
1176                      sparse_labels=True):
1177    assert num_frames >= max_label_sequence_length
1178
1179    labels_shape = (batch_size, max_label_sequence_length)
1180    # Zero-pad the labels. Zero is the default blank index in the TF2 API.
1181    # num_classes includes the blank class
1182    unmasked_labels = np.random.randint(
1183        1, num_classes, size=labels_shape, dtype=np.int32)
1184    labels_lengths = np.random.randint(
1185        1, high=max_label_sequence_length, size=batch_size, dtype=np.int32)
1186    labels_masks = (np.arange(max_label_sequence_length) <
1187                    labels_lengths.reshape(batch_size, 1)).astype(np.int32)
1188    labels = unmasked_labels * labels_masks
1189    if sparse_labels:
1190      labels = ctc_ops.dense_labels_to_sparse(labels, labels_lengths)
1191
1192    if logits_time_major:
1193      logits_shape = (num_frames, batch_size, num_classes)
1194    else:
1195      logits_shape = (batch_size, num_frames, num_classes)
1196    logits = self._randomFloats(logits_shape)
1197
1198    labels_lengths = constant_op.constant(labels_lengths)
1199
1200    logits_lengths = [num_frames] * batch_size
1201    logits_lengths = constant_op.constant(logits_lengths)
1202
1203    return labels, logits, labels_lengths, logits_lengths
1204
1205  def _forwardAndBackward(self, sparse_labels, logits_time_major, seed):
1206    np.random.seed(seed)
1207    params = self._genInputParams(
1208        logits_time_major=logits_time_major, sparse_labels=sparse_labels)
1209    labels, logits, labels_lengths, logits_lengths = params
1210    output_shape = (labels_lengths.shape[0],)
1211    upstream_gradients = self._randomFloats(output_shape)
1212    with backprop.GradientTape() as tape:
1213      tape.watch(logits)
1214      loss = ctc_ops.ctc_loss_v3(
1215          labels,
1216          logits,
1217          labels_lengths,
1218          logits_lengths,
1219          logits_time_major=logits_time_major,
1220          blank_index=0)
1221      gradient_injector_output = loss * upstream_gradients
1222    return loss, tape.gradient(gradient_injector_output, logits)
1223
1224  @parameterized.parameters(  # parameterized.product not yet available
1225      (False, False), (False, True), (True, False), (True, True))
1226  def testForwardAndBackward(self, sparse_labels, logits_time_major):
1227    with test_util.deterministic_ops():
1228      for seed in range(2):
1229        loss_a, gradient_a = self._forwardAndBackward(sparse_labels,
1230                                                      logits_time_major, seed)
1231        loss_b, gradient_b = self._forwardAndBackward(sparse_labels,
1232                                                      logits_time_major, seed)
1233        loss_a, loss_b, gradient_a, gradient_b = self.evaluate(
1234            (loss_a, loss_b, gradient_a, gradient_b))
1235        self.assertAllEqual(loss_a, loss_b, "Loss mismatch")
1236        self.assertAllEqual(gradient_a, gradient_b, "Gradient mismatch")
1237
1238
1239if __name__ == "__main__":
1240  test.main()
1241