1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""CTC (Connectionist Temporal Classification) Operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import uuid 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import function as function_eager 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import device 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import function 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_shape 33 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import custom_gradient 36from tensorflow.python.ops import functional_ops 37from tensorflow.python.ops import gen_ctc_ops 38from tensorflow.python.ops import inplace_ops 39from tensorflow.python.ops import linalg_ops 40from tensorflow.python.ops import map_fn 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import nn_ops 43from tensorflow.python.ops import sparse_ops 44from tensorflow.python.ops.nn_grad import _BroadcastMul 45from tensorflow.python.util import deprecation 46from tensorflow.python.util import dispatch 47from tensorflow.python.util import nest 48from tensorflow.python.util.tf_export import tf_export 49 50_DEFUN_API_NAME_ATTRIBUTE = "api_implements" 51_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device" 52_CPU_DEVICE_NAME = "CPU" 53_GPU_DEVICE_NAME = "GPU" 54 55 56def _get_context_device_type(): 57 """Parse the current context and return the device type, eg CPU/GPU.""" 58 current_device = context.context().device_name 59 if current_device is None: 60 return None 61 return device.DeviceSpec.from_string(current_device).device_type 62 63 64def _generate_defun_backend(unique_api_name, preferred_device, func): 65 function_attributes = { 66 _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, 67 _DEFUN_DEVICE_ATTRIBUTE: preferred_device, 68 } 69 return function_eager.defun_with_attributes( 70 func=func, attributes=function_attributes, autograph=False) 71 72# pylint: disable=protected-access, invalid-name 73@tf_export(v1=["nn.ctc_loss"]) 74@dispatch.add_dispatch_support 75def ctc_loss(labels, 76 inputs=None, 77 sequence_length=None, 78 preprocess_collapse_repeated=False, 79 ctc_merge_repeated=True, 80 ignore_longer_outputs_than_inputs=False, 81 time_major=True, 82 logits=None): 83 """Computes the CTC (Connectionist Temporal Classification) Loss. 84 85 This op implements the CTC loss as presented in (Graves et al., 2006). 86 87 Input requirements: 88 89 ``` 90 sequence_length(b) <= time for all b 91 92 max(labels.indices(labels.indices[:, 1] == b, 2)) 93 <= sequence_length(b) for all b. 94 ``` 95 96 Notes: 97 98 This class performs the softmax operation for you, so inputs should 99 be e.g. linear projections of outputs by an LSTM. 100 101 The `inputs` Tensor's innermost dimension size, `num_classes`, represents 102 `num_labels + 1` classes, where num_labels is the number of true labels, and 103 the largest value `(num_classes - 1)` is reserved for the blank label. 104 105 For example, for a vocabulary containing 3 labels `[a, b, c]`, 106 `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`. 107 108 Regarding the arguments `preprocess_collapse_repeated` and 109 `ctc_merge_repeated`: 110 111 If `preprocess_collapse_repeated` is True, then a preprocessing step runs 112 before loss calculation, wherein repeated labels passed to the loss 113 are merged into single labels. This is useful if the training labels come 114 from, e.g., forced alignments and therefore have unnecessary repetitions. 115 116 If `ctc_merge_repeated` is set False, then deep within the CTC calculation, 117 repeated non-blank labels will not be merged and are interpreted 118 as individual labels. This is a simplified (non-standard) version of CTC. 119 120 Here is a table of the (roughly) expected first order behavior: 121 122 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True` 123 124 Classical CTC behavior: Outputs true repeated classes with blanks in 125 between, and can also output repeated classes with no blanks in 126 between that need to be collapsed by the decoder. 127 128 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False` 129 130 Never learns to output repeated classes, as they are collapsed 131 in the input labels before training. 132 133 * `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False` 134 135 Outputs repeated classes with blanks in between, but generally does not 136 require the decoder to collapse/merge repeated classes. 137 138 * `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True` 139 140 Untested. Very likely will not learn to output repeated classes. 141 142 The `ignore_longer_outputs_than_inputs` option allows to specify the behavior 143 of the CTCLoss when dealing with sequences that have longer outputs than 144 inputs. If true, the CTCLoss will simply return zero gradient for those 145 items, otherwise an InvalidArgument error is returned, stopping training. 146 147 Args: 148 labels: An `int32` `SparseTensor`. 149 `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores the id 150 for (batch b, time t). `labels.values[i]` must take on values in `[0, 151 num_labels)`. See `core/ops/ctc_ops.cc` for more details. 152 inputs: 3-D `float` `Tensor`. 153 If time_major == False, this will be a `Tensor` shaped: `[batch_size, 154 max_time, num_classes]`. 155 If time_major == True (default), this will be a `Tensor` shaped: 156 `[max_time, batch_size, num_classes]`. The logits. 157 sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence 158 lengths. 159 preprocess_collapse_repeated: Boolean. Default: False. If True, repeated 160 labels are collapsed prior to the CTC calculation. 161 ctc_merge_repeated: Boolean. Default: True. 162 ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, 163 sequences with longer outputs than inputs will be ignored. 164 time_major: The shape format of the `inputs` Tensors. If True, these 165 `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, 166 these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. 167 Using `time_major = True` (default) is a bit more efficient because it 168 avoids transposes at the beginning of the ctc_loss calculation. However, 169 most TensorFlow data is batch-major, so by this function also accepts 170 inputs in batch-major form. 171 logits: Alias for inputs. 172 173 Returns: 174 A 1-D `float` `Tensor`, size `[batch]`, containing the negative log 175 probabilities. 176 177 Raises: 178 TypeError: if labels is not a `SparseTensor`. 179 180 References: 181 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 182 with Recurrent Neural Networks: 183 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 184 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 185 """ 186 return _ctc_loss_impl( 187 labels, 188 inputs, 189 sequence_length, 190 preprocess_collapse_repeated, 191 ctc_merge_repeated, 192 ignore_longer_outputs_than_inputs, 193 time_major, 194 logits, 195 use_cudnn=False) 196 197 198def _ctc_loss_impl(labels, 199 inputs=None, 200 sequence_length=None, 201 preprocess_collapse_repeated=False, 202 ctc_merge_repeated=True, 203 ignore_longer_outputs_than_inputs=False, 204 time_major=True, 205 logits=None, 206 use_cudnn=False): 207 # Helper function of ctc_loss with one additional param: 208 # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank 209 # index has to be 0. 210 211 # The second, third, etc output tensors contain the gradients. We use it in 212 # _CTCLossGrad() below. 213 if not isinstance(labels, sparse_tensor.SparseTensor): 214 raise TypeError("Expected labels (first argument) to be a SparseTensor") 215 216 # For internal calculations, we transpose to [time, batch, num_classes] 217 inputs = deprecation.deprecated_argument_lookup("logits", logits, "inputs", 218 inputs) 219 if not time_major: 220 inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) 221 222 # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the 223 # blank index to be 0, but v1 views it as the last index. 224 if use_cudnn: 225 ctc_loss_func = gen_ctc_ops.ctc_loss_v2 226 else: 227 ctc_loss_func = gen_ctc_ops.ctc_loss 228 229 loss, _ = ctc_loss_func( 230 inputs, 231 labels.indices, 232 labels.values, 233 sequence_length, 234 preprocess_collapse_repeated=preprocess_collapse_repeated, 235 ctc_merge_repeated=ctc_merge_repeated, 236 ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) 237 238 return loss 239 240# pylint: disable=unused-argument 241def _CTCLossGradImpl(op, grad_loss, _): 242 # Outputs are: loss, grad 243 # 244 # Currently there is no way to take the second derivative of this op 245 # due to the fused implementation's interaction with tf.gradients(), 246 # so we make sure we prevent silently incorrect results by raising 247 # an error if the second derivative is requested via prevent_gradient. 248 grad_without_gradient = array_ops.prevent_gradient( 249 op.outputs[1], 250 message="Currently there is no way to take the second " 251 " derivative of ctc_loss due to the fused implementation's interaction " 252 " with tf.gradients()") 253 # Return gradient for inputs and None for 254 # labels_indices, labels_values and sequence_length 255 return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] 256 257 258# pylint: disable=unused-argument 259@ops.RegisterGradient("CTCLoss") 260def _CTCLossGrad(op, grad_loss, _): 261 """The derivative provided by CTC Loss. 262 263 Args: 264 op: the CTCLoss op. 265 grad_loss: The backprop for cost. 266 267 Returns: 268 The CTC Loss gradient. 269 """ 270 return _CTCLossGradImpl(op, grad_loss, _) 271 272 273# pylint: disable=unused-argument 274@ops.RegisterGradient("CTCLossV2") 275def _CTCLossV2Grad(op, grad_loss, _): 276 """The derivative provided by CTC Loss V2. 277 278 Args: 279 op: the CTCLossV2 op. 280 grad_loss: The backprop for cost. 281 282 Returns: 283 The CTC Loss V2 gradient. 284 """ 285 return _CTCLossGradImpl(op, grad_loss, _) 286 287 288@tf_export("nn.ctc_greedy_decoder") 289@dispatch.add_dispatch_support 290def ctc_greedy_decoder(inputs, 291 sequence_length, 292 merge_repeated=True, 293 blank_index=None): 294 """Performs greedy decoding on the logits given in input (best path). 295 296 Given a tensor as `inputs`, the `blank_index` parameter defines the class 297 index of the blank symbol. 298 299 For example: 300 301 If `blank_index` is equal to 1: 302 303 >>> inf = float("inf") 304 >>> logits = tf.constant([[[ 0., -inf, -inf], 305 ... [ -2.3, -inf, -0.1]], 306 ... [[ -inf, -0.5, -inf], 307 ... [ -inf, -inf, -0.1]], 308 ... [[ -inf, -inf, -inf], 309 ... [ -0.1, -inf, -2.3]]]) 310 >>> seq_lens = tf.constant([2, 3]) 311 >>> outputs = tf.nn.ctc_greedy_decoder( 312 ... logits, 313 ... seq_lens, 314 ... blank_index=1) 315 316 Notes: 317 318 - Regardless of the value of `merge_repeated`, if an index of a 319 given time and batch corresponds to the `blank_index`, no new 320 element is emitted. 321 - Default `blank_index` is `(num_classes - 1)`, unless overriden. 322 323 If `merge_repeated` is `True`, merge repeated classes in output. 324 This means that if consecutive logits' maximum indices are the same, 325 only the first of these is emitted. The sequence `A B B * B * B` (where '*' 326 is the blank label) becomes 327 328 * `A B B B` if `merge_repeated=True`. 329 * `A B B B B` if `merge_repeated=False`. 330 331 Args: 332 inputs: 3-D `float` `Tensor` sized `[max_time, batch_size, num_classes]`. 333 The logits. 334 sequence_length: 1-D `int32` vector containing sequence lengths, having size 335 `[batch_size]`. 336 merge_repeated: Boolean. Default: True. 337 blank_index: (Optional). Default: `num_classes - 1`. Define the class index 338 to use for the blank label. Negative values will start from num_classes, 339 ie, -1 will reproduce the ctc_greedy_decoder behavior of using 340 num_classes - 1 for the blank symbol, which corresponds to the default. 341 342 Returns: 343 A tuple `(decoded, neg_sum_logits)` where 344 345 decoded: A single-element list. `decoded[0]` 346 is an `SparseTensor` containing the decoded outputs s.t.: 347 348 `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. 349 The rows store: `[batch, time]`. 350 351 `decoded.values`: Values vector, size `(total_decoded_outputs)`. 352 The vector stores the decoded classes. 353 354 `decoded.dense_shape`: Shape vector, size `(2)`. 355 The shape values are: `[batch_size, max_decoded_length]` 356 357 neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the 358 sequence found, the negative of the sum of the greatest logit at each 359 timeframe. 360 """ 361 362 outputs = gen_ctc_ops.ctc_greedy_decoder( 363 inputs, 364 sequence_length, 365 merge_repeated=merge_repeated, 366 blank_index=blank_index) 367 (decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs 368 return ([sparse_tensor.SparseTensor(decoded_ix, decoded_val, 369 decoded_shape)], log_probabilities) 370 371 372@tf_export(v1=["nn.ctc_beam_search_decoder"]) 373@dispatch.add_dispatch_support 374def ctc_beam_search_decoder(inputs, 375 sequence_length, 376 beam_width=100, 377 top_paths=1, 378 merge_repeated=True): 379 """Performs beam search decoding on the logits given in input. 380 381 **Note** The `ctc_greedy_decoder` is a special case of the 382 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 383 that decoder is faster for this special case). 384 385 If `merge_repeated` is `True`, merge repeated classes in the output beams. 386 This means that if consecutive entries in a beam are the same, 387 only the first of these is emitted. That is, when the sequence is 388 `A B B * B * B` (where '*' is the blank label), the return value is: 389 390 * `A B` if `merge_repeated = True`. 391 * `A B B B` if `merge_repeated = False`. 392 393 Args: 394 inputs: 3-D `float` `Tensor`, size `[max_time x batch_size x num_classes]`. 395 The logits. 396 sequence_length: 1-D `int32` vector containing sequence lengths, having size 397 `[batch_size]`. 398 beam_width: An int scalar >= 0 (beam search beam width). 399 top_paths: An int scalar >= 0, <= beam_width (controls output size). 400 merge_repeated: Boolean. Default: True. 401 402 Returns: 403 A tuple `(decoded, log_probabilities)` where 404 405 decoded: A list of length top_paths, where `decoded[j]` 406 is a `SparseTensor` containing the decoded outputs: 407 408 `decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)` 409 The rows store: [batch, time]. 410 411 `decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`. 412 The vector stores the decoded classes for beam j. 413 414 `decoded[j].dense_shape`: Shape vector, size `(2)`. 415 The shape values are: `[batch_size, max_decoded_length[j]]`. 416 417 log_probability: A `float` matrix `(batch_size x top_paths)` containing 418 sequence log-probabilities. 419 """ 420 421 decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( 422 gen_ctc_ops.ctc_beam_search_decoder( 423 inputs, 424 sequence_length, 425 beam_width=beam_width, 426 top_paths=top_paths, 427 merge_repeated=merge_repeated)) 428 429 return ([ 430 sparse_tensor.SparseTensor(ix, val, shape) 431 for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes) 432 ], log_probabilities) 433 434 435@tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) 436@dispatch.add_dispatch_support 437def ctc_beam_search_decoder_v2(inputs, 438 sequence_length, 439 beam_width=100, 440 top_paths=1): 441 """Performs beam search decoding on the logits given in input. 442 443 **Note** The `ctc_greedy_decoder` is a special case of the 444 `ctc_beam_search_decoder` with `top_paths=1` and `beam_width=1` (but 445 that decoder is faster for this special case). 446 447 Args: 448 inputs: 3-D `float` `Tensor`, size `[max_time, batch_size, num_classes]`. 449 The logits. 450 sequence_length: 1-D `int32` vector containing sequence lengths, having size 451 `[batch_size]`. 452 beam_width: An int scalar >= 0 (beam search beam width). 453 top_paths: An int scalar >= 0, <= beam_width (controls output size). 454 455 Returns: 456 A tuple `(decoded, log_probabilities)` where 457 458 decoded: A list of length top_paths, where `decoded[j]` 459 is a `SparseTensor` containing the decoded outputs: 460 461 `decoded[j].indices`: Indices matrix `[total_decoded_outputs[j], 2]`; 462 The rows store: `[batch, time]`. 463 464 `decoded[j].values`: Values vector, size `[total_decoded_outputs[j]]`. 465 The vector stores the decoded classes for beam `j`. 466 467 `decoded[j].dense_shape`: Shape vector, size `(2)`. 468 The shape values are: `[batch_size, max_decoded_length[j]]`. 469 470 log_probability: A `float` matrix `[batch_size, top_paths]` containing 471 sequence log-probabilities. 472 """ 473 474 # Note, merge_repeated is an invalid optimization that is removed from the 475 # public API: it returns low probability paths. 476 return ctc_beam_search_decoder( 477 inputs, 478 sequence_length=sequence_length, 479 beam_width=beam_width, 480 top_paths=top_paths, 481 merge_repeated=False) 482 483 484ops.NotDifferentiable("CTCGreedyDecoder") 485ops.NotDifferentiable("CTCBeamSearchDecoder") 486 487 488def _ctc_state_trans(label_seq): 489 """Compute CTC alignment model transition matrix. 490 491 Args: 492 label_seq: tensor of shape [batch_size, max_seq_length] 493 494 Returns: 495 tensor of shape [batch_size, states, states] with a state transition matrix 496 computed for each sequence of the batch. 497 """ 498 499 with ops.name_scope("ctc_state_trans"): 500 label_seq = ops.convert_to_tensor(label_seq, name="label_seq") 501 batch_size = _get_dim(label_seq, 0) 502 num_labels = _get_dim(label_seq, 1) 503 504 num_label_states = num_labels + 1 505 num_states = 2 * num_label_states 506 507 label_states = math_ops.range(num_label_states) 508 blank_states = label_states + num_label_states 509 510 # Start state to first label. 511 start_to_label = [[1, 0]] 512 513 # Blank to label transitions. 514 blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1) 515 516 # Label to blank transitions. 517 label_to_blank = array_ops.stack([blank_states, label_states], 1) 518 519 # Scatter transitions that don't depend on sequence. 520 indices = array_ops.concat([start_to_label, blank_to_label, label_to_blank], 521 0) 522 values = array_ops.ones([_get_dim(indices, 0)]) 523 trans = array_ops.scatter_nd( 524 indices, values, shape=[num_states, num_states]) 525 trans += linalg_ops.eye(num_states) # Self-loops. 526 527 # Label to label transitions. Disallow transitions between repeated labels 528 # with no blank state in between. 529 batch_idx = array_ops.zeros_like(label_states[2:]) 530 indices = array_ops.stack([batch_idx, label_states[2:], label_states[1:-1]], 531 1) 532 indices = array_ops.tile( 533 array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) 534 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] 535 indices += array_ops.expand_dims(batch_idx, 1) 536 repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) 537 values = 1.0 - math_ops.cast(repeats, dtypes.float32) 538 batched_shape = [batch_size, num_states, num_states] 539 label_to_label = array_ops.scatter_nd(indices, values, batched_shape) 540 541 return array_ops.expand_dims(trans, 0) + label_to_label 542 543 544def ctc_state_log_probs(seq_lengths, max_seq_length): 545 """Computes CTC alignment initial and final state log probabilities. 546 547 Create the initial/final state values directly as log values to avoid 548 having to take a float64 log on tpu (which does not exist). 549 550 Args: 551 seq_lengths: int tensor of shape [batch_size], seq lengths in the batch. 552 max_seq_length: int, max sequence length possible. 553 554 Returns: 555 initial_state_log_probs, final_state_log_probs 556 """ 557 558 batch_size = _get_dim(seq_lengths, 0) 559 num_label_states = max_seq_length + 1 560 num_duration_states = 2 561 num_states = num_duration_states * num_label_states 562 log_0 = math_ops.cast( 563 math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307), dtypes.float32) 564 565 initial_state_log_probs = array_ops.one_hot( 566 indices=array_ops.zeros([batch_size], dtype=dtypes.int32), 567 depth=num_states, 568 on_value=0.0, 569 off_value=log_0, 570 axis=1) 571 572 label_final_state_mask = array_ops.one_hot( 573 seq_lengths, depth=num_label_states, axis=0) 574 duration_final_state_mask = array_ops.ones( 575 [num_duration_states, 1, batch_size]) 576 final_state_mask = duration_final_state_mask * label_final_state_mask 577 final_state_log_probs = (1.0 - final_state_mask) * log_0 578 final_state_log_probs = array_ops.reshape(final_state_log_probs, 579 [num_states, batch_size]) 580 581 return initial_state_log_probs, array_ops.transpose(final_state_log_probs) 582 583 584def _ilabel_to_state(labels, num_labels, ilabel_log_probs): 585 """Project ilabel log probs to state log probs.""" 586 587 num_label_states = _get_dim(labels, 1) 588 blank = ilabel_log_probs[:, :, :1] 589 blank = array_ops.tile(blank, [1, 1, num_label_states + 1]) 590 one_hot = array_ops.one_hot(labels, depth=num_labels) 591 one_hot = array_ops.expand_dims(one_hot, axis=0) 592 ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2) 593 state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3) 594 state_log_probs = array_ops.concat([state_log_probs, blank], axis=2) 595 return array_ops.pad( 596 state_log_probs, [[0, 0], [0, 0], [1, 0]], 597 constant_values=math_ops.log(0.0)) 598 599 600def _state_to_olabel(labels, num_labels, states): 601 """Sum state log probs to ilabel log probs.""" 602 603 num_label_states = _get_dim(labels, 1) + 1 604 label_states = states[:, :, 1:num_label_states] 605 blank_states = states[:, :, num_label_states:] 606 one_hot = array_ops.one_hot( 607 labels - 1, 608 depth=(num_labels - 1), 609 on_value=0.0, 610 off_value=math_ops.log(0.0)) 611 one_hot = array_ops.expand_dims(one_hot, axis=0) 612 label_states = array_ops.expand_dims(label_states, axis=3) 613 label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2) 614 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 615 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 616 617 618# pylint: disable=redefined-outer-name 619def _state_to_olabel_unique(labels, num_labels, states, unique): 620 """Sum state log probs to ilabel log probs using unique label indices.""" 621 622 num_label_states = _get_dim(labels, 1) + 1 623 label_states = states[:, :, 1:num_label_states] 624 blank_states = states[:, :, num_label_states:] 625 626 unique_y, unique_idx = unique 627 mul_reduce = _sum_states(unique_idx, label_states) 628 629 num_frames = states.shape[0] 630 batch_size = states.shape[1] 631 num_states = num_label_states - 1 632 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) 633 batch_state_major = array_ops.reshape(batch_state_major, 634 [batch_size * num_states, num_frames]) 635 batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels 636 indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) 637 indices = array_ops.reshape(indices, [-1, 1]) 638 scatter = array_ops.scatter_nd( 639 indices=indices, 640 updates=batch_state_major, 641 shape=[batch_size * num_labels, num_frames]) 642 scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) 643 644 mask = array_ops.ones_like(batch_state_major, dtype=dtypes.bool) 645 mask = array_ops.scatter_nd( 646 indices=indices, 647 updates=mask, 648 shape=[batch_size * num_labels, num_frames]) 649 mask = array_ops.reshape(mask, [batch_size, num_labels, num_frames]) 650 651 scatter = array_ops.where( 652 mask, scatter, 653 array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0))) 654 655 label_olabels = array_ops.transpose(scatter, [2, 0, 1]) 656 label_olabels = label_olabels[:, :, 1:] 657 658 blank_olabels = math_ops.reduce_logsumexp(blank_states, axis=2, keepdims=True) 659 660 return array_ops.concat([blank_olabels, label_olabels], axis=-1) 661 662 663def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None): 664 """Computes the CTC loss and gradients. 665 666 Most users will want fwd_bwd.ctc_loss 667 668 This function returns the computed gradient, it does not have a gradient 669 of its own defined. 670 671 Args: 672 logits: tensor of shape [frames, batch_size, num_labels] 673 labels: tensor of shape [batch_size, max_label_seq_length] 674 label_length: tensor of shape [batch_size] Length of reference label 675 sequence in labels. 676 logit_length: tensor of shape [batch_size] Length of input sequence in 677 logits. 678 unique: (optional) unique label indices as computed by unique(labels) If 679 supplied, enables an implementation that is faster and more memory 680 efficient on TPU. 681 682 Returns: 683 loss: tensor of shape [batch_size] 684 gradient: tensor of shape [frames, batch_size, num_labels] 685 """ 686 687 num_labels = _get_dim(logits, 2) 688 max_label_seq_length = _get_dim(labels, 1) 689 690 ilabel_log_probs = nn_ops.log_softmax(logits) 691 state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs) 692 state_trans_probs = _ctc_state_trans(labels) 693 initial_state_log_probs, final_state_log_probs = ctc_state_log_probs( 694 label_length, max_label_seq_length) 695 fwd_bwd_log_probs, log_likelihood = _forward_backward_log( 696 state_trans_log_probs=math_ops.log(state_trans_probs), 697 initial_state_log_probs=initial_state_log_probs, 698 final_state_log_probs=final_state_log_probs, 699 observed_log_probs=state_log_probs, 700 sequence_length=logit_length) 701 702 if unique: 703 olabel_log_probs = _state_to_olabel_unique(labels, num_labels, 704 fwd_bwd_log_probs, unique) 705 else: 706 olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs) 707 708 grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs) 709 710 # Applies the sequence mask for the gradient. It is enough to appply the mask 711 # only for ilabel_log_probs because olabel_log_probs already consider the 712 # mask. However, it is just safe and clean to apply it for the gradient. 713 max_logit_length = _get_dim(logits, 0) 714 logit_mask = array_ops.sequence_mask(logit_length, max_logit_length, 715 dtypes.float32) 716 logit_mask = array_ops.transpose(logit_mask, perm=[1, 0]) 717 logit_mask = array_ops.expand_dims(logit_mask, axis=2) 718 grad *= logit_mask 719 720 loss = -log_likelihood 721 return loss, grad 722 723 724def _ctc_loss_grad(op, grad_loss, _): 725 grad = op.outputs[1] 726 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad] 727 grad += [None] * (len(op.inputs) - len(grad)) 728 return grad 729 730 731def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major, 732 blank_index): 733 part_before = logits[:, :, :blank_index] 734 part_after = logits[:, :, blank_index + 1:] 735 part_blank = logits[:, :, blank_index:blank_index + 1] 736 logits = array_ops.concat([part_before, part_after, part_blank], axis=2) 737 labels = sparse_tensor.SparseTensor( 738 labels.indices, 739 array_ops.where(labels.values < blank_index, labels.values, 740 labels.values - 1), labels.dense_shape) 741 return _ctc_loss_impl( 742 labels=labels, 743 inputs=logits, 744 sequence_length=logit_length, 745 time_major=logits_time_major, 746 use_cudnn=False) 747 748 749def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major, 750 blank_index): 751 part_before = logits[:, :, :blank_index] 752 part_after = logits[:, :, blank_index + 1:] 753 part_blank = logits[:, :, blank_index:blank_index + 1] 754 logits = array_ops.concat([part_blank, part_before, part_after], axis=2) 755 labels = sparse_tensor.SparseTensor( 756 labels.indices, 757 array_ops.where(labels.values < blank_index, labels.values + 1, 758 labels.values), labels.dense_shape) 759 return _ctc_loss_impl( 760 labels=labels, 761 inputs=logits, 762 sequence_length=logit_length, 763 time_major=logits_time_major, 764 use_cudnn=True) 765 766 767def _ctc_loss_shape(op): 768 return [op.inputs[2].get_shape(), op.inputs[0].get_shape()] 769 770 771# pylint: disable=protected-access, invalid-name 772@tf_export(v1=["nn.ctc_loss_v2"]) 773@dispatch.add_dispatch_support 774def ctc_loss_v2(labels, 775 logits, 776 label_length, 777 logit_length, 778 logits_time_major=True, 779 unique=None, 780 blank_index=None, 781 name=None): 782 """Computes CTC (Connectionist Temporal Classification) loss. 783 784 This op implements the CTC loss as presented in (Graves et al., 2006). 785 786 Notes: 787 788 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 789 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 790 - Labels may be supplied as either a dense, zero-padded tensor with a 791 vector of label sequence lengths OR as a SparseTensor. 792 - On TPU and GPU: Only dense padded labels are supported. 793 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 794 a SparseTensor will be significantly faster. 795 - Default blank label is 0 rather num_classes - 1, unless overridden by 796 blank_index. 797 798 Args: 799 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 800 logits: tensor of shape [frames, batch_size, num_labels], if 801 logits_time_major == False, shape is [batch_size, frames, num_labels]. 802 label_length: tensor of shape [batch_size], None if labels is SparseTensor 803 Length of reference label sequence in labels. 804 logit_length: tensor of shape [batch_size] Length of input sequence in 805 logits. 806 logits_time_major: (optional) If True (default), logits is shaped [time, 807 batch, logits]. If False, shape is [batch, time, logits] 808 unique: (optional) Unique label indices as computed by 809 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 810 implementation on TPU. 811 blank_index: (optional) Set the class index to use for the blank label. 812 Negative values will start from num_classes, ie, -1 will reproduce the 813 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 814 some memory/performance overhead to switching from the default of 0 as an 815 additional shifted copy of the logits may be created. 816 name: A name for this `Op`. Defaults to "ctc_loss_dense". 817 818 Returns: 819 loss: tensor of shape [batch_size], negative log probabilities. 820 821 References: 822 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 823 with Recurrent Neural Networks: 824 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 825 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 826 """ 827 if isinstance(labels, sparse_tensor.SparseTensor): 828 if blank_index is None: 829 raise ValueError( 830 "blank_index must be given when using SparseTensor labels.") 831 832 if blank_index < 0: 833 blank_index += _get_dim(logits, 2) 834 835 if blank_index != _get_dim(logits, 2) - 1: 836 logits = array_ops.concat([ 837 logits[:, :, :blank_index], 838 logits[:, :, blank_index + 1:], 839 logits[:, :, blank_index:blank_index + 1], 840 ], 841 axis=2) 842 labels = sparse_tensor.SparseTensor( 843 labels.indices, 844 array_ops.where(labels.values < blank_index, labels.values, 845 labels.values - 1), labels.dense_shape) 846 847 return ctc_loss( 848 labels=labels, 849 inputs=logits, 850 sequence_length=logit_length, 851 time_major=logits_time_major) 852 853 if blank_index is None: 854 blank_index = 0 855 856 return ctc_loss_dense( 857 labels=labels, 858 logits=logits, 859 label_length=label_length, 860 logit_length=logit_length, 861 logits_time_major=logits_time_major, 862 unique=unique, 863 blank_index=blank_index, 864 name=name) 865 866 867@tf_export("nn.ctc_loss", v1=[]) 868@dispatch.add_dispatch_support 869def ctc_loss_v3(labels, 870 logits, 871 label_length, 872 logit_length, 873 logits_time_major=True, 874 unique=None, 875 blank_index=None, 876 name=None): 877 """Computes CTC (Connectionist Temporal Classification) loss. 878 879 This op implements the CTC loss as presented in (Graves et al., 2006). 880 881 Notes: 882 883 - Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss 884 setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True 885 - Labels may be supplied as either a dense, zero-padded tensor with a 886 vector of label sequence lengths OR as a SparseTensor. 887 - On TPU and GPU: Only dense padded labels are supported. 888 - On CPU: Caller may use SparseTensor or dense padded labels but calling with 889 a SparseTensor will be significantly faster. 890 - Default blank label is 0 rather num_classes - 1, unless overridden by 891 blank_index. 892 893 Args: 894 labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor 895 logits: tensor of shape [frames, batch_size, num_labels], if 896 logits_time_major == False, shape is [batch_size, frames, num_labels]. 897 label_length: tensor of shape [batch_size], None if labels is SparseTensor 898 Length of reference label sequence in labels. 899 logit_length: tensor of shape [batch_size] Length of input sequence in 900 logits. 901 logits_time_major: (optional) If True (default), logits is shaped [time, 902 batch, logits]. If False, shape is [batch, time, logits] 903 unique: (optional) Unique label indices as computed by 904 ctc_unique_labels(labels). If supplied, enable a faster, memory efficient 905 implementation on TPU. 906 blank_index: (optional) Set the class index to use for the blank label. 907 Negative values will start from num_classes, ie, -1 will reproduce the 908 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 909 some memory/performance overhead to switching from the default of 0 as an 910 additional shifted copy of the logits may be created. 911 name: A name for this `Op`. Defaults to "ctc_loss_dense". 912 913 Returns: 914 loss: tensor of shape [batch_size], negative log probabilities. 915 916 References: 917 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 918 with Recurrent Neural Networks: 919 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 920 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 921 """ 922 if isinstance(labels, sparse_tensor.SparseTensor): 923 if blank_index is None: 924 raise ValueError( 925 "blank_index must be given when using SparseTensor labels.") 926 927 if blank_index < 0: 928 blank_index += _get_dim(logits, 2) 929 930 params = { 931 "labels": labels, 932 "logits": logits, 933 "logit_length": logit_length, 934 "logits_time_major": logits_time_major, 935 "blank_index": blank_index 936 } 937 938 if context.executing_eagerly(): 939 device_type = _get_context_device_type() 940 can_use_gpu = ( 941 # Either user specified GPU or unspecified but GPU is available. 942 (device_type == _GPU_DEVICE_NAME or 943 (device_type is None and context.num_gpus() > 0))) 944 # Under eager context, check the device placement and prefer the 945 if can_use_gpu: 946 res = _ctc_loss_op_cudnn(**params) 947 else: 948 res = _ctc_loss_op_standard(**params) 949 else: 950 api_name = "ctc_loss_" + str(uuid.uuid4()) 951 ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME, 952 _ctc_loss_op_standard) 953 ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME, 954 _ctc_loss_op_cudnn) 955 res = ctc_loss_op_standard(**params) 956 function_eager.register(ctc_loss_op_cudnn, **params) 957 return res 958 959 if blank_index is None: 960 blank_index = 0 961 962 return ctc_loss_dense( 963 labels=labels, 964 logits=logits, 965 label_length=label_length, 966 logit_length=logit_length, 967 logits_time_major=logits_time_major, 968 unique=unique, 969 blank_index=blank_index, 970 name=name) 971 972 973def ctc_loss_dense(labels, 974 logits, 975 label_length, 976 logit_length, 977 logits_time_major=True, 978 unique=None, 979 blank_index=0, 980 name=None): 981 """Computes CTC (Connectionist Temporal Classification) loss. 982 983 This op implements the CTC loss as presented in (Graves et al., 2006), 984 using the batched forward backward algorithm described in (Sim et al., 2017). 985 986 Notes: 987 Significant differences from tf.compat.v1.nn.ctc_loss: 988 Supports GPU and TPU (tf.compat.v1.nn.ctc_loss supports CPU only): 989 For batched operations, GPU and TPU are significantly faster than using 990 ctc_loss on CPU. 991 This implementation runs on CPU, but significantly slower than ctc_loss. 992 Blank label is 0 rather num_classes - 1, unless overridden by blank_index. 993 Logits and labels are dense arrays with padding rather than SparseTensor. 994 The only mode supported is the same as: 995 preprocess_collapse_repeated=False, ctc_merge_repeated=True 996 To collapse labels, the caller can preprocess label sequence first. 997 998 The dense implementation supports both CPU, GPU and TPU. A fast path is 999 provided that significantly improves memory use for large vocabulary if the 1000 caller preprocesses label sequences to get unique label indices on the CPU 1001 (eg. in the data input pipeline) using ctc_ops.unique and simplifies this in 1002 the optional "unique" kwarg. This is especially useful for TPU and GPU but 1003 also works with if used on CPU. 1004 1005 Args: 1006 labels: tensor of shape [batch_size, max_label_seq_length] 1007 logits: tensor of shape [frames, batch_size, num_labels], if 1008 logits_time_major == False, shape is [batch_size, frames, num_labels]. 1009 label_length: tensor of shape [batch_size] Length of reference label 1010 sequence in labels. 1011 logit_length: tensor of shape [batch_size] Length of input sequence in 1012 logits. 1013 logits_time_major: (optional) If True (default), logits is shaped [time, 1014 batch, logits]. If False, shape is [batch, time, logits] 1015 unique: (optional) Unique label indices as computed by unique(labels). If 1016 supplied, enable a faster, memory efficient implementation on TPU. 1017 blank_index: (optional) Set the class index to use for the blank label. 1018 Negative values will start from num_classes, ie, -1 will reproduce the 1019 ctc_loss behavior of using num_classes - 1 for the blank symbol. There is 1020 some memory/performance overhead to switching from the default of 0 as an 1021 additional shifted copy of the logits may be created. 1022 name: A name for this `Op`. Defaults to "ctc_loss_dense". 1023 1024 Returns: 1025 loss: tensor of shape [batch_size], negative log probabilities. 1026 1027 References: 1028 Connectionist Temporal Classification - Labeling Unsegmented Sequence Data 1029 with Recurrent Neural Networks: 1030 [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) 1031 ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) 1032 Improving the efficiency of forward-backward algorithm using batched 1033 computation in TensorFlow: 1034 [Sim et al., 2017](https://ieeexplore.ieee.org/document/8268944) 1035 ([pdf](http://bacchiani.net/resume/papers/ASRU2017.pdf)) 1036 """ 1037 1038 with ops.name_scope(name, "ctc_loss_dense", 1039 [logits, labels, label_length, logit_length]): 1040 logits = ops.convert_to_tensor(logits, name="logits") 1041 labels = ops.convert_to_tensor(labels, name="labels") 1042 label_length = ops.convert_to_tensor(label_length, name="label_length") 1043 logit_length = ops.convert_to_tensor(logit_length, name="logit_length") 1044 1045 if not logits_time_major: 1046 logits = array_ops.transpose(logits, perm=[1, 0, 2]) 1047 1048 if blank_index != 0: 1049 if blank_index < 0: 1050 blank_index += _get_dim(logits, 2) 1051 logits = array_ops.concat([ 1052 logits[:, :, blank_index:blank_index + 1], 1053 logits[:, :, :blank_index], 1054 logits[:, :, blank_index + 1:], 1055 ], 1056 axis=2) 1057 labels = array_ops.where(labels < blank_index, labels + 1, labels) 1058 1059 args = [logits, labels, label_length, logit_length] 1060 1061 if unique: 1062 unique_y, unique_idx = unique 1063 if blank_index != 0: 1064 unique_y = array_ops.where(unique_y < blank_index, unique_y + 1, 1065 unique_y) 1066 label_mask_len = math_ops.reduce_max(unique_idx, axis=1) + 1 1067 max_label_length = _get_dim(unique_y, 1) 1068 label_mask = array_ops.sequence_mask(label_mask_len, max_label_length) 1069 unique_y = array_ops.where(label_mask, unique_y, 1070 array_ops.zeros_like(unique_y)) 1071 args.extend([unique_y, unique_idx]) 1072 1073 @custom_gradient.custom_gradient 1074 def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t, 1075 *unique_t): 1076 """Compute CTC loss.""" 1077 logits_t.set_shape(logits.shape) 1078 labels_t.set_shape(labels.shape) 1079 label_length_t.set_shape(label_length.shape) 1080 logit_length_t.set_shape(logit_length.shape) 1081 kwargs = dict( 1082 logits=logits_t, 1083 labels=labels_t, 1084 label_length=label_length_t, 1085 logit_length=logit_length_t) 1086 if unique_t: 1087 kwargs["unique"] = unique_t 1088 result = ctc_loss_and_grad(**kwargs) 1089 def grad(grad_loss): 1090 grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * result[1]] 1091 grad += [None] * (len(args) - len(grad)) 1092 return grad 1093 1094 return result[0], grad 1095 1096 return compute_ctc_loss(*args) 1097 1098 1099@tf_export("nn.collapse_repeated") 1100@dispatch.add_dispatch_support 1101def collapse_repeated(labels, seq_length, name=None): 1102 """Merge repeated labels into single labels. 1103 1104 Args: 1105 labels: Tensor of shape [batch, max value in seq_length] 1106 seq_length: Tensor of shape [batch], sequence length of each batch element. 1107 name: A name for this `Op`. Defaults to "collapse_repeated_labels". 1108 1109 Returns: 1110 A tuple `(collapsed_labels, new_seq_length)` where 1111 1112 collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated 1113 labels collapsed and padded to max_seq_length, eg: 1114 `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` 1115 1116 new_seq_length: int tensor of shape [batch] with new sequence lengths. 1117 """ 1118 1119 with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): 1120 labels = ops.convert_to_tensor(labels, name="labels") 1121 seq_length = ops.convert_to_tensor(seq_length, name="seq_length") 1122 1123 # Mask labels that don't equal previous label. 1124 label_mask = array_ops.concat([ 1125 array_ops.ones_like(labels[:, :1], dtypes.bool), 1126 math_ops.not_equal(labels[:, 1:], labels[:, :-1]) 1127 ], 1128 axis=1) 1129 1130 # Filter labels that aren't in the original sequence. 1131 maxlen = _get_dim(labels, 1) 1132 seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) 1133 label_mask = math_ops.logical_and(label_mask, seq_mask) 1134 1135 # Count masks for new sequence lengths. 1136 new_seq_len = math_ops.reduce_sum( 1137 math_ops.cast(label_mask, dtypes.int32), axis=1) 1138 1139 # Mask indexes based on sequence length mask. 1140 new_maxlen = math_ops.reduce_max(new_seq_len) 1141 idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) 1142 1143 # Flatten everything and mask out labels to keep and sparse indices. 1144 flat_labels = array_ops.reshape(labels, [-1]) 1145 flat_label_mask = array_ops.reshape(label_mask, [-1]) 1146 flat_idx_mask = array_ops.reshape(idx_mask, [-1]) 1147 idx = math_ops.range(_get_dim(flat_idx_mask, 0)) 1148 1149 # Scatter to flat shape. 1150 flat = array_ops.scatter_nd( 1151 indices=array_ops.expand_dims( 1152 array_ops.boolean_mask(idx, flat_idx_mask), axis=1), 1153 updates=array_ops.boolean_mask(flat_labels, flat_label_mask), 1154 shape=array_ops.shape(flat_idx_mask)) 1155 1156 # Reshape back to square batch. 1157 batch_size = _get_dim(labels, 0) 1158 new_shape = [batch_size, new_maxlen] 1159 return (array_ops.reshape(flat, new_shape), 1160 math_ops.cast(new_seq_len, seq_length.dtype)) 1161 1162 1163def dense_labels_to_sparse(dense, length): 1164 """Convert dense labels with sequence lengths to sparse tensor. 1165 1166 Args: 1167 dense: tensor of shape [batch, max_length] 1168 length: int tensor of shape [batch] The length of each sequence in dense. 1169 1170 Returns: 1171 tf.sparse.SparseTensor with values only for the valid elements of sequences. 1172 """ 1173 1174 flat_values = array_ops.reshape(dense, [-1]) 1175 flat_indices = math_ops.range( 1176 array_ops.shape(flat_values, out_type=dtypes.int64)[0]) 1177 mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1]) 1178 flat_mask = array_ops.reshape(mask, [-1]) 1179 indices = array_ops.expand_dims( 1180 array_ops.boolean_mask(flat_indices, flat_mask), 1) 1181 values = array_ops.boolean_mask(flat_values, flat_mask) 1182 sparse = sparse_tensor.SparseTensor( 1183 indices=indices, 1184 values=math_ops.cast(values, dtypes.int32), 1185 dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64)) 1186 reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense)) 1187 max_length = math_ops.reduce_max(length) 1188 return sparse_tensor.SparseTensor( 1189 indices=reshaped.indices, 1190 values=reshaped.values, 1191 dense_shape=[ 1192 math_ops.cast(reshaped.dense_shape[0], dtypes.int64), 1193 math_ops.cast(max_length, dtypes.int64) 1194 ]) 1195 1196 1197@tf_export("nn.ctc_unique_labels") 1198@dispatch.add_dispatch_support 1199def ctc_unique_labels(labels, name=None): 1200 """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. 1201 1202 For use with `tf.nn.ctc_loss` optional argument `unique`: This op can be 1203 used to preprocess labels in input pipeline to for better speed/memory use 1204 computing the ctc loss on TPU. 1205 1206 Example: 1207 ctc_unique_labels([[3, 4, 4, 3]]) -> 1208 unique labels padded with 0: [[3, 4, 0, 0]] 1209 indices of original labels in unique: [0, 1, 1, 0] 1210 1211 Args: 1212 labels: tensor of shape [batch_size, max_label_length] padded with 0. 1213 name: A name for this `Op`. Defaults to "ctc_unique_labels". 1214 1215 Returns: 1216 tuple of 1217 - unique labels, tensor of shape `[batch_size, max_label_length]` 1218 - indices into unique labels, shape `[batch_size, max_label_length]` 1219 """ 1220 1221 with ops.name_scope(name, "ctc_unique_labels", [labels]): 1222 labels = ops.convert_to_tensor(labels, name="labels") 1223 1224 def _unique(x): 1225 u = array_ops.unique(x) 1226 y = array_ops.pad(u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]]) 1227 y = math_ops.cast(y, dtypes.int64) 1228 return [y, u.idx] 1229 1230 return map_fn.map_fn(_unique, labels, dtype=[dtypes.int64, dtypes.int32]) 1231 1232 1233def _sum_states(idx, states): 1234 """Take logsumexp for each unique state out of all label states. 1235 1236 Args: 1237 idx: tensor of shape [batch, label_length] For each sequence, indices into a 1238 set of unique labels as computed by calling unique. 1239 states: tensor of shape [frames, batch, label_length] Log probabilities for 1240 each label state. 1241 1242 Returns: 1243 tensor of shape [frames, batch_size, label_length], log probabilites summed 1244 for each unique label of the sequence. 1245 """ 1246 1247 with ops.name_scope("sum_states"): 1248 idx = ops.convert_to_tensor(idx, name="idx") 1249 num_states = _get_dim(states, 2) 1250 states = array_ops.expand_dims(states, axis=2) 1251 one_hot = array_ops.one_hot( 1252 idx, 1253 depth=num_states, 1254 on_value=0.0, 1255 off_value=math_ops.log(0.0), 1256 axis=1) 1257 return math_ops.reduce_logsumexp(states + one_hot, axis=-1) 1258 1259 1260def _forward_backward_log(state_trans_log_probs, initial_state_log_probs, 1261 final_state_log_probs, observed_log_probs, 1262 sequence_length): 1263 """Forward-backward algorithm computed in log domain. 1264 1265 Args: 1266 state_trans_log_probs: tensor of shape [states, states] or if different 1267 transition matrix per batch [batch_size, states, states] 1268 initial_state_log_probs: tensor of shape [batch_size, states] 1269 final_state_log_probs: tensor of shape [batch_size, states] 1270 observed_log_probs: tensor of shape [frames, batch_size, states] 1271 sequence_length: tensor of shape [batch_size] 1272 1273 Returns: 1274 forward backward log probabilites: tensor of shape [frames, batch, states] 1275 log_likelihood: tensor of shape [batch_size] 1276 1277 Raises: 1278 ValueError: If state_trans_log_probs has unknown or incorrect rank. 1279 """ 1280 1281 if state_trans_log_probs.shape.ndims == 2: 1282 perm = [1, 0] 1283 elif state_trans_log_probs.shape.ndims == 3: 1284 perm = [0, 2, 1] 1285 else: 1286 raise ValueError( 1287 "state_trans_log_probs rank must be known and == 2 or 3, is: %s" % 1288 state_trans_log_probs.shape.ndims) 1289 1290 bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm) 1291 batch_size = _get_dim(observed_log_probs, 1) 1292 1293 def _forward(state_log_prob, obs_log_prob): 1294 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1295 state_log_prob += state_trans_log_probs 1296 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1297 state_log_prob += obs_log_prob 1298 log_prob_sum = math_ops.reduce_logsumexp( 1299 state_log_prob, axis=-1, keepdims=True) 1300 state_log_prob -= log_prob_sum 1301 return state_log_prob 1302 1303 fwd = _scan( 1304 _forward, observed_log_probs, initial_state_log_probs, inclusive=True) 1305 1306 def _backward(accs, elems): 1307 """Calculate log probs and cumulative sum masked for sequence length.""" 1308 state_log_prob, cum_log_sum = accs 1309 obs_log_prob, mask = elems 1310 state_log_prob += obs_log_prob 1311 state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast. 1312 state_log_prob += bwd_state_trans_log_probs 1313 state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1) 1314 1315 log_prob_sum = math_ops.reduce_logsumexp( 1316 state_log_prob, axis=-1, keepdims=True) 1317 state_log_prob -= log_prob_sum 1318 1319 cum_log_sum += array_ops.squeeze(log_prob_sum) * mask 1320 batched_mask = array_ops.expand_dims(mask, axis=1) 1321 out = state_log_prob * batched_mask 1322 out += final_state_log_probs * (1.0 - batched_mask) 1323 return out, cum_log_sum 1324 1325 zero_log_sum = array_ops.zeros([batch_size]) 1326 maxlen = _get_dim(observed_log_probs, 0) 1327 mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32) 1328 mask = array_ops.transpose(mask, perm=[1, 0]) 1329 1330 bwd, cum_log_sum = _scan( 1331 _backward, (observed_log_probs, mask), 1332 (final_state_log_probs, zero_log_sum), 1333 reverse=True, 1334 inclusive=True) 1335 1336 fwd_bwd_log_probs = fwd[1:] + bwd[1:] 1337 fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp( 1338 fwd_bwd_log_probs, axis=2, keepdims=True) 1339 fwd_bwd_log_probs -= fwd_bwd_log_probs_sum 1340 fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2)) 1341 1342 log_likelihood = bwd[0, :, 0] + cum_log_sum[0] 1343 1344 return fwd_bwd_log_probs, log_likelihood 1345 1346 1347# TODO(tombagby): This is currently faster for the ctc implementation than using 1348# functional_ops.scan, but could be replaced by that or something similar if 1349# things change. 1350def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False): 1351 """Repeatedly applies callable `fn` to a sequence of elements. 1352 1353 Implemented by functional_ops.While, tpu friendly, no gradient. 1354 1355 This is similar to functional_ops.scan but significantly faster on tpu/gpu 1356 for the forward backward use case. 1357 1358 Examples: 1359 scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0] 1360 1361 Multiple accumulators: 1362 scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0)) 1363 1364 Multiple inputs: 1365 scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0) 1366 1367 Args: 1368 fn: callable, fn(accumulators, element) return new accumulator values. The 1369 (possibly nested) sequence of accumulators is the same as `initial` and 1370 the return value must have the same structure. 1371 elems: A (possibly nested) tensor which will be unpacked along the first 1372 dimension. The resulting slices will be the second argument to fn. The 1373 first dimension of all nested input tensors must be the same. 1374 initial: A tensor or (possibly nested) sequence of tensors with initial 1375 values for the accumulators. 1376 reverse: (optional) True enables scan and output elems in reverse order. 1377 inclusive: (optional) True includes the initial accumulator values in the 1378 output. Length of output will be len(elem sequence) + 1. Not meaningful if 1379 final_only is True. 1380 final_only: (optional) When True, return only the final accumulated values, 1381 not the concatenation of accumulated values for each input. 1382 1383 Returns: 1384 A (possibly nested) sequence of tensors with the results of applying fn 1385 to tensors unpacked from elems and previous accumulator values. 1386 """ 1387 1388 flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)] 1389 num_elems = array_ops.shape(flat_elems[0])[0] 1390 pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x) 1391 flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)] 1392 pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x) 1393 accum_dtypes = [x.dtype for x in flat_initial] 1394 num_accums = len(flat_initial) 1395 1396 # Types for counter, [outputs], [accumulators] loop arguments. 1397 if final_only: 1398 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes 1399 else: 1400 loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes 1401 1402 # TODO(tombagby): Update to tfe.defun 1403 def cond(i, num_elems, *args): 1404 del args 1405 return i >= 0 if reverse else i < num_elems 1406 1407 # The loop *args are [output tensors] + [accumulator tensors] which must 1408 # be paired. Each output corresponds to one accumulator. 1409 def body(i, num_elems, *args): 1410 """Loop body.""" 1411 i.set_shape([]) 1412 if final_only: 1413 accum = args 1414 else: 1415 out, accum = args[:num_accums], args[num_accums:] 1416 slices = [array_ops.gather(e, i) for e in flat_elems] 1417 accum = fn(pack(accum), pack_elems(slices)) 1418 flat_accum = nest.flatten(accum) 1419 if final_only: 1420 new_out = [] 1421 else: 1422 update_i = i + 1 if inclusive and not reverse else i 1423 new_out = [ 1424 inplace_ops.alias_inplace_update(x, update_i, y) 1425 for x, y in zip(out, flat_accum) 1426 ] 1427 i = i - 1 if reverse else i + 1 1428 return [i, num_elems] + new_out + flat_accum 1429 1430 init_i = ( 1431 array_ops.shape(flat_elems[0])[0] - 1432 1 if reverse else constant_op.constant(0, dtype=dtypes.int32)) 1433 outputs = [] 1434 if not final_only: 1435 num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0) 1436 for initial_accum in flat_initial: 1437 out_shape = array_ops.concat( 1438 [[num_outputs], array_ops.shape(initial_accum)], 0) 1439 out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True) 1440 if inclusive: 1441 out = inplace_ops.alias_inplace_add(out, init_i + (1 if reverse else 0), 1442 initial_accum) 1443 outputs.append(out) 1444 loop_in = [init_i, num_elems] + outputs + flat_initial 1445 hostmem = [ 1446 i for i, x in enumerate(loop_in) 1447 if x.dtype.base_dtype in (dtypes.int32, dtypes.int64) 1448 ] 1449 1450 if context.executing_eagerly(): 1451 loop_results = loop_in 1452 while cond(*loop_results): 1453 loop_results = body(*loop_results) 1454 else: 1455 # TODO(tombagby): Update to while_v2. 1456 cond = function.Defun(*loop_dtypes)(cond) 1457 body = function.Defun(*loop_dtypes)(body) 1458 loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem) 1459 out = loop_results[2:num_accums + 2] 1460 return pack(out) 1461 1462 1463def _get_dim(tensor, i): 1464 """Get value of tensor shape[i] preferring static value if available.""" 1465 return tensor_shape.dimension_value( 1466 tensor.shape[i]) or array_ops.shape(tensor)[i] 1467