1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ctc_ops.ctc_loss_op.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import random_seed 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import ctc_ops 33from tensorflow.python.ops import gradients_impl 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import random_ops 36from tensorflow.python.ops import sparse_ops 37from tensorflow.python.platform import test 38 39 40def SimpleSparseTensorFrom(x): 41 """Create a very simple SparseTensor with dimensions (batch, time). 42 43 Args: 44 x: a list of lists of type int 45 46 Returns: 47 x_ix and x_val, the indices and values of the SparseTensor<2>. 48 """ 49 x_ix = [] 50 x_val = [] 51 for batch_i, batch in enumerate(x): 52 for time, val in enumerate(batch): 53 x_ix.append([batch_i, time]) 54 x_val.append(val) 55 x_shape = [len(x), np.asarray(x_ix).max(0)[1] + 1] 56 x_ix = constant_op.constant(x_ix, dtypes.int64) 57 x_val = constant_op.constant(x_val, dtypes.int32) 58 x_shape = constant_op.constant(x_shape, dtypes.int64) 59 60 return sparse_tensor.SparseTensor(x_ix, x_val, x_shape) 61 62 63def _ctc_loss_v2(labels, inputs, sequence_length, 64 preprocess_collapse_repeated=False, 65 ctc_merge_repeated=True, 66 ignore_longer_outputs_than_inputs=False, 67 time_major=True): 68 """Call ctc_loss_v2 with v1 args.""" 69 assert not preprocess_collapse_repeated 70 assert ctc_merge_repeated 71 assert not ignore_longer_outputs_than_inputs 72 return ctc_ops.ctc_loss_v2( 73 labels=labels, 74 logits=inputs, 75 logit_length=sequence_length, 76 label_length=None, 77 blank_index=-1, 78 logits_time_major=time_major) 79 80 81class CTCLossTest(test.TestCase): 82 83 def _testCTCLoss(self, 84 inputs, 85 seq_lens, 86 labels, 87 loss_truth, 88 grad_truth, 89 expected_err_re=None): 90 self.assertEqual(len(inputs), len(grad_truth)) 91 92 inputs_t = constant_op.constant(inputs) 93 94 with self.cached_session(use_gpu=False) as sess: 95 loss = _ctc_loss_v2( 96 inputs=inputs_t, labels=labels, sequence_length=seq_lens) 97 grad = gradients_impl.gradients(loss, [inputs_t])[0] 98 99 self.assertShapeEqual(loss_truth, loss) 100 self.assertShapeEqual(grad_truth, grad) 101 102 if expected_err_re is None: 103 (tf_loss, tf_grad) = self.evaluate([loss, grad]) 104 self.assertAllClose(tf_loss, loss_truth, atol=1e-6) 105 self.assertAllClose(tf_grad, grad_truth, atol=1e-6) 106 else: 107 with self.assertRaisesOpError(expected_err_re): 108 self.evaluate([loss, grad]) 109 110 @test_util.run_v1_only("b/120545219") 111 def testBasic(self): 112 """Test two batch entries.""" 113 # Input and ground truth from Alex Graves' implementation. 114 # 115 #### Batch entry 0 ##### 116 # targets: 0 1 2 1 0 117 # outputs: 118 # 0 0.633766 0.221185 0.0917319 0.0129757 0.0142857 0.0260553 119 # 1 0.111121 0.588392 0.278779 0.0055756 0.00569609 0.010436 120 # 2 0.0357786 0.633813 0.321418 0.00249248 0.00272882 0.0037688 121 # 3 0.0663296 0.643849 0.280111 0.00283995 0.0035545 0.00331533 122 # 4 0.458235 0.396634 0.123377 0.00648837 0.00903441 0.00623107 123 # alpha: 124 # 0 -3.64753 -0.456075 -inf -inf -inf -inf -inf -inf -inf -inf -inf 125 # 1 -inf -inf -inf -0.986437 -inf -inf -inf -inf -inf -inf -inf 126 # 2 -inf -inf -inf -inf -inf -2.12145 -inf -inf -inf -inf -inf 127 # 3 -inf -inf -inf -inf -inf -inf -inf -2.56174 -inf -inf -inf 128 # 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf -3.34211 -inf 129 # beta: 130 # 0 -inf -2.88604 -inf -inf -inf -inf -inf -inf -inf -inf -inf 131 # 1 -inf -inf -inf -2.35568 -inf -inf -inf -inf -inf -inf -inf 132 # 2 -inf -inf -inf -inf -inf -1.22066 -inf -inf -inf -inf -inf 133 # 3 -inf -inf -inf -inf -inf -inf -inf -0.780373 -inf -inf -inf 134 # 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf 0 0 135 # prob: -3.34211 136 # outputDerivs: 137 # 0 -0.366234 0.221185 0.0917319 0.0129757 0.0142857 0.0260553 138 # 1 0.111121 -0.411608 0.278779 0.0055756 0.00569609 0.010436 139 # 2 0.0357786 0.633813 -0.678582 0.00249248 0.00272882 0.0037688 140 # 3 0.0663296 -0.356151 0.280111 0.00283995 0.0035545 0.00331533 141 # 4 -0.541765 0.396634 0.123377 0.00648837 0.00903441 0.00623107 142 # 143 #### Batch entry 1 ##### 144 # 145 # targets: 0 1 1 0 146 # outputs: 147 # 0 0.30176 0.28562 0.0831517 0.0862751 0.0816851 0.161508 148 # 1 0.24082 0.397533 0.0557226 0.0546814 0.0557528 0.19549 149 # 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 0.202456 150 # 3 0.280884 0.429522 0.0326593 0.0339046 0.0326856 0.190345 151 # 4 0.423286 0.315517 0.0338439 0.0393744 0.0339315 0.154046 152 # alpha: 153 # 0 -1.8232 -1.19812 -inf -inf -inf -inf -inf -inf -inf 154 # 1 -inf -2.19315 -2.83037 -2.1206 -inf -inf -inf -inf -inf 155 # 2 -inf -inf -inf -2.03268 -3.71783 -inf -inf -inf -inf 156 # 3 -inf -inf -inf -inf -inf -4.56292 -inf -inf -inf 157 # 4 -inf -inf -inf -inf -inf -inf -inf -5.42262 -inf 158 # beta: 159 # 0 -inf -4.2245 -inf -inf -inf -inf -inf -inf -inf 160 # 1 -inf -inf -inf -3.30202 -inf -inf -inf -inf -inf 161 # 2 -inf -inf -inf -inf -1.70479 -0.856738 -inf -inf -inf 162 # 3 -inf -inf -inf -inf -inf -0.859706 -0.859706 -0.549337 -inf 163 # 4 -inf -inf -inf -inf -inf -inf -inf 0 0 164 # prob: -5.42262 165 # outputDerivs: 166 # 0 -0.69824 0.28562 0.0831517 0.0862751 0.0816851 0.161508 167 # 1 0.24082 -0.602467 0.0557226 0.0546814 0.0557528 0.19549 168 # 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 -0.797544 169 # 3 0.280884 -0.570478 0.0326593 0.0339046 0.0326856 0.190345 170 # 4 -0.576714 0.315517 0.0338439 0.0393744 0.0339315 0.154046 171 172 # max_time_steps == 7 173 depth = 6 174 175 # seq_len_0 == 5 176 targets_0 = [0, 1, 2, 1, 0] 177 loss_log_prob_0 = -3.34211 178 # dimensions are time x depth 179 input_prob_matrix_0 = np.asarray( 180 [[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], 181 [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436], 182 [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688], 183 [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533], 184 [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], 185 dtype=np.float32) 186 input_log_prob_matrix_0 = np.log(input_prob_matrix_0) 187 gradient_log_prob_0 = np.asarray( 188 [[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], 189 [0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436], 190 [0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688], 191 [0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533], 192 [-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]], 193 dtype=np.float32) 194 195 # seq_len_1 == 5 196 targets_1 = [0, 1, 1, 0] 197 loss_log_prob_1 = -5.42262 198 # dimensions are time x depth 199 200 input_prob_matrix_1 = np.asarray( 201 [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], 202 [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549], 203 [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456], 204 [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345], 205 [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], 206 dtype=np.float32) 207 input_log_prob_matrix_1 = np.log(input_prob_matrix_1) 208 gradient_log_prob_1 = np.asarray( 209 [[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508], 210 [0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549], 211 [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544], 212 [0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345], 213 [-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]], 214 dtype=np.float32) 215 216 # len max_time_steps array of 2 x depth matrices 217 inputs = [ 218 np.vstack( 219 [input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]]) 220 for t in range(5) 221 ] + 2 * [np.nan * np.ones((2, depth), np.float32)] 222 223 # convert inputs into [max_time x batch_size x depth tensor] Tensor 224 inputs = np.asarray(inputs, dtype=np.float32) 225 226 # len batch_size array of label vectors 227 labels = SimpleSparseTensorFrom([targets_0, targets_1]) 228 229 # batch_size length vector of sequence_lengths 230 seq_lens = np.array([5, 5], dtype=np.int32) 231 232 # output: batch_size length vector of negative log probabilities 233 loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32) 234 235 # output: len max_time_steps array of 2 x depth matrices 236 grad_truth = [ 237 np.vstack([gradient_log_prob_0[t, :], gradient_log_prob_1[t, :]]) 238 for t in range(5) 239 ] + 2 * [np.zeros((2, depth), np.float32)] 240 241 # convert grad_truth into [max_time x batch_size x depth] Tensor 242 grad_truth = np.asarray(grad_truth, dtype=np.float32) 243 244 self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth) 245 246 def test_time_major(self): 247 """Testing time_major param. 248 249 250 testing if transposing and setting time_major=False will result in the same 251 loss 252 """ 253 # [max_time x batch_size x depth tensor] 254 inputs = np.random.randn(2, 2, 3).astype(np.float32) 255 labels = SimpleSparseTensorFrom([[0, 1], [1, 0]]) 256 seq_lens = np.array([2, 2], dtype=np.int32) 257 258 inputs_t = constant_op.constant(inputs) 259 260 # Transposing tensor to [batch_size x max_time x depth tensor] 261 inputs_t_transposed = constant_op.constant(inputs.transpose(1, 0, 2)) 262 263 with self.session(use_gpu=False) as sess: 264 loss = _ctc_loss_v2( 265 inputs=inputs_t, labels=labels, sequence_length=seq_lens) 266 loss_transposed = _ctc_loss_v2( 267 inputs=inputs_t_transposed, 268 labels=labels, 269 sequence_length=seq_lens, 270 time_major=False) 271 272 (tf_loss, tf_loss_transposed) = self.evaluate([loss, loss_transposed]) 273 self.assertAllEqual(tf_loss, tf_loss_transposed) 274 275 @test_util.run_v1_only("b/120545219") 276 def testInvalidSecondGradient(self): 277 inputs = np.random.randn(2, 2, 3).astype(np.float32) 278 inputs_t = constant_op.constant(inputs) 279 labels = SimpleSparseTensorFrom([[0, 1], [1, 0]]) 280 seq_lens = np.array([2, 2], dtype=np.int32) 281 v = [1.0] 282 283 with self.session(use_gpu=False): 284 loss = _ctc_loss_v2( 285 inputs=inputs_t, labels=labels, sequence_length=seq_lens) 286 # Taking this second gradient should fail, since it is not 287 # yet supported. 288 with self.assertRaisesRegex(LookupError, "explicitly disabled"): 289 _ = gradients_impl._hessian_vector_product(loss, [inputs_t], v) 290 291 @test_util.run_v1_only("b/120545219") 292 def testEmptyBatch(self): 293 inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2)) 294 sequence_lengths = constant_op.constant([], dtype=dtypes.int32) 295 labels = sparse_tensor.SparseTensor( 296 indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64), 297 values=constant_op.constant([], shape=(0,), dtype=dtypes.int32), 298 dense_shape=[5, 5]) 299 300 with self.session(use_gpu=False) as sess: 301 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 302 "batch_size must not be 0"): 303 sess.run(_ctc_loss_v2(labels, inputs, sequence_lengths)) 304 305 306class CTCLossTestV2(test.TestCase, parameterized.TestCase): 307 308 @test_util.run_in_graph_and_eager_modes 309 def testCtcLossV2(self): 310 random_seed.set_random_seed(5) 311 312 batch_size = 8 313 num_labels = 6 314 max_label_length = 5 315 num_frames = 12 316 317 labels = random_ops.random_uniform( 318 [batch_size, max_label_length], minval=1, maxval=num_labels, 319 dtype=dtypes.int64) 320 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 321 322 label_length = random_ops.random_uniform( 323 [batch_size], minval=2, maxval=max_label_length, dtype=dtypes.int64) 324 label_mask = array_ops.sequence_mask( 325 label_length, maxlen=max_label_length, dtype=label_length.dtype) 326 labels *= label_mask 327 logit_length = [num_frames] * batch_size 328 329 with backprop.GradientTape() as t: 330 t.watch(logits) 331 ref_loss = ctc_ops.ctc_loss_v2( 332 labels=labels, 333 logits=logits, 334 label_length=label_length, 335 logit_length=logit_length) 336 ref_grad = t.gradient(ref_loss, [logits]) 337 338 sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length) 339 340 def assert_same_loss_and_grads(loss): 341 if context.executing_eagerly(): 342 return 343 with self.cached_session(): 344 self.assertAllClose(*self.evaluate([loss, ref_loss])) 345 grad = gradients_impl.gradients(loss, [logits]) 346 self.assertAllClose( 347 *self.evaluate([grad, ref_grad]), rtol=2e-06, atol=2e-06) 348 349 assert_same_loss_and_grads( 350 ctc_ops.ctc_loss_v2( 351 labels=sparse_labels, 352 logits=logits, 353 label_length=label_length, 354 logit_length=logit_length, 355 blank_index=0)) 356 357 @test_util.run_v1_only("b/120545219") 358 def testCtcLossDenseIsSameAsCtcLoss(self): 359 with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): 360 random_seed.set_random_seed(5) 361 362 batch_size = 8 363 num_labels = 6 364 label_length = 5 365 minimum_logits_length = 10 366 num_frames = minimum_logits_length + batch_size 367 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 368 labels = random_ops.random_uniform( 369 [batch_size, label_length], minval=1, maxval=num_labels, 370 dtype=dtypes.int64) 371 372 label_lengths = random_ops.random_uniform( 373 [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64) 374 label_mask = array_ops.sequence_mask( 375 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 376 labels *= label_mask 377 378 logit_lengths = math_ops.range(batch_size) + minimum_logits_length 379 380 ctc_loss = ctc_ops.ctc_loss_dense( 381 labels=labels, 382 logits=logits, 383 label_length=label_lengths, 384 logit_length=logit_lengths) 385 ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] 386 387 # Shift labels down by one (move blank from 0 to num_labels -1) 388 tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1 389 tf_nn_ctc_logits = array_ops.concat([ 390 logits[:, :, 1:], 391 logits[:, :, 0:1], 392 ], axis=2) 393 394 tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse( 395 tf_ctc_loss_labels, label_lengths) 396 397 tf_nn_ctc_loss = ctc_ops.ctc_loss( 398 labels=tf_ctc_loss_labels, 399 inputs=tf_nn_ctc_logits, 400 sequence_length=logit_lengths, 401 time_major=True) 402 tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] 403 404 with self.cached_session() as sess: 405 for _ in range(32): 406 self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) 407 self.assertAllClose( 408 *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), 409 rtol=4e-06, 410 atol=4e-06) 411 412 @test_util.run_v1_only("b/120545219") 413 def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self): 414 random_seed.set_random_seed(5) 415 416 batch_size = 8 417 num_labels = 6 418 label_length = 5 419 num_frames = 12 420 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 421 labels = random_ops.random_uniform( 422 [batch_size, label_length], minval=1, maxval=num_labels, 423 dtype=dtypes.int64) 424 425 label_lengths = random_ops.random_uniform( 426 [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64) 427 label_mask = array_ops.sequence_mask( 428 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 429 labels *= label_mask 430 431 logit_lengths = [num_frames] * batch_size 432 433 ctc_loss = ctc_ops.ctc_loss_dense( 434 labels=labels, 435 logits=logits, 436 label_length=label_lengths, 437 logit_length=logit_lengths, 438 unique=ctc_ops.ctc_unique_labels(labels)) 439 ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] 440 441 # Shift labels down by one (move blank from 0 to num_labels -1) 442 tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1 443 tf_nn_ctc_logits = array_ops.concat([ 444 logits[:, :, 1:], 445 logits[:, :, 0:1], 446 ], axis=2) 447 448 tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse( 449 tf_ctc_loss_labels, label_lengths) 450 451 tf_nn_ctc_loss = ctc_ops.ctc_loss( 452 labels=tf_ctc_loss_labels, 453 inputs=tf_nn_ctc_logits, 454 sequence_length=logit_lengths, 455 time_major=True) 456 tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] 457 458 with self.cached_session(): 459 for _ in range(32): 460 self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) 461 self.assertAllClose( 462 *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), 463 rtol=2e-06, 464 atol=2e-06) 465 466 @test_util.run_v1_only("b/120545219") 467 def testCtcLossDenseUniqueFastPathWithBlankIndexIsSameAsCtcLoss(self): 468 random_seed.set_random_seed(5) 469 470 batch_size = 8 471 num_labels = 6 472 label_length = 5 473 num_frames = 12 474 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 475 labels = random_ops.random_uniform([batch_size, label_length], 476 minval=0, 477 maxval=num_labels - 1, 478 dtype=dtypes.int64) 479 480 label_lengths = random_ops.random_uniform([batch_size], 481 minval=2, 482 maxval=label_length, 483 dtype=dtypes.int64) 484 label_mask = array_ops.sequence_mask( 485 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 486 labels *= label_mask 487 488 logit_lengths = [num_frames] * batch_size 489 490 tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) 491 tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(tf_ctc_loss_labels, 492 label_lengths) 493 494 tf_nn_ctc_loss = ctc_ops.ctc_loss( 495 labels=tf_ctc_loss_labels, 496 inputs=logits, 497 sequence_length=logit_lengths, 498 time_major=True) 499 tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] 500 501 # Shift the blank logits/labels to be somewhere in the middle. 502 blank_index = 2 503 shifted_logits = array_ops.concat([ 504 logits[:, :, :blank_index], 505 logits[:, :, -1:], 506 logits[:, :, blank_index:-1], 507 ], 508 axis=2) 509 shifted_labels = array_ops.where_v2(labels < blank_index, labels, 510 labels + 1) 511 512 ctc_loss = ctc_ops.ctc_loss_dense( 513 labels=shifted_labels, 514 logits=shifted_logits, 515 label_length=label_lengths, 516 logit_length=logit_lengths, 517 blank_index=blank_index, 518 unique=ctc_ops.ctc_unique_labels(shifted_labels)) 519 ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] 520 521 with self.cached_session() as sess: 522 for _ in range(32): 523 self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) 524 self.assertAllClose( 525 *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), 526 rtol=2e-06, 527 atol=2e-06) 528 529 @test_util.run_v1_only("b/120545219") 530 def testCtcLossDenseWithBlankIndexIsSameAsCtcLoss(self): 531 random_seed.set_random_seed(5) 532 533 batch_size = 8 534 num_labels = 6 535 label_length = 5 536 num_frames = 12 537 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 538 labels = random_ops.random_uniform( 539 [batch_size, label_length], minval=0, maxval=num_labels-1, 540 dtype=dtypes.int64) 541 542 label_lengths = random_ops.random_uniform( 543 [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64) 544 label_mask = array_ops.sequence_mask( 545 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 546 labels *= label_mask 547 548 logit_lengths = [num_frames] * batch_size 549 550 tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) 551 tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse( 552 tf_ctc_loss_labels, label_lengths) 553 554 tf_nn_ctc_loss = ctc_ops.ctc_loss( 555 labels=tf_ctc_loss_labels, 556 inputs=logits, 557 sequence_length=logit_lengths, 558 time_major=True) 559 tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] 560 561 # Shift the blank logits/labels to be somewhere in the middle. 562 blank_index = 2 563 shifted_logits = array_ops.concat([ 564 logits[:, :, :blank_index], 565 logits[:, :, -1:], 566 logits[:, :, blank_index:-1], 567 ], axis=2) 568 shifted_labels = array_ops.where_v2(labels < blank_index, labels, 569 labels + 1) 570 571 ctc_loss = ctc_ops.ctc_loss_dense( 572 labels=shifted_labels, 573 logits=shifted_logits, 574 label_length=label_lengths, 575 logit_length=logit_lengths, 576 blank_index=blank_index) 577 ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] 578 579 with self.cached_session() as sess: 580 for _ in range(32): 581 self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) 582 self.assertAllClose( 583 *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), 584 rtol=2e-06, 585 atol=2e-06) 586 587 @test_util.run_v1_only("b/120545219") 588 def testCtcLossDenseWithNegativeBlankIndexIsSameAsCtcLoss(self): 589 with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): 590 random_seed.set_random_seed(5) 591 592 batch_size = 8 593 num_labels = 6 594 label_length = 5 595 num_frames = 12 596 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 597 labels = random_ops.random_uniform( 598 [batch_size, label_length], minval=0, maxval=num_labels-1, 599 dtype=dtypes.int64) 600 601 label_lengths = random_ops.random_uniform( 602 [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64) 603 label_mask = array_ops.sequence_mask( 604 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 605 labels *= label_mask 606 607 logit_lengths = [num_frames] * batch_size 608 609 ctc_loss = ctc_ops.ctc_loss_dense( 610 labels=labels, 611 logits=logits, 612 label_length=label_lengths, 613 logit_length=logit_lengths, 614 blank_index=-1) 615 ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0] 616 617 tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) 618 tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse( 619 tf_ctc_loss_labels, label_lengths) 620 621 tf_nn_ctc_loss = ctc_ops.ctc_loss( 622 labels=tf_ctc_loss_labels, 623 inputs=logits, 624 sequence_length=logit_lengths, 625 time_major=True) 626 tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0] 627 628 with self.cached_session() as sess: 629 for _ in range(32): 630 self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss])) 631 self.assertAllClose( 632 *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]), 633 rtol=2e-06, 634 atol=2e-06) 635 636 @parameterized.parameters((False, 0), (True, 0), (False, -1), (True, -1)) 637 def testCtcLossDenseWithUndefinedStaticDimensions(self, unique, blank_index): 638 random_seed.set_random_seed(5) 639 640 # Trace without a batch size and number of frames 641 batch_size = None 642 num_labels = 6 643 label_length = 5 644 num_frames = None 645 646 @def_function.function 647 def func(labels, logits, label_lengths, logit_lengths): 648 unique_labels = ctc_ops.ctc_unique_labels(labels) if unique else None 649 return ctc_ops.ctc_loss_dense( 650 labels=labels, 651 logits=logits, 652 label_length=label_lengths, 653 logit_length=logit_lengths, 654 unique=unique_labels, 655 blank_index=blank_index) 656 657 labels_spec = tensor_spec.TensorSpec([batch_size, label_length], 658 dtypes.int64) 659 logits_spec = tensor_spec.TensorSpec([num_frames, batch_size, num_labels], 660 dtypes.float32) 661 label_lengths_spec = tensor_spec.TensorSpec([batch_size], dtypes.int64) 662 logit_lengths_spec = tensor_spec.TensorSpec([batch_size], dtypes.int64) 663 664 f = func.get_concrete_function( 665 labels_spec, logits_spec, label_lengths_spec, logit_lengths_spec) 666 667 # Execute with a defined batch size and number of frames 668 batch_size = 8 669 num_frames = 12 670 671 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 672 labels = random_ops.random_uniform( 673 [batch_size, label_length], minval=1, maxval=num_labels, 674 dtype=dtypes.int64) 675 676 label_lengths = random_ops.random_uniform( 677 [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64) 678 label_mask = array_ops.sequence_mask( 679 label_lengths, maxlen=label_length, dtype=label_lengths.dtype) 680 labels *= label_mask 681 682 logit_lengths = constant_op.constant( 683 [num_frames] * batch_size, dtype=dtypes.int64) 684 685 f(labels, logits, label_lengths, logit_lengths) 686 687 def testCollapseRepeated(self): 688 collapsed, new_seq_lengths = ctc_ops.collapse_repeated( 689 labels=[[1, 3, 3, 3, 0], 690 [1, 4, 4, 4, 0], 691 [4, 2, 2, 9, 4]], 692 seq_length=[4, 5, 5]) 693 self.assertAllEqual(new_seq_lengths, [2, 3, 4]) 694 self.assertAllEqual( 695 collapsed, 696 [[1, 3, 0, 0], 697 [1, 4, 0, 0], 698 [4, 2, 9, 4]]) 699 700 def testCollapseRepeatedPreservesDtypes(self): 701 collapsed, new_seq_lengths = ctc_ops.collapse_repeated( 702 labels=constant_op.constant( 703 [[1, 3, 3, 3, 0], 704 [1, 4, 4, 4, 0], 705 [4, 2, 2, 9, 4]], 706 dtype=dtypes.int64), 707 seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64)) 708 self.assertEqual(new_seq_lengths.dtype, dtypes.int64) 709 self.assertEqual(collapsed.dtype, dtypes.int64) 710 self.assertAllEqual(new_seq_lengths, [2, 3, 4]) 711 self.assertAllEqual( 712 collapsed, 713 [[1, 3, 0, 0], 714 [1, 4, 0, 0], 715 [4, 2, 9, 4]]) 716 717 def testCollapseRepeatedExtraPadding(self): 718 collapsed, new_seq_lengths = ctc_ops.collapse_repeated( 719 labels=[[1, 3, 3, 3, 0, 0, 0], 720 [1, 4, 4, 4, 0, 1, 2], 721 [4, 2, 2, 9, 4, 0, 0]], 722 seq_length=[4, 5, 5]) 723 self.assertAllEqual(new_seq_lengths, [2, 3, 4]) 724 self.assertAllEqual( 725 collapsed, 726 [[1, 3, 0, 0], 727 [1, 4, 0, 0], 728 [4, 2, 9, 4]]) 729 730 def testCollapseRepeatedFrontRepeats(self): 731 collapsed, new_seq_lengths = ctc_ops.collapse_repeated( 732 labels=[[1, 1, 1, 2, 2], 733 [1, 1, 1, 2, 2], 734 [1, 1, 1, 2, 2]], 735 seq_length=[5, 4, 3]) 736 self.assertAllEqual(new_seq_lengths, [2, 2, 1]) 737 self.assertAllEqual( 738 collapsed, 739 [[1, 2], 740 [1, 2], 741 [1, 0]]) 742 743 def testCollapseRepeatedAllLabelsTheSame(self): 744 collapsed, new_seq_lengths = ctc_ops.collapse_repeated( 745 labels=[[1, 1, 1, 1, 1], 746 [1, 1, 1, 1, 1], 747 [1, 1, 1, 1, 1]], 748 seq_length=[4, 5, 1]) 749 self.assertAllEqual(new_seq_lengths, [1, 1, 1]) 750 self.assertAllEqual( 751 collapsed, 752 [[1], 753 [1], 754 [1]]) 755 756 def testDenseSequencesToSparse(self): 757 labels = [[1, 3, 3, 3, 0], 758 [1, 4, 4, 4, 0], 759 [4, 2, 2, 9, 4]] 760 length = [4, 5, 5] 761 sparse = ctc_ops.dense_labels_to_sparse(labels, length) 762 new_dense = sparse_ops.sparse_tensor_to_dense(sparse) 763 764 self.assertAllEqual(labels, new_dense) 765 766 padded_labels = [[1, 3, 3, 3, 0, 0, 0, 0], 767 [1, 4, 4, 4, 0, 0, 0, 0], 768 [4, 2, 2, 9, 4, 0, 0, 0]] 769 length = [4, 5, 5] 770 sparse = ctc_ops.dense_labels_to_sparse(padded_labels, length) 771 padded_dense = sparse_ops.sparse_tensor_to_dense(sparse) 772 773 self.assertAllEqual(padded_dense, new_dense) 774 775 def testUnique(self): 776 labels = [ 777 [3, 4, 4, 3], 778 [1, 1, 1, 0], 779 ] 780 unique, idx = ctc_ops.ctc_unique_labels(labels) 781 self.assertAllEqual([ 782 [3, 4, 0, 0], 783 [1, 0, 0, 0], 784 ], unique) 785 self.assertAllEqual([ 786 [0, 1, 1, 0], 787 [0, 0, 0, 1], 788 ], idx) 789 790 def testSumStates(self): 791 idx = [ 792 [0, 1, 0, 1], 793 [0, 0, 0, 1], 794 ] 795 states = math_ops.log([ 796 [[1.0, 2.0, 3.0, 4.0], 797 [5.0, 6.0, 7.0, 8.0]], 798 [[0.1, 0.2, 0.3, 0.4], 799 [0.5, 0.6, 0.7, 0.8]], 800 ]) 801 sum_of_states = math_ops.exp(ctc_ops._sum_states(idx, states)) 802 self.assertAllClose([ 803 [[4.0, 6.0, 0.0, 0.0], 804 [18.0, 8.0, 0.0, 0.0]], 805 [[0.4, 0.6, 0.0, 0.0], 806 [1.8, 0.8, 0.0, 0.0]] 807 ], sum_of_states) 808 809 def testStateToOlabel(self): 810 labels = [ 811 [3, 4, 3, 4], 812 [1, 1, 1, 0], 813 ] 814 num_labels = 8 815 816 # 3 frames, 2 batch, 10 states (5 label, 5 blank). 817 states = [ 818 [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20], 819 [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]], 820 [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], 821 [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]], 822 [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], 823 [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]], 824 ] 825 labels = ops.convert_to_tensor(labels) 826 states = math_ops.log(states) 827 olabel = ctc_ops._state_to_olabel(labels, num_labels, states) 828 olabel = math_ops.exp(olabel) 829 blank = olabel[:, :, 0] 830 self.assertAllClose(blank, [ 831 [0.16 + 0.17 + 0.18 + 0.19 + 0.20, 832 0.26 + 0.27 + 0.28 + 0.29 + 0.30], 833 [1.6 + 1.7 + 1.8 + 1.9 + 2.0, 834 2.6 + 2.7 + 2.8 + 2.9 + 3.0], 835 [16.0 + 17.0 + 18.0 + 19.0 + 20.0, 836 26.0 + 27.0 + 28.0 + 29.0 + 30.0] 837 ]) 838 self.assertAllClose(olabel[:, :, 1:], [ 839 [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0], 840 [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 841 [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0], 842 [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 843 [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0], 844 [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 845 ]) 846 847 def testStateToOlabelUnique(self): 848 labels = [ 849 [3, 4, 3, 4], 850 [1, 1, 1, 0], 851 ] 852 num_labels = 8 853 854 # 3 frames, 2 batch, 10 states (5 label, 5 blank). 855 states = [ 856 [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20], 857 [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]], 858 [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], 859 [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]], 860 [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], 861 [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]], 862 ] 863 labels = ops.convert_to_tensor(labels) 864 states = math_ops.log(states) 865 olabel = ctc_ops._state_to_olabel_unique( 866 labels, num_labels, states, ctc_ops.ctc_unique_labels(labels)) 867 olabel = math_ops.exp(olabel) 868 blank = olabel[:, :, 0] 869 self.assertAllClose(blank, [ 870 [0.16 + 0.17 + 0.18 + 0.19 + 0.20, 871 0.26 + 0.27 + 0.28 + 0.29 + 0.30], 872 [1.6 + 1.7 + 1.8 + 1.9 + 2.0, 873 2.6 + 2.7 + 2.8 + 2.9 + 3.0], 874 [16.0 + 17.0 + 18.0 + 19.0 + 20.0, 875 26.0 + 27.0 + 28.0 + 29.0 + 30.0]]) 876 self.assertAllClose(olabel[:, :, 1:], [ 877 [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0], 878 [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 879 [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0], 880 [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 881 [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0], 882 [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 883 ]) 884 885 def testStateToOlabelUniqueSinglePath(self): 886 labels = [ 887 [3, 4, 3], 888 [1, 0, 0], 889 ] 890 num_labels = 8 891 892 # 3 frames, 2 batch, 8 states (4 label, 4 blank). 893 # 894 # There is only single valid path for each sequence because the frame 895 # lengths and the label lengths are the same. 896 states = [[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 897 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 898 [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], 899 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 900 [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], 901 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]] 902 labels = ops.convert_to_tensor(labels) 903 states = math_ops.log(states) 904 olabel = ctc_ops._state_to_olabel_unique(labels, num_labels, states, 905 ctc_ops.ctc_unique_labels(labels)) 906 olabel = math_ops.exp(olabel) 907 blank = olabel[:, :, 0] 908 909 self.assertAllClose(blank, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 910 self.assertAllClose(olabel[:, :, 1:], 911 [ 912 [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], 913 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 914 [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], 915 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 916 [[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], 917 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], 918 ]) 919 920 @test_util.run_deprecated_v1 921 def testScan(self): 922 with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): 923 out = ctc_ops._scan( 924 lambda accum, elem: accum + elem, 925 constant_op.constant([1.0, 2.0, 3.0]), 23.0) 926 self.assertAllEqual([24.0, 26.0, 29.0], out) 927 928 out = ctc_ops._scan( 929 lambda a, e: a + e, 930 constant_op.constant([1.0, 2.0, 3.0]), 23.0, 931 inclusive=True) 932 self.assertAllEqual([23.0, 24.0, 26.0, 29.0], out) 933 934 out = ctc_ops._scan( 935 lambda a, e: a + e, 936 constant_op.constant([1.0, 2.0, 3.0]), 23.0, 937 reverse=True) 938 self.assertAllEqual([29.0, 28.0, 26.0], out) 939 940 out = ctc_ops._scan( 941 lambda a, e: a + e, 942 constant_op.constant([1.0, 2.0, 3.0]), 23.0, 943 reverse=True, 944 inclusive=True) 945 self.assertAllEqual([29.0, 28.0, 26.0, 23.0], out) 946 947 out = ctc_ops._scan( 948 lambda a, e: a + e, 949 constant_op.constant([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]), 950 constant_op.constant([23.0, 24.0])) 951 self.assertAllEqual([[23.0, 25.0], [25.0, 28.0], [29.0, 33.0]], out) 952 953 @test_util.run_deprecated_v1 954 def testScanCapturesVariables(self): 955 with self.cached_session() as sess: 956 x = random_ops.random_uniform([]) 957 fn = lambda accum, elem: accum + x * elem 958 out = ctc_ops._scan(fn, constant_op.constant([0.0, 1.0, 2.0]), 23.0) 959 self.assertAllClose(*sess.run([ 960 [23.0 + x * 0.0, 23.0 + x * 1.0, 23.0 + x * 3.0], out 961 ])) 962 963 @test_util.run_deprecated_v1 964 def testScanMultipleAccumulators(self): 965 with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): 966 def fn(accum, elem): 967 accum_a, accum_b = accum 968 return accum_a + elem, accum_b * elem 969 out = ctc_ops._scan( 970 fn, constant_op.constant([1.0, 2.0, 3.0]), 971 (23.0, constant_op.constant([1.0, 2.0]))) 972 a, b = out 973 self.assertAllEqual([24.0, 26.0, 29.0], a) 974 self.assertAllEqual([[1.0, 2.0], [2.0, 4.0], [6.0, 12.0]], b) 975 976 @test_util.run_deprecated_v1 977 def testScanMultipleElements(self): 978 with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"): 979 def fn(accum, elem): 980 elem_a, elem_b = elem 981 return accum + (elem_a * elem_b) 982 elems_a = constant_op.constant([1.0, 2.0, 3.0]) 983 elems_b = constant_op.constant([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]) 984 out = ctc_ops._scan( 985 fn, (elems_a, elems_b), 986 initial=constant_op.constant([0.0, 0.0])) 987 self.assertAllEqual( 988 [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out) 989 990 991def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu, 992 sparse=True): 993 with test_util.device(use_gpu=use_gpu): 994 if sparse: 995 labels = ctc_ops.dense_labels_to_sparse(labels, label_length) 996 with backprop.GradientTape() as t: 997 t.watch(logits) 998 ref_loss = ctc_ops.ctc_loss_v3( 999 labels=labels, 1000 logits=logits, 1001 label_length=label_length, 1002 logit_length=logit_length, 1003 blank_index=0) 1004 ref_grad = t.gradient(ref_loss, logits) 1005 return ref_loss, ref_grad 1006 1007 1008@test_util.run_all_in_graph_and_eager_modes 1009class CTCLossTestV3(test.TestCase, parameterized.TestCase): 1010 1011 @parameterized.parameters([False, True]) 1012 @test_util.run_v2_only 1013 def testCtcLossV3(self, run_tf_func): 1014 """Testing GPU CTC loss. 1015 1016 1017 testing if GPU CTC loss will generate same result with CPU version 1018 """ 1019 if not test.is_gpu_available(): 1020 self.skipTest("Need GPU for testing.") 1021 if not context.executing_eagerly(): 1022 self.skipTest("Need eager execution for testing.") 1023 random_seed.set_random_seed(5) 1024 1025 batch_size = 8 1026 num_labels = 6 1027 max_label_length = 5 1028 num_frames = 12 1029 1030 labels = random_ops.random_uniform([batch_size, max_label_length], 1031 minval=1, 1032 maxval=num_labels, 1033 dtype=dtypes.int64) 1034 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 1035 1036 label_length = random_ops.random_uniform([batch_size], 1037 minval=2, 1038 maxval=max_label_length, 1039 dtype=dtypes.int64) 1040 label_mask = array_ops.sequence_mask( 1041 label_length, maxlen=max_label_length, dtype=label_length.dtype) 1042 labels *= label_mask 1043 logit_length = [num_frames] * batch_size 1044 1045 if run_tf_func: 1046 ctc_loss = def_function.function(_ctc_loss_v3) 1047 else: 1048 ctc_loss = _ctc_loss_v3 1049 1050 ref_loss, ref_grad = ctc_loss(labels, logits, label_length, logit_length, 1051 False) 1052 loss, grad = ctc_loss(labels, logits, label_length, logit_length, True) 1053 1054 self.assertAllClose(loss, ref_loss, atol=1e-6) 1055 self.assertAllClose(grad, ref_grad, atol=2e-6) 1056 1057 @parameterized.parameters([False, True]) 1058 def testCtcLossFp16(self, sparse_labels): 1059 batch_size = 8 1060 num_labels = 6 1061 max_label_length = 5 1062 num_frames = 12 1063 1064 labels = np.random.randint(1, num_labels, [batch_size, max_label_length]) 1065 labels = ops.convert_to_tensor(labels, dtypes.int64) 1066 fp16_logits = np.random.uniform(size=[num_frames, batch_size, num_labels]) 1067 fp16_logits = ops.convert_to_tensor(fp16_logits, dtypes.float16) 1068 label_length = np.random.randint(2, max_label_length, [batch_size]) 1069 label_length = ops.convert_to_tensor(label_length, dtypes.int64) 1070 1071 label_mask = array_ops.sequence_mask( 1072 label_length, maxlen=max_label_length, dtype=label_length.dtype) 1073 labels *= label_mask 1074 logit_length = [num_frames] * batch_size 1075 1076 fp16_loss, fp16_grad = _ctc_loss_v3( 1077 labels, fp16_logits, label_length, logit_length, use_gpu=True, 1078 sparse=sparse_labels) 1079 fp32_loss, fp32_grad = _ctc_loss_v3( 1080 labels, math_ops.cast(fp16_logits, dtypes.float32), label_length, 1081 logit_length, use_gpu=True, sparse=sparse_labels) 1082 1083 self.assertEqual(fp16_loss.dtype, dtypes.float16) 1084 self.assertEqual(fp16_grad.dtype, dtypes.float16) 1085 self.assertAllClose( 1086 self.evaluate(fp16_loss), 1087 self.evaluate(math_ops.cast(fp32_loss, dtypes.float16)) 1088 ) 1089 self.assertAllClose( 1090 self.evaluate(fp16_grad), 1091 self.evaluate(math_ops.cast(fp32_grad, dtypes.float16)) 1092 ) 1093 1094 @parameterized.parameters([False, True]) 1095 def testCtcLossWithListLogits(self, sparse_labels): 1096 batch_size = 8 1097 num_labels = 6 1098 max_label_length = 5 1099 num_frames = 12 1100 1101 labels = np.random.randint(1, num_labels, [batch_size, max_label_length]) 1102 labels = ops.convert_to_tensor(labels, dtypes.int64) 1103 logits = np.random.uniform(size=[num_frames, batch_size, num_labels]) 1104 label_length = np.random.randint(2, max_label_length, [batch_size]) 1105 label_length = ops.convert_to_tensor(label_length, dtypes.int64) 1106 1107 label_mask = array_ops.sequence_mask( 1108 label_length, maxlen=max_label_length, dtype=label_length.dtype) 1109 labels *= label_mask 1110 logit_length = [num_frames] * batch_size 1111 if sparse_labels: 1112 labels = ctc_ops.dense_labels_to_sparse(labels, label_length) 1113 1114 list_loss = ctc_ops.ctc_loss_v3( 1115 labels=labels, 1116 logits=logits.tolist(), 1117 label_length=label_length, 1118 logit_length=logit_length, 1119 blank_index=0) 1120 tensor_loss = ctc_ops.ctc_loss_v3( 1121 labels=labels, 1122 logits=ops.convert_to_tensor(logits, dtypes.float32), 1123 label_length=label_length, 1124 logit_length=logit_length, 1125 blank_index=0) 1126 1127 self.assertAllClose(self.evaluate(list_loss), self.evaluate(tensor_loss)) 1128 1129 @test_util.run_v2_only 1130 def testCtcLossAlgorithmFallback(self): 1131 """Test if GPU CTC loss can fallback to the correct algorithm.""" 1132 if not test.is_gpu_available(): 1133 self.skipTest("Need GPU for testing.") 1134 if not context.executing_eagerly(): 1135 self.skipTest("Need eager execution for testing.") 1136 random_seed.set_random_seed(5) 1137 1138 batch_size = 1 1139 num_labels = 11777 1140 max_label_length = 2 1141 num_frames = 1 1142 1143 labels = random_ops.random_uniform([batch_size, max_label_length], 1144 minval=1, 1145 maxval=num_labels, 1146 dtype=dtypes.int64) 1147 logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) 1148 1149 label_length = random_ops.random_uniform([batch_size], 1150 minval=1, 1151 maxval=max_label_length, 1152 dtype=dtypes.int64) 1153 logit_length = [num_frames] * batch_size 1154 1155 loss, grad = _ctc_loss_v3(labels, logits, label_length, logit_length, True) 1156 ref_loss, ref_grad = _ctc_loss_v3(labels, logits, label_length, 1157 logit_length, False) 1158 1159 self.assertAllClose(loss, ref_loss, atol=1e-6) 1160 self.assertAllClose(grad, ref_grad, atol=2e-6) 1161 1162 1163@test_util.run_all_in_graph_and_eager_modes 1164class CTCLossDeterministicTest(test.TestCase, parameterized.TestCase): 1165 1166 def _randomFloats(self, shape): 1167 x = (2 * np.random.random_sample(shape) - 1) 1168 return constant_op.constant(x, dtype=dtypes.float32) 1169 1170 def _genInputParams(self, 1171 num_classes=10, 1172 batch_size=32, 1173 max_label_sequence_length=50, 1174 num_frames=100, 1175 logits_time_major=True, 1176 sparse_labels=True): 1177 assert num_frames >= max_label_sequence_length 1178 1179 labels_shape = (batch_size, max_label_sequence_length) 1180 # Zero-pad the labels. Zero is the default blank index in the TF2 API. 1181 # num_classes includes the blank class 1182 unmasked_labels = np.random.randint( 1183 1, num_classes, size=labels_shape, dtype=np.int32) 1184 labels_lengths = np.random.randint( 1185 1, high=max_label_sequence_length, size=batch_size, dtype=np.int32) 1186 labels_masks = (np.arange(max_label_sequence_length) < 1187 labels_lengths.reshape(batch_size, 1)).astype(np.int32) 1188 labels = unmasked_labels * labels_masks 1189 if sparse_labels: 1190 labels = ctc_ops.dense_labels_to_sparse(labels, labels_lengths) 1191 1192 if logits_time_major: 1193 logits_shape = (num_frames, batch_size, num_classes) 1194 else: 1195 logits_shape = (batch_size, num_frames, num_classes) 1196 logits = self._randomFloats(logits_shape) 1197 1198 labels_lengths = constant_op.constant(labels_lengths) 1199 1200 logits_lengths = [num_frames] * batch_size 1201 logits_lengths = constant_op.constant(logits_lengths) 1202 1203 return labels, logits, labels_lengths, logits_lengths 1204 1205 def _forwardAndBackward(self, sparse_labels, logits_time_major, seed): 1206 np.random.seed(seed) 1207 params = self._genInputParams( 1208 logits_time_major=logits_time_major, sparse_labels=sparse_labels) 1209 labels, logits, labels_lengths, logits_lengths = params 1210 output_shape = (labels_lengths.shape[0],) 1211 upstream_gradients = self._randomFloats(output_shape) 1212 with backprop.GradientTape() as tape: 1213 tape.watch(logits) 1214 loss = ctc_ops.ctc_loss_v3( 1215 labels, 1216 logits, 1217 labels_lengths, 1218 logits_lengths, 1219 logits_time_major=logits_time_major, 1220 blank_index=0) 1221 gradient_injector_output = loss * upstream_gradients 1222 return loss, tape.gradient(gradient_injector_output, logits) 1223 1224 @parameterized.parameters( # parameterized.product not yet available 1225 (False, False), (False, True), (True, False), (True, True)) 1226 def testForwardAndBackward(self, sparse_labels, logits_time_major): 1227 with test_util.deterministic_ops(): 1228 for seed in range(2): 1229 loss_a, gradient_a = self._forwardAndBackward(sparse_labels, 1230 logits_time_major, seed) 1231 loss_b, gradient_b = self._forwardAndBackward(sparse_labels, 1232 logits_time_major, seed) 1233 loss_a, loss_b, gradient_a, gradient_b = self.evaluate( 1234 (loss_a, loss_b, gradient_a, gradient_b)) 1235 self.assertAllEqual(loss_a, loss_b, "Loss mismatch") 1236 self.assertAllEqual(gradient_a, gradient_b, "Gradient mismatch") 1237 1238 1239if __name__ == "__main__": 1240 test.main() 1241