1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CTC (Connectionist Temporal Classification) Operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import uuid 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import function as function_eager 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import device 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import function 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_shape 33 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import custom_gradient 36from tensorflow.python.ops import functional_ops 37from tensorflow.python.ops import gen_ctc_ops 38from tensorflow.python.ops import inplace_ops 39from tensorflow.python.ops import linalg_ops 40from tensorflow.python.ops import map_fn 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import nn_ops 43from tensorflow.python.ops import sparse_ops 44from tensorflow.python.ops.nn_grad import _BroadcastMul 45from tensorflow.python.util import deprecation 46from tensorflow.python.util import dispatch 47from tensorflow.python.util import nest 48from tensorflow.python.util.tf_export import tf_export 49 50_DEFUN_API_NAME_ATTRIBUTE = "api_implements" 51_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device" 52_CPU_DEVICE_NAME = "CPU" 53_GPU_DEVICE_NAME = "GPU" 54 55 56def _get_context_device_type(): 57 """Parse the current context and return the device type, eg CPU/GPU.""" 58 current_device = context.context().device_name 59 if current_device is None: 60 return None 61 return device.DeviceSpec.from_string(current_device).device_type 62 63 64def _generate_defun_backend(unique_api_name, preferred_device, func): 65 function_attributes = { 66 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, 67 _DEFUN_DEVICE_ATTRIBUTE: preferred_device, 68 } 69 return function_eager.defun_with_attributes( 70 func=func, attributes=function_attributes, autograph=False) 71 72# pylint: disable=protected-access, invalid-name 73@tf_export(v1=["nn.ctc_loss"]) 74@dispatch.add_dispatch_support 75def ctc_loss(labels, 76 inputs=None, 77 sequence_length=None, 78 preprocess_collapse_repeated=False, 79 ctc_merge_repeated=True, 80 ignore_longer_outputs_than_inputs=False, 81 time_major=True, 82 logits=None): 83 """Computes the CTC (Connectionist Temporal Classification) Loss. 84 85 This op implements the CTC loss as presented in (Graves et al., 2006). 86 87 Input requirements: 88 89 ``` 90 sequence_length(b) <= time for all b 91 92 max(labels.indices(labels.indices[:, 1] == b, 2)) 93 <= sequence_length(b) for all b. 94 ``` 95 96 Notes: 97 98 This class performs the softmax operation for you, so inputs should 99 be e.g. linear projections of outputs by an LSTM. 100 101 The `inputs` Tensor's innermost dimension size, `num_classes`, represents 102 `num_labels + 1` classes, where num_labels is the number of true labels, and 103 the largest value `(num_classes - 1)` is reserved for the blank label. 104 105 For example, for a vocabulary containing 3 labels `[a, b, c]`, 106 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. 107 108 Regarding the arguments `preprocess_collapse_repeated` and 109 `ctc_merge_repeated`: 110 111 If `preprocess_collapse_repeated` is True, then a preprocessing step runs 112 before loss calculation, wherein repeated labels passed to the loss 113 are merged into single labels. This is useful if the training labels come 114 from, e.g., forced alignments and therefore have unnecessary repetitions. 115 116 If `ctc_merge_repeated` is set False, then deep within the CTC calculation, 117 repeated non-blank labels will not be merged and are interpreted 118 as individual labels. This is a simplified (non-standard) version of CTC. 119 120 Here is a table of the (roughly) expected first order behavior: 121 122 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` 123 124 Classical CTC behavior: Outputs true repeated classes with blanks in 125 between, and can also output repeated classes with no blanks in 126 between that need to be collapsed by the decoder. 127 128 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` 129 130 Never learns to output repeated classes, as they are collapsed 131 in the input labels before training. 132 133 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` 134 135 Outputs repeated classes with blanks in between, but generally does not 136 require the decoder to collapse/merge repeated classes. 137 138 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` 139 140 Untested. Very likely will not learn to output repeated classes. 141 142 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior 143 of the CTCLoss when dealing with sequences that have longer outputs than 144 inputs. If true, the CTCLoss will simply return zero gradient for those 145 items, otherwise an InvalidArgument error is returned, stopping training. 146 147 Args: 148 labels: An `int32` `SparseTensor`. 149 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id 150 for (batch b, time t). `labels.values[i]` must take on values in `[0, 151 num_labels)`. See `core/ops/ctc_ops.cc` for more details. 152 inputs: 3-D `float` `Tensor`. 153 If time_major == False, this will be a `Tensor` shaped: `[batch_size, 154 max_time, num_classes]`. 155 If time_major == True (default), this will be a `Tensor` shaped: 156 `[max_time, batch_size, num_classes]`. The logits. 157 sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence 158 lengths. 159 preprocess_collapse_repeated: Boolean. Default: False. If True, repeated 160 labels are collapsed prior to the CTC calculation. 161 ctc_merge_repeated: Boolean. Default: True. 162 ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, 163 sequences with longer outputs than inputs will be ignored. 164 time_major: The shape format of the `inputs` Tensors. If True, these 165 `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, 166 these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. 167 Using `time_major = True` (default) is a bit more efficient because it 168 avoids transposes at the beginning of the ctc_loss calculation. However, 169 most TensorFlow data is batch-major, so by this function also accepts 170 inputs in batch-major form. 171 logits: Alias for inputs. 172 173 Returns: 174 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log 175 probabilities. 176 177 Raises: 178 TypeError: if labels is not a `SparseTensor`. 179 180 References: 181 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 182 with Recurrent Neural Networks: 183 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 184 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 185 """ 186 return _ctc_loss_impl( 187 labels, 188 inputs, 189 sequence_length, 190 preprocess_collapse_repeated, 191 ctc_merge_repeated, 192 ignore_longer_outputs_than_inputs, 193 time_major, 194 logits, 195 use_cudnn=False) 196 197 198def _ctc_loss_impl(labels, 199 inputs=None, 200 sequence_length=None, 201 preprocess_collapse_repeated=False, 202 ctc_merge_repeated=True, 203 ignore_longer_outputs_than_inputs=False, 204 time_major=True, 205 logits=None, 206 use_cudnn=False): 207 # Helper function of ctc_loss with one additional param: 208 # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank 209 # index has to be 0. 210 211 # The second, third, etc output tensors contain the gradients. We use it in 212 # _CTCLossGrad() below. 213 if not isinstance(labels, sparse_tensor.SparseTensor): 214 raise TypeError("Expected labels (first argument) to be a SparseTensor") 215 216 # For internal calculations, we transpose to [time, batch, num_classes] 217 inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs", 218 inputs) 219 if not time_major: 220 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) 221 222 # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the 223 # blank index to be 0, but v1 views it as the last index. 224 if use_cudnn: 225 ctc_loss_func = gen_ctc_ops.ctc_loss_v2 226 else: 227 ctc_loss_func = gen_ctc_ops.ctc_loss 228 229 loss, _ = ctc_loss_func( 230 inputs, 231 labels.indices, 232 labels.values, 233 sequence_length, 234 preprocess_collapse_repeated=preprocess_collapse_repeated, 235 ctc_merge_repeated=ctc_merge_repeated, 236 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) 237 238 return loss 239 240# pylint: disable=unused-argument 241def _CTCLossGradImpl(op, grad_loss, _): 242 # Outputs are: loss, grad 243 # 244 # Currently there is no way to take the second derivative of this op 245 # due to the fused implementation's interaction with tf.gradients(), 246 # so we make sure we prevent silently incorrect results by raising 247 # an error if the second derivative is requested via prevent_gradient. 248 grad_without_gradient = array_ops.prevent_gradient( 249 op.outputs[1], 250 message="Currently there is no way to take the second " 251 " derivative of ctc_loss due to the fused implementation's interaction " 252 " with tf.gradients()") 253 # Return gradient for inputs and None for 254 # labels_indices, labels_values and sequence_length 255 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] 256 257 258# pylint: disable=unused-argument 259@ops.RegisterGradient("CTCLoss") 260def _CTCLossGrad(op, grad_loss, _): 261 """The derivative provided by CTC Loss. 262 263 Args: 264 op: the CTCLoss op. 265 grad_loss: The backprop for cost. 266 267 Returns: 268 The CTC Loss gradient. 269 """ 270 return _CTCLossGradImpl(op, grad_loss, _) 271 272 273# pylint: disable=unused-argument 274@ops.RegisterGradient("CTCLossV2") 275def _CTCLossV2Grad(op, grad_loss, _): 276 """The derivative provided by CTC Loss V2. 277 278 Args: 279 op: the CTCLossV2 op. 280 grad_loss: The backprop for cost. 281 282 Returns: 283 The CTC Loss V2 gradient. 284 """ 285 return _CTCLossGradImpl(op, grad_loss, _) 286 287 288@tf_export("nn.ctc_greedy_decoder") 289@dispatch.add_dispatch_support 290def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): 291 """Performs greedy decoding on the logits given in input (best path). 292 293 Note: Regardless of the value of merge_repeated, if the maximum index of a 294 given time and batch corresponds to the blank index `(num_classes - 1)`, no 295 new element is emitted. 296 297 If `merge_repeated` is `True`, merge repeated classes in output. 298 This means that if consecutive logits' maximum indices are the same, 299 only the first of these is emitted. The sequence `A B B * B * B` (where '*' 300 is the blank label) becomes 301 302 * `A B B B` if `merge_repeated=True`. 303 * `A B B B B` if `merge_repeated=False`. 304 305 Args: 306 inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. 307 The logits. 308 sequence_length: 1-D `int32` vector containing sequence lengths, having size 309 `[batch_size]`. 310 merge_repeated: Boolean. Default: True. 311 312 Returns: 313 A tuple `(decoded, neg_sum_logits)` where 314 315 decoded: A single-element list. `decoded[0]` 316 is an `SparseTensor` containing the decoded outputs s.t.: 317 318 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. 319 The rows store: `[batch, time]`. 320 321 `decoded.values`: Values vector, size `(total_decoded_outputs)`. 322 The vector stores the decoded classes. 323 324 `decoded.dense_shape`: Shape vector, size `(2)`. 325 The shape values are: `[batch_size, max_decoded_length]` 326 327 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the 328 sequence found, the negative of the sum of the greatest logit at each 329 timeframe. 330 """ 331 outputs = gen_ctc_ops.ctc_greedy_decoder( 332 inputs, sequence_length, merge_repeated=merge_repeated) 333 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs 334 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, 335 decoded_shape)], log_probabilities) 336 337 338@tf_export(v1=["nn.ctc_beam_search_decoder"]) 339@dispatch.add_dispatch_support 340def ctc_beam_search_decoder(inputs, 341 sequence_length, 342 beam_width=100, 343 top_paths=1, 344 merge_repeated=True): 345 """Performs beam search decoding on the logits given in input. 346 347 **Note** The `ctc_greedy_decoder` is a special case of the 348 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 349 that decoder is faster for this special case). 350 351 If `merge_repeated` is `True`, merge repeated classes in the output beams. 352 This means that if consecutive entries in a beam are the same, 353 only the first of these is emitted. That is, when the sequence is 354 `A B B * B * B` (where '*' is the blank label), the return value is: 355 356 * `A B` if `merge_repeated = True`. 357 * `A B B B` if `merge_repeated = False`. 358 359 Args: 360 inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`. 361 The logits. 362 sequence_length: 1-D `int32` vector containing sequence lengths, having size 363 `[batch_size]`. 364 beam_width: An int scalar >= 0 (beam search beam width). 365 top_paths: An int scalar >= 0, <= beam_width (controls output size). 366 merge_repeated: Boolean. Default: True. 367 368 Returns: 369 A tuple `(decoded, log_probabilities)` where 370 371 decoded: A list of length top_paths, where `decoded[j]` 372 is a `SparseTensor` containing the decoded outputs: 373 374 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` 375 The rows store: [batch, time]. 376 377 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. 378 The vector stores the decoded classes for beam j. 379 380 `decoded[j].dense_shape`: Shape vector, size `(2)`. 381 The shape values are: `[batch_size, max_decoded_length[j]]`. 382 383 log_probability: A `float` matrix `(batch_size x top_paths)` containing 384 sequence log-probabilities. 385 """ 386 387 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( 388 gen_ctc_ops.ctc_beam_search_decoder( 389 inputs, 390 sequence_length, 391 beam_width=beam_width, 392 top_paths=top_paths, 393 merge_repeated=merge_repeated)) 394 395 return ([ 396 sparse_tensor.SparseTensor(ix, val, shape) 397 for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes) 398 ], log_probabilities) 399 400 401@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) 402@dispatch.add_dispatch_support 403def ctc_beam_search_decoder_v2(inputs, 404 sequence_length, 405 beam_width=100, 406 top_paths=1): 407 """Performs beam search decoding on the logits given in input. 408 409 **Note** The `ctc_greedy_decoder` is a special case of the 410 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 411 that decoder is faster for this special case). 412 413 Args: 414 inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`. 415 The logits. 416 sequence_length: 1-D `int32` vector containing sequence lengths, having size 417 `[batch_size]`. 418 beam_width: An int scalar >= 0 (beam search beam width). 419 top_paths: An int scalar >= 0, <= beam_width (controls output size). 420 421 Returns: 422 A tuple `(decoded, log_probabilities)` where 423 424 decoded: A list of length top_paths, where `decoded[j]` 425 is a `SparseTensor` containing the decoded outputs: 426 427 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; 428 The rows store: `[batch, time]`. 429 430 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. 431 The vector stores the decoded classes for beam `j`. 432 433 `decoded[j].dense_shape`: Shape vector, size `(2)`. 434 The shape values are: `[batch_size, max_decoded_length[j]]`. 435 436 log_probability: A `float` matrix `[batch_size, top_paths]` containing 437 sequence log-probabilities. 438 """ 439 440 # Note, merge_repeated is an invalid optimization that is removed from the 441 # public API: it returns low probability paths. 442 return ctc_beam_search_decoder( 443 inputs, 444 sequence_length=sequence_length, 445 beam_width=beam_width, 446 top_paths=top_paths, 447 merge_repeated=False) 448 449 450ops.NotDifferentiable("CTCGreedyDecoder") 451ops.NotDifferentiable("CTCBeamSearchDecoder") 452 453 454def _ctc_state_trans(label_seq): 455 """Compute CTC alignment model transition matrix. 456 457 Args: 458 label_seq: tensor of shape [batch_size, max_seq_length] 459 460 Returns: 461 tensor of shape [batch_size, states, states] with a state transition matrix 462 computed for each sequence of the batch. 463 """ 464 465 with ops.name_scope("ctc_state_trans"): 466 label_seq = ops.convert_to_tensor(label_seq, name="label_seq") 467 batch_size = _get_dim(label_seq, 0) 468 num_labels = _get_dim(label_seq, 1) 469 470 num_label_states = num_labels + 1 471 num_states = 2 * num_label_states 472 473 label_states = math_ops.range(num_label_states) 474 blank_states = label_states + num_label_states 475 476 # Start state to first label. 477 start_to_label = [[1, 0]] 478 479 # Blank to label transitions. 480 blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1) 481 482 # Label to blank transitions. 483 label_to_blank = array_ops.stack([blank_states, label_states], 1) 484 485 # Scatter transitions that don't depend on sequence. 486 indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank], 487 0) 488 values = array_ops.ones([_get_dim(indices, 0)]) 489 trans = array_ops.scatter_nd( 490 indices, values, shape=[num_states, num_states]) 491 trans += linalg_ops.eye(num_states) # Self-loops. 492 493 # Label to label transitions. Disallow transitions between repeated labels 494 # with no blank state in between. 495 batch_idx = array_ops.zeros_like(label_states[2:]) 496 indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]], 497 1) 498 indices = array_ops.tile( 499 array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) 500 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] 501 indices += array_ops.expand_dims(batch_idx, 1) 502 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) 503 values = 1.0 - math_ops.cast(repeats, dtypes.float32) 504 batched_shape = [batch_size, num_states, num_states] 505 label_to_label = array_ops.scatter_nd(indices, values, batched_shape) 506 507 return array_ops.expand_dims(trans, 0) + label_to_label 508 509 510def ctc_state_log_probs(seq_lengths, max_seq_length): 511 """Computes CTC alignment initial and final state log probabilities. 512 513 Create the initial/final state values directly as log values to avoid 514 having to take a float64 log on tpu (which does not exist). 515 516 Args: 517 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. 518 max_seq_length: int, max sequence length possible. 519 520 Returns: 521 initial_state_log_probs, final_state_log_probs 522 """ 523 524 batch_size = _get_dim(seq_lengths, 0) 525 num_label_states = max_seq_length + 1 526 num_duration_states = 2 527 num_states = num_duration_states * num_label_states 528 log_0 = math_ops.cast( 529 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32) 530 531 initial_state_log_probs = array_ops.one_hot( 532 indices=array_ops.zeros([batch_size], dtype=dtypes.int32), 533 depth=num_states, 534 on_value=0.0, 535 off_value=log_0, 536 axis=1) 537 538 label_final_state_mask = array_ops.one_hot( 539 seq_lengths, depth=num_label_states, axis=0) 540 duration_final_state_mask = array_ops.ones( 541 [num_duration_states, 1, batch_size]) 542 final_state_mask = duration_final_state_mask * label_final_state_mask 543 final_state_log_probs = (1.0 - final_state_mask) * log_0 544 final_state_log_probs = array_ops.reshape(final_state_log_probs, 545 [num_states, batch_size]) 546 547 return initial_state_log_probs, array_ops.transpose(final_state_log_probs) 548 549 550def _ilabel_to_state(labels, num_labels, ilabel_log_probs): 551 """Project ilabel log probs to state log probs.""" 552 553 num_label_states = _get_dim(labels, 1) 554 blank = ilabel_log_probs[:, :, :1] 555 blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) 556 one_hot = array_ops.one_hot(labels, depth=num_labels) 557 one_hot = array_ops.expand_dims(one_hot, axis=0) 558 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) 559 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) 560 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) 561 return array_ops.pad( 562 state_log_probs, [[0, 0], [0, 0], [1, 0]], 563 constant_values=math_ops.log(0.0)) 564 565 566def _state_to_olabel(labels, num_labels, states): 567 """Sum state log probs to ilabel log probs.""" 568 569 num_label_states = _get_dim(labels, 1) + 1 570 label_states = states[:, :, 1:num_label_states] 571 blank_states = states[:, :, num_label_states:] 572 one_hot = array_ops.one_hot( 573 labels - 1, 574 depth=(num_labels - 1), 575 on_value=0.0, 576 off_value=math_ops.log(0.0)) 577 one_hot = array_ops.expand_dims(one_hot, axis=0) 578 label_states = array_ops.expand_dims(label_states, axis=3) 579 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) 580 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 581 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 582 583 584# pylint: disable=redefined-outer-name 585def _state_to_olabel_unique(labels, num_labels, states, unique): 586 """Sum state log probs to ilabel log probs using unique label indices.""" 587 588 num_label_states = _get_dim(labels, 1) + 1 589 label_states = states[:, :, 1:num_label_states] 590 blank_states = states[:, :, num_label_states:] 591 592 unique_y, unique_idx = unique 593 mul_reduce = _sum_states(unique_idx, label_states) 594 595 num_frames = states.shape[0] 596 batch_size = states.shape[1] 597 num_states = num_label_states - 1 598 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) 599 batch_state_major = array_ops.reshape(batch_state_major, 600 [batch_size * num_states, num_frames]) 601 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels 602 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) 603 indices = array_ops.reshape(indices, [-1, 1]) 604 scatter = array_ops.scatter_nd( 605 indices=indices, 606 updates=batch_state_major, 607 shape=[batch_size * num_labels, num_frames]) 608 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) 609 610 mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) 611 mask = array_ops.scatter_nd( 612 indices=indices, 613 updates=mask, 614 shape=[batch_size * num_labels, num_frames]) 615 mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) 616 617 scatter = array_ops.where( 618 mask, scatter, 619 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) 620 621 label_olabels = array_ops.transpose(scatter, [2, 0, 1]) 622 label_olabels = label_olabels[:, :, 1:] 623 624 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 625 626 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 627 628 629def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): 630 """Computes the CTC loss and gradients. 631 632 Most users will want fwd_bwd.ctc_loss 633 634 This function returns the computed gradient, it does not have a gradient 635 of its own defined. 636 637 Args: 638 logits: tensor of shape [frames, batch_size, num_labels] 639 labels: tensor of shape [batch_size, max_label_seq_length] 640 label_length: tensor of shape [batch_size] Length of reference label 641 sequence in labels. 642 logit_length: tensor of shape [batch_size] Length of input sequence in 643 logits. 644 unique: (optional) unique label indices as computed by unique(labels) If 645 supplied, enables an implementation that is faster and more memory 646 efficient on TPU. 647 648 Returns: 649 loss: tensor of shape [batch_size] 650 gradient: tensor of shape [frames, batch_size, num_labels] 651 """ 652 653 num_labels = _get_dim(logits, 2) 654 max_label_seq_length = _get_dim(labels, 1) 655 656 ilabel_log_probs = nn_ops.log_softmax(logits) 657 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) 658 state_trans_probs = _ctc_state_trans(labels) 659 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( 660 label_length, max_label_seq_length) 661 fwd_bwd_log_probs, log_likelihood = _forward_backward_log( 662 state_trans_log_probs=math_ops.log(state_trans_probs), 663 initial_state_log_probs=initial_state_log_probs, 664 final_state_log_probs=final_state_log_probs, 665 observed_log_probs=state_log_probs, 666 sequence_length=logit_length) 667 668 if unique: 669 olabel_log_probs = _state_to_olabel_unique(labels, num_labels, 670 fwd_bwd_log_probs, unique) 671 else: 672 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) 673 674 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) 675 676 # Applies the sequence mask for the gradient. It is enough to appply the mask 677 # only for ilabel_log_probs because olabel_log_probs already consider the 678 # mask. However, it is just safe and clean to apply it for the gradient. 679 max_logit_length = _get_dim(logits, 0) 680 logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, 681 dtypes.float32) 682 logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) 683 logit_mask = array_ops.expand_dims(logit_mask, axis=2) 684 grad *= logit_mask 685 686 loss = -log_likelihood 687 return loss, grad 688 689 690def _ctc_loss_grad(op, grad_loss, _): 691 grad = op.outputs[1] 692 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] 693 grad += [None] * (len(op.inputs) - len(grad)) 694 return grad 695 696 697def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major, 698 blank_index): 699 part_before = logits[:, :, :blank_index] 700 part_after = logits[:, :, blank_index + 1:] 701 part_blank = logits[:, :, blank_index:blank_index + 1] 702 logits = array_ops.concat([part_before, part_after, part_blank], axis=2) 703 labels = sparse_tensor.SparseTensor( 704 labels.indices, 705 array_ops.where(labels.values < blank_index, labels.values, 706 labels.values - 1), labels.dense_shape) 707 return _ctc_loss_impl( 708 labels=labels, 709 inputs=logits, 710 sequence_length=logit_length, 711 time_major=logits_time_major, 712 use_cudnn=False) 713 714 715def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major, 716 blank_index): 717 part_before = logits[:, :, :blank_index] 718 part_after = logits[:, :, blank_index + 1:] 719 part_blank = logits[:, :, blank_index:blank_index + 1] 720 logits = array_ops.concat([part_blank, part_before, part_after], axis=2) 721 labels = sparse_tensor.SparseTensor( 722 labels.indices, 723 array_ops.where(labels.values < blank_index, labels.values + 1, 724 labels.values), labels.dense_shape) 725 return _ctc_loss_impl( 726 labels=labels, 727 inputs=logits, 728 sequence_length=logit_length, 729 time_major=logits_time_major, 730 use_cudnn=True) 731 732 733def _ctc_loss_shape(op): 734 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] 735 736 737# pylint: disable=protected-access, invalid-name 738@tf_export(v1=["nn.ctc_loss_v2"]) 739@dispatch.add_dispatch_support 740def ctc_loss_v2(labels, 741 logits, 742 label_length, 743 logit_length, 744 logits_time_major=True, 745 unique=None, 746 blank_index=None, 747 name=None): 748 """Computes CTC (Connectionist Temporal Classification) loss. 749 750 This op implements the CTC loss as presented in (Graves et al., 2006). 751 752 Notes: 753 754 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 755 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 756 - Labels may be supplied as either a dense, zero-padded tensor with a 757 vector of label sequence lengths OR as a SparseTensor. 758 - On TPU and GPU: Only dense padded labels are supported. 759 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 760 a SparseTensor will be significantly faster. 761 - Default blank label is 0 rather num_classes - 1, unless overridden by 762 blank_index. 763 764 Args: 765 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 766 logits: tensor of shape [frames, batch_size, num_labels], if 767 logits_time_major == False, shape is [batch_size, frames, num_labels]. 768 label_length: tensor of shape [batch_size], None if labels is SparseTensor 769 Length of reference label sequence in labels. 770 logit_length: tensor of shape [batch_size] Length of input sequence in 771 logits. 772 logits_time_major: (optional) If True (default), logits is shaped [time, 773 batch, logits]. If False, shape is [batch, time, logits] 774 unique: (optional) Unique label indices as computed by 775 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 776 implementation on TPU. 777 blank_index: (optional) Set the class index to use for the blank label. 778 Negative values will start from num_classes, ie, -1 will reproduce the 779 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 780 some memory/performance overhead to switching from the default of 0 as an 781 additional shifted copy of the logits may be created. 782 name: A name for this `Op`. Defaults to "ctc_loss_dense". 783 784 Returns: 785 loss: tensor of shape [batch_size], negative log probabilities. 786 787 References: 788 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 789 with Recurrent Neural Networks: 790 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 791 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 792 """ 793 if isinstance(labels, sparse_tensor.SparseTensor): 794 if blank_index is None: 795 raise ValueError( 796 "blank_index must be given when using SparseTensor labels.") 797 798 if blank_index < 0: 799 blank_index += _get_dim(logits, 2) 800 801 if blank_index != _get_dim(logits, 2) - 1: 802 logits = array_ops.concat([ 803 logits[:, :, :blank_index], 804 logits[:, :, blank_index + 1:], 805 logits[:, :, blank_index:blank_index + 1], 806 ], 807 axis=2) 808 labels = sparse_tensor.SparseTensor( 809 labels.indices, 810 array_ops.where(labels.values < blank_index, labels.values, 811 labels.values - 1), labels.dense_shape) 812 813 return ctc_loss( 814 labels=labels, 815 inputs=logits, 816 sequence_length=logit_length, 817 time_major=logits_time_major) 818 819 if blank_index is None: 820 blank_index = 0 821 822 return ctc_loss_dense( 823 labels=labels, 824 logits=logits, 825 label_length=label_length, 826 logit_length=logit_length, 827 logits_time_major=logits_time_major, 828 unique=unique, 829 blank_index=blank_index, 830 name=name) 831 832 833@tf_export("nn.ctc_loss", v1=[]) 834@dispatch.add_dispatch_support 835def ctc_loss_v3(labels, 836 logits, 837 label_length, 838 logit_length, 839 logits_time_major=True, 840 unique=None, 841 blank_index=None, 842 name=None): 843 """Computes CTC (Connectionist Temporal Classification) loss. 844 845 This op implements the CTC loss as presented in (Graves et al., 2006). 846 847 Notes: 848 849 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 850 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 851 - Labels may be supplied as either a dense, zero-padded tensor with a 852 vector of label sequence lengths OR as a SparseTensor. 853 - On TPU and GPU: Only dense padded labels are supported. 854 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 855 a SparseTensor will be significantly faster. 856 - Default blank label is 0 rather num_classes - 1, unless overridden by 857 blank_index. 858 859 Args: 860 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 861 logits: tensor of shape [frames, batch_size, num_labels], if 862 logits_time_major == False, shape is [batch_size, frames, num_labels]. 863 label_length: tensor of shape [batch_size], None if labels is SparseTensor 864 Length of reference label sequence in labels. 865 logit_length: tensor of shape [batch_size] Length of input sequence in 866 logits. 867 logits_time_major: (optional) If True (default), logits is shaped [time, 868 batch, logits]. If False, shape is [batch, time, logits] 869 unique: (optional) Unique label indices as computed by 870 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 871 implementation on TPU. 872 blank_index: (optional) Set the class index to use for the blank label. 873 Negative values will start from num_classes, ie, -1 will reproduce the 874 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 875 some memory/performance overhead to switching from the default of 0 as an 876 additional shifted copy of the logits may be created. 877 name: A name for this `Op`. Defaults to "ctc_loss_dense". 878 879 Returns: 880 loss: tensor of shape [batch_size], negative log probabilities. 881 882 References: 883 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 884 with Recurrent Neural Networks: 885 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 886 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 887 """ 888 if isinstance(labels, sparse_tensor.SparseTensor): 889 if blank_index is None: 890 raise ValueError( 891 "blank_index must be given when using SparseTensor labels.") 892 893 if blank_index < 0: 894 blank_index += _get_dim(logits, 2) 895 896 params = { 897 "labels": labels, 898 "logits": logits, 899 "logit_length": logit_length, 900 "logits_time_major": logits_time_major, 901 "blank_index": blank_index 902 } 903 904 if context.executing_eagerly(): 905 device_type = _get_context_device_type() 906 can_use_gpu = ( 907 # Either user specified GPU or unspecified but GPU is available. 908 (device_type == _GPU_DEVICE_NAME or 909 (device_type is None and context.num_gpus() > 0))) 910 # Under eager context, check the device placement and prefer the 911 if can_use_gpu: 912 res = _ctc_loss_op_cudnn(**params) 913 else: 914 res = _ctc_loss_op_standard(**params) 915 else: 916 api_name = "ctc_loss_" + str(uuid.uuid4()) 917 ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME, 918 _ctc_loss_op_standard) 919 ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME, 920 _ctc_loss_op_cudnn) 921 res = ctc_loss_op_standard(**params) 922 function_eager.register(ctc_loss_op_cudnn, **params) 923 return res 924 925 if blank_index is None: 926 blank_index = 0 927 928 return ctc_loss_dense( 929 labels=labels, 930 logits=logits, 931 label_length=label_length, 932 logit_length=logit_length, 933 logits_time_major=logits_time_major, 934 unique=unique, 935 blank_index=blank_index, 936 name=name) 937 938 939def ctc_loss_dense(labels, 940 logits, 941 label_length, 942 logit_length, 943 logits_time_major=True, 944 unique=None, 945 blank_index=0, 946 name=None): 947 """Computes CTC (Connectionist Temporal Classification) loss. 948 949 This op implements the CTC loss as presented in (Graves et al., 2006), 950 using the batched forward backward algorithm described in (Sim et al., 2017). 951 952 Notes: 953 Significant differences from tf.compat.v1.nn.ctc_loss: 954 Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only): 955 For batched operations, GPU and TPU are significantly faster than using 956 ctc_loss on CPU. 957 This implementation runs on CPU, but significantly slower than ctc_loss. 958 Blank label is 0 rather num_classes - 1, unless overridden by blank_index. 959 Logits and labels are dense arrays with padding rather than SparseTensor. 960 The only mode supported is the same as: 961 preprocess_collapse_repeated=False, ctc_merge_repeated=True 962 To collapse labels, the caller can preprocess label sequence first. 963 964 The dense implementation supports both CPU, GPU and TPU. A fast path is 965 provided that significantly improves memory use for large vocabulary if the 966 caller preprocesses label sequences to get unique label indices on the CPU 967 (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in 968 the optional "unique" kwarg. This is especially useful for TPU and GPU but 969 also works with if used on CPU. 970 971 Args: 972 labels: tensor of shape [batch_size, max_label_seq_length] 973 logits: tensor of shape [frames, batch_size, num_labels], if 974 logits_time_major == False, shape is [batch_size, frames, num_labels]. 975 label_length: tensor of shape [batch_size] Length of reference label 976 sequence in labels. 977 logit_length: tensor of shape [batch_size] Length of input sequence in 978 logits. 979 logits_time_major: (optional) If True (default), logits is shaped [time, 980 batch, logits]. If False, shape is [batch, time, logits] 981 unique: (optional) Unique label indices as computed by unique(labels). If 982 supplied, enable a faster, memory efficient implementation on TPU. 983 blank_index: (optional) Set the class index to use for the blank label. 984 Negative values will start from num_classes, ie, -1 will reproduce the 985 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 986 some memory/performance overhead to switching from the default of 0 as an 987 additional shifted copy of the logits may be created. 988 name: A name for this `Op`. Defaults to "ctc_loss_dense". 989 990 Returns: 991 loss: tensor of shape [batch_size], negative log probabilities. 992 993 References: 994 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 995 with Recurrent Neural Networks: 996 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 997 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 998 Improving the efficiency of forward-backward algorithm using batched 999 computation in TensorFlow: 1000 [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944) 1001 ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf)) 1002 """ 1003 1004 with ops.name_scope(name, "ctc_loss_dense", 1005 [logits, labels, label_length, logit_length]): 1006 logits = ops.convert_to_tensor(logits, name="logits") 1007 labels = ops.convert_to_tensor(labels, name="labels") 1008 label_length = ops.convert_to_tensor(label_length, name="label_length") 1009 logit_length = ops.convert_to_tensor(logit_length, name="logit_length") 1010 1011 if not logits_time_major: 1012 logits = array_ops.transpose(logits, perm=[1, 0, 2]) 1013 1014 if blank_index != 0: 1015 if blank_index < 0: 1016 blank_index += _get_dim(logits, 2) 1017 logits = array_ops.concat([ 1018 logits[:, :, blank_index:blank_index + 1], 1019 logits[:, :, :blank_index], 1020 logits[:, :, blank_index + 1:], 1021 ], 1022 axis=2) 1023 labels = array_ops.where(labels < blank_index, labels + 1, labels) 1024 1025 args = [logits, labels, label_length, logit_length] 1026 1027 if unique: 1028 unique_y, unique_idx = unique 1029 if blank_index != 0: 1030 unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, 1031 unique_y) 1032 label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 1033 max_label_length = _get_dim(unique_y, 1) 1034 label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) 1035 unique_y = array_ops.where(label_mask, unique_y, 1036 array_ops.zeros_like(unique_y)) 1037 args.extend([unique_y, unique_idx]) 1038 1039 @custom_gradient.custom_gradient 1040 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, 1041 *unique_t): 1042 """Compute CTC loss.""" 1043 logits_t.set_shape(logits.shape) 1044 labels_t.set_shape(labels.shape) 1045 label_length_t.set_shape(label_length.shape) 1046 logit_length_t.set_shape(logit_length.shape) 1047 kwargs = dict( 1048 logits=logits_t, 1049 labels=labels_t, 1050 label_length=label_length_t, 1051 logit_length=logit_length_t) 1052 if unique_t: 1053 kwargs["unique"] = unique_t 1054 result = ctc_loss_and_grad(**kwargs) 1055 def grad(grad_loss): 1056 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]] 1057 grad += [None] * (len(args) - len(grad)) 1058 return grad 1059 1060 return result[0], grad 1061 1062 return compute_ctc_loss(*args) 1063 1064 1065@tf_export("nn.collapse_repeated") 1066@dispatch.add_dispatch_support 1067def collapse_repeated(labels, seq_length, name=None): 1068 """Merge repeated labels into single labels. 1069 1070 Args: 1071 labels: Tensor of shape [batch, max value in seq_length] 1072 seq_length: Tensor of shape [batch], sequence length of each batch element. 1073 name: A name for this `Op`. Defaults to "collapse_repeated_labels". 1074 1075 Returns: 1076 A tuple `(collapsed_labels, new_seq_length)` where 1077 1078 collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated 1079 labels collapsed and padded to max_seq_length, eg: 1080 `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` 1081 1082 new_seq_length: int tensor of shape [batch] with new sequence lengths. 1083 """ 1084 1085 with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): 1086 labels = ops.convert_to_tensor(labels, name="labels") 1087 seq_length = ops.convert_to_tensor(seq_length, name="seq_length") 1088 1089 # Mask labels that don't equal previous label. 1090 label_mask = array_ops.concat([ 1091 array_ops.ones_like(labels[:, :1], dtypes.bool), 1092 math_ops.not_equal(labels[:, 1:], labels[:, :-1]) 1093 ], 1094 axis=1) 1095 1096 # Filter labels that aren't in the original sequence. 1097 maxlen = _get_dim(labels, 1) 1098 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) 1099 label_mask = math_ops.logical_and(label_mask, seq_mask) 1100 1101 # Count masks for new sequence lengths. 1102 new_seq_len = math_ops.reduce_sum( 1103 math_ops.cast(label_mask, dtypes.int32), axis=1) 1104 1105 # Mask indexes based on sequence length mask. 1106 new_maxlen = math_ops.reduce_max(new_seq_len) 1107 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) 1108 1109 # Flatten everything and mask out labels to keep and sparse indices. 1110 flat_labels = array_ops.reshape(labels, [-1]) 1111 flat_label_mask = array_ops.reshape(label_mask, [-1]) 1112 flat_idx_mask = array_ops.reshape(idx_mask, [-1]) 1113 idx = math_ops.range(_get_dim(flat_idx_mask, 0)) 1114 1115 # Scatter to flat shape. 1116 flat = array_ops.scatter_nd( 1117 indices=array_ops.expand_dims( 1118 array_ops.boolean_mask(idx, flat_idx_mask), axis=1), 1119 updates=array_ops.boolean_mask(flat_labels, flat_label_mask), 1120 shape=array_ops.shape(flat_idx_mask)) 1121 1122 # Reshape back to square batch. 1123 batch_size = _get_dim(labels, 0) 1124 new_shape = [batch_size, new_maxlen] 1125 return (array_ops.reshape(flat, new_shape), 1126 math_ops.cast(new_seq_len, seq_length.dtype)) 1127 1128 1129def dense_labels_to_sparse(dense, length): 1130 """Convert dense labels with sequence lengths to sparse tensor. 1131 1132 Args: 1133 dense: tensor of shape [batch, max_length] 1134 length: int tensor of shape [batch] The length of each sequence in dense. 1135 1136 Returns: 1137 tf.sparse.SparseTensor with values only for the valid elements of sequences. 1138 """ 1139 1140 flat_values = array_ops.reshape(dense, [-1]) 1141 flat_indices = math_ops.range( 1142 array_ops.shape(flat_values, out_type=dtypes.int64)[0]) 1143 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) 1144 flat_mask = array_ops.reshape(mask, [-1]) 1145 indices = array_ops.expand_dims( 1146 array_ops.boolean_mask(flat_indices, flat_mask), 1) 1147 values = array_ops.boolean_mask(flat_values, flat_mask) 1148 sparse = sparse_tensor.SparseTensor( 1149 indices=indices, 1150 values=math_ops.cast(values, dtypes.int32), 1151 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) 1152 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) 1153 max_length = math_ops.reduce_max(length) 1154 return sparse_tensor.SparseTensor( 1155 indices=reshaped.indices, 1156 values=reshaped.values, 1157 dense_shape=[ 1158 math_ops.cast(reshaped.dense_shape[0], dtypes.int64), 1159 math_ops.cast(max_length, dtypes.int64) 1160 ]) 1161 1162 1163@tf_export("nn.ctc_unique_labels") 1164@dispatch.add_dispatch_support 1165def ctc_unique_labels(labels, name=None): 1166 """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. 1167 1168 For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be 1169 used to preprocess labels in input pipeline to for better speed/memory use 1170 computing the ctc loss on TPU. 1171 1172 Example: 1173 ctc_unique_labels([[3, 4, 4, 3]]) -> 1174 unique labels padded with 0: [[3, 4, 0, 0]] 1175 indices of original labels in unique: [0, 1, 1, 0] 1176 1177 Args: 1178 labels: tensor of shape [batch_size, max_label_length] padded with 0. 1179 name: A name for this `Op`. Defaults to "ctc_unique_labels". 1180 1181 Returns: 1182 tuple of 1183 - unique labels, tensor of shape `[batch_size, max_label_length]` 1184 - indices into unique labels, shape `[batch_size, max_label_length]` 1185 """ 1186 1187 with ops.name_scope(name, "ctc_unique_labels", [labels]): 1188 labels = ops.convert_to_tensor(labels, name="labels") 1189 1190 def _unique(x): 1191 u = array_ops.unique(x) 1192 y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) 1193 y = math_ops.cast(y, dtypes.int64) 1194 return [y, u.idx] 1195 1196 return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32]) 1197 1198 1199def _sum_states(idx, states): 1200 """Take logsumexp for each unique state out of all label states. 1201 1202 Args: 1203 idx: tensor of shape [batch, label_length] For each sequence, indices into a 1204 set of unique labels as computed by calling unique. 1205 states: tensor of shape [frames, batch, label_length] Log probabilities for 1206 each label state. 1207 1208 Returns: 1209 tensor of shape [frames, batch_size, label_length], log probabilites summed 1210 for each unique label of the sequence. 1211 """ 1212 1213 with ops.name_scope("sum_states"): 1214 idx = ops.convert_to_tensor(idx, name="idx") 1215 num_states = _get_dim(states, 2) 1216 states = array_ops.expand_dims(states, axis=2) 1217 one_hot = array_ops.one_hot( 1218 idx, 1219 depth=num_states, 1220 on_value=0.0, 1221 off_value=math_ops.log(0.0), 1222 axis=1) 1223 return math_ops.reduce_logsumexp(states + one_hot, axis=-1) 1224 1225 1226def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, 1227 final_state_log_probs, observed_log_probs, 1228 sequence_length): 1229 """Forward-backward algorithm computed in log domain. 1230 1231 Args: 1232 state_trans_log_probs: tensor of shape [states, states] or if different 1233 transition matrix per batch [batch_size, states, states] 1234 initial_state_log_probs: tensor of shape [batch_size, states] 1235 final_state_log_probs: tensor of shape [batch_size, states] 1236 observed_log_probs: tensor of shape [frames, batch_size, states] 1237 sequence_length: tensor of shape [batch_size] 1238 1239 Returns: 1240 forward backward log probabilites: tensor of shape [frames, batch, states] 1241 log_likelihood: tensor of shape [batch_size] 1242 1243 Raises: 1244 ValueError: If state_trans_log_probs has unknown or incorrect rank. 1245 """ 1246 1247 if state_trans_log_probs.shape.ndims == 2: 1248 perm = [1, 0] 1249 elif state_trans_log_probs.shape.ndims == 3: 1250 perm = [0, 2, 1] 1251 else: 1252 raise ValueError( 1253 "state_trans_log_probs rank must be known and == 2 or 3, is: %s" % 1254 state_trans_log_probs.shape.ndims) 1255 1256 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) 1257 batch_size = _get_dim(observed_log_probs, 1) 1258 1259 def _forward(state_log_prob, obs_log_prob): 1260 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1261 state_log_prob += state_trans_log_probs 1262 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1263 state_log_prob += obs_log_prob 1264 log_prob_sum = math_ops.reduce_logsumexp( 1265 state_log_prob, axis=-1, keepdims=True) 1266 state_log_prob -= log_prob_sum 1267 return state_log_prob 1268 1269 fwd = _scan( 1270 _forward, observed_log_probs, initial_state_log_probs, inclusive=True) 1271 1272 def _backward(accs, elems): 1273 """Calculate log probs and cumulative sum masked for sequence length.""" 1274 state_log_prob, cum_log_sum = accs 1275 obs_log_prob, mask = elems 1276 state_log_prob += obs_log_prob 1277 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1278 state_log_prob += bwd_state_trans_log_probs 1279 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1280 1281 log_prob_sum = math_ops.reduce_logsumexp( 1282 state_log_prob, axis=-1, keepdims=True) 1283 state_log_prob -= log_prob_sum 1284 1285 cum_log_sum += array_ops.squeeze(log_prob_sum) * mask 1286 batched_mask = array_ops.expand_dims(mask, axis=1) 1287 out = state_log_prob * batched_mask 1288 out += final_state_log_probs * (1.0 - batched_mask) 1289 return out, cum_log_sum 1290 1291 zero_log_sum = array_ops.zeros([batch_size]) 1292 maxlen = _get_dim(observed_log_probs, 0) 1293 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) 1294 mask = array_ops.transpose(mask, perm=[1, 0]) 1295 1296 bwd, cum_log_sum = _scan( 1297 _backward, (observed_log_probs, mask), 1298 (final_state_log_probs, zero_log_sum), 1299 reverse=True, 1300 inclusive=True) 1301 1302 fwd_bwd_log_probs = fwd[1:] + bwd[1:] 1303 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( 1304 fwd_bwd_log_probs, axis=2, keepdims=True) 1305 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum 1306 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) 1307 1308 log_likelihood = bwd[0, :, 0] + cum_log_sum[0] 1309 1310 return fwd_bwd_log_probs, log_likelihood 1311 1312 1313# TODO(tombagby): This is currently faster for the ctc implementation than using 1314# functional_ops.scan, but could be replaced by that or something similar if 1315# things change. 1316def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): 1317 """Repeatedly applies callable `fn` to a sequence of elements. 1318 1319 Implemented by functional_ops.While, tpu friendly, no gradient. 1320 1321 This is similar to functional_ops.scan but significantly faster on tpu/gpu 1322 for the forward backward use case. 1323 1324 Examples: 1325 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] 1326 1327 Multiple accumulators: 1328 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) 1329 1330 Multiple inputs: 1331 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) 1332 1333 Args: 1334 fn: callable, fn(accumulators, element) return new accumulator values. The 1335 (possibly nested) sequence of accumulators is the same as `initial` and 1336 the return value must have the same structure. 1337 elems: A (possibly nested) tensor which will be unpacked along the first 1338 dimension. The resulting slices will be the second argument to fn. The 1339 first dimension of all nested input tensors must be the same. 1340 initial: A tensor or (possibly nested) sequence of tensors with initial 1341 values for the accumulators. 1342 reverse: (optional) True enables scan and output elems in reverse order. 1343 inclusive: (optional) True includes the initial accumulator values in the 1344 output. Length of output will be len(elem sequence) + 1. Not meaningful if 1345 final_only is True. 1346 final_only: (optional) When True, return only the final accumulated values, 1347 not the concatenation of accumulated values for each input. 1348 1349 Returns: 1350 A (possibly nested) sequence of tensors with the results of applying fn 1351 to tensors unpacked from elems and previous accumulator values. 1352 """ 1353 1354 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] 1355 num_elems = array_ops.shape(flat_elems[0])[0] 1356 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) 1357 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] 1358 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) 1359 accum_dtypes = [x.dtype for x in flat_initial] 1360 num_accums = len(flat_initial) 1361 1362 # Types for counter, [outputs], [accumulators] loop arguments. 1363 if final_only: 1364 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes 1365 else: 1366 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes 1367 1368 # TODO(tombagby): Update to tfe.defun 1369 def cond(i, num_elems, *args): 1370 del args 1371 return i >= 0 if reverse else i < num_elems 1372 1373 # The loop *args are [output tensors] + [accumulator tensors] which must 1374 # be paired. Each output corresponds to one accumulator. 1375 def body(i, num_elems, *args): 1376 """Loop body.""" 1377 i.set_shape([]) 1378 if final_only: 1379 accum = args 1380 else: 1381 out, accum = args[:num_accums], args[num_accums:] 1382 slices = [array_ops.gather(e, i) for e in flat_elems] 1383 accum = fn(pack(accum), pack_elems(slices)) 1384 flat_accum = nest.flatten(accum) 1385 if final_only: 1386 new_out = [] 1387 else: 1388 update_i = i + 1 if inclusive and not reverse else i 1389 new_out = [ 1390 inplace_ops.alias_inplace_update(x, update_i, y) 1391 for x, y in zip(out, flat_accum) 1392 ] 1393 i = i - 1 if reverse else i + 1 1394 return [i, num_elems] + new_out + flat_accum 1395 1396 init_i = ( 1397 array_ops.shape(flat_elems[0])[0] - 1398 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) 1399 outputs = [] 1400 if not final_only: 1401 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) 1402 for initial_accum in flat_initial: 1403 out_shape = array_ops.concat( 1404 [[num_outputs], array_ops.shape(initial_accum)], 0) 1405 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) 1406 if inclusive: 1407 out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0), 1408 initial_accum) 1409 outputs.append(out) 1410 loop_in = [init_i, num_elems] + outputs + flat_initial 1411 hostmem = [ 1412 i for i, x in enumerate(loop_in) 1413 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) 1414 ] 1415 1416 if context.executing_eagerly(): 1417 loop_results = loop_in 1418 while cond(*loop_results): 1419 loop_results = body(*loop_results) 1420 else: 1421 # TODO(tombagby): Update to while_v2. 1422 cond = function.Defun(*loop_dtypes)(cond) 1423 body = function.Defun(*loop_dtypes)(body) 1424 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) 1425 out = loop_results[2:num_accums + 2] 1426 return pack(out) 1427 1428 1429def _get_dim(tensor, i): 1430 """Get value of tensor shape[i] preferring static value if available.""" 1431 return tensor_shape.dimension_value( 1432 tensor.shape[i]) or array_ops.shape(tensor)[i] 1433