1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ctc_ops.ctc_decoder_ops.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.framework import errors 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import ctc_ops 26from tensorflow.python.platform import test 27 28 29def grouper(iterable, n, fillvalue=None): 30 """Collect data into fixed-length chunks or blocks.""" 31 # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx 32 args = [iter(iterable)] * n 33 return itertools.zip_longest(fillvalue=fillvalue, *args) 34 35 36def flatten(list_of_lists): 37 """Flatten one level of nesting.""" 38 return itertools.chain.from_iterable(list_of_lists) 39 40 41class CTCGreedyDecoderTest(test.TestCase): 42 43 def _testCTCDecoder(self, 44 decoder, 45 inputs, 46 seq_lens, 47 log_prob_truth, 48 decode_truth, 49 expected_err_re=None, 50 **decoder_args): 51 inputs_t = [ops.convert_to_tensor(x) for x in inputs] 52 # convert inputs_t into a [max_time x batch_size x depth] tensor 53 # from a len time python list of [batch_size x depth] tensors 54 inputs_t = array_ops.stack(inputs_t) 55 56 with self.cached_session(use_gpu=False) as sess: 57 decoded_list, log_probability = decoder( 58 inputs_t, sequence_length=seq_lens, **decoder_args) 59 decoded_unwrapped = list( 60 flatten([(st.indices, st.values, st.dense_shape) for st in 61 decoded_list])) 62 63 if expected_err_re is None: 64 outputs = sess.run(decoded_unwrapped + [log_probability]) 65 66 # Group outputs into (ix, vals, shape) tuples 67 output_sparse_tensors = list(grouper(outputs[:-1], 3)) 68 69 output_log_probability = outputs[-1] 70 71 # Check the number of decoded outputs (top_paths) match 72 self.assertEqual(len(output_sparse_tensors), len(decode_truth)) 73 74 # For each SparseTensor tuple, compare (ix, vals, shape) 75 for out_st, truth_st, tf_st in zip(output_sparse_tensors, decode_truth, 76 decoded_list): 77 self.assertAllEqual(out_st[0], truth_st[0]) # ix 78 self.assertAllEqual(out_st[1], truth_st[1]) # vals 79 self.assertAllEqual(out_st[2], truth_st[2]) # shape 80 # Compare the shapes of the components with the truth. The 81 # `None` elements are not known statically. 82 self.assertEqual([None, truth_st[0].shape[1]], 83 tf_st.indices.get_shape().as_list()) 84 self.assertEqual([None], tf_st.values.get_shape().as_list()) 85 self.assertShapeEqual(truth_st[2], tf_st.dense_shape) 86 87 # Make sure decoded probabilities match 88 self.assertAllClose(output_log_probability, log_prob_truth, atol=1e-6) 89 else: 90 with self.assertRaisesOpError(expected_err_re): 91 sess.run(decoded_unwrapped + [log_probability]) 92 93 @test_util.run_deprecated_v1 94 def testCTCGreedyDecoder(self): 95 """Test two batch entries - best path decoder.""" 96 max_time_steps = 6 97 # depth == 4 98 seq_len_0 = 4 99 input_prob_matrix_0 = np.asarray( 100 [ 101 [1.0, 0.0, 0.0, 0.0], # t=0 102 [0.0, 0.0, 0.4, 0.6], # t=1 103 [0.0, 0.0, 0.4, 0.6], # t=2 104 [0.0, 0.9, 0.1, 0.0], # t=3 105 [0.0, 0.0, 0.0, 0.0], # t=4 (ignored) 106 [0.0, 0.0, 0.0, 0.0] 107 ], # t=5 (ignored) 108 dtype=np.float32) 109 input_log_prob_matrix_0 = np.log(input_prob_matrix_0) 110 111 seq_len_1 = 5 112 # dimensions are time x depth 113 input_prob_matrix_1 = np.asarray( 114 [ 115 [0.1, 0.9, 0.0, 0.0], # t=0 116 [0.0, 0.9, 0.1, 0.0], # t=1 117 [0.0, 0.0, 0.1, 0.9], # t=2 118 [0.0, 0.9, 0.1, 0.1], # t=3 119 [0.9, 0.1, 0.0, 0.0], # t=4 120 [0.0, 0.0, 0.0, 0.0] # t=5 (ignored) 121 ], 122 dtype=np.float32) 123 input_log_prob_matrix_1 = np.log(input_prob_matrix_1) 124 125 # len max_time_steps array of batch_size x depth matrices 126 inputs = np.array([ 127 np.vstack( 128 [input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]]) 129 for t in range(max_time_steps) 130 ]) 131 132 # batch_size length vector of sequence_lengths 133 seq_lens = np.array([seq_len_0, seq_len_1], dtype=np.int32) 134 135 # batch_size length vector of negative log probabilities 136 log_prob_truth = np.array([ 137 np.sum(-np.log([1.0, 0.6, 0.6, 0.9])), 138 np.sum(-np.log([0.9, 0.9, 0.9, 0.9, 0.9])) 139 ], np.float32)[:, np.newaxis] 140 141 # decode_truth: one SparseTensor (ix, vals, shape) 142 decode_truth = [ 143 ( 144 np.array( 145 [ 146 [0, 0], # batch 0, 2 outputs 147 [0, 1], 148 [1, 0], # batch 1, 3 outputs 149 [1, 1], 150 [1, 2] 151 ], 152 dtype=np.int64), 153 np.array( 154 [ 155 0, # batch 0, 2 values 156 1, 157 1, # batch 1, 3 values 158 1, 159 0 160 ], 161 dtype=np.int64), 162 # shape is batch x max_decoded_length 163 np.array([2, 3], dtype=np.int64)), 164 ] 165 166 # Test without defining blank_index 167 self._testCTCDecoder(ctc_ops.ctc_greedy_decoder, inputs, seq_lens, 168 log_prob_truth, decode_truth) 169 170 # Shift blank_index to be somewhere in the middle of inputs 171 blank_index = 2 172 inputs = np.concatenate( 173 (inputs[:, :, :blank_index], inputs[:, :, -1:], inputs[:, :, 174 blank_index:-1]), 175 axis=2) 176 177 # Test positive value in blank_index 178 self._testCTCDecoder( 179 ctc_ops.ctc_greedy_decoder, 180 inputs, 181 seq_lens, 182 log_prob_truth, 183 decode_truth, 184 blank_index=2) 185 186 # Test negative value in blank_index 187 self._testCTCDecoder( 188 ctc_ops.ctc_greedy_decoder, 189 inputs, 190 seq_lens, 191 log_prob_truth, 192 decode_truth, 193 blank_index=-2) 194 195 @test_util.run_deprecated_v1 196 def testCTCDecoderBeamSearch(self): 197 """Test one batch, two beams - hibernating beam search.""" 198 # max_time_steps == 8 199 depth = 6 200 201 seq_len_0 = 5 202 input_prob_matrix_0 = np.asarray( 203 [ 204 [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], 205 [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], 206 [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], 207 [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], 208 [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], 209 # Random entry added in at time=5 210 [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] 211 ], 212 dtype=np.float32) 213 # Add arbitrary offset - this is fine 214 input_prob_matrix_0 = input_prob_matrix_0 + 2.0 215 216 # len max_time_steps array of batch_size x depth matrices 217 inputs = ([ 218 input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0) 219 ] # Pad to max_time_steps = 8 220 + 2 * [np.zeros( 221 (1, depth), dtype=np.float32)]) 222 223 # batch_size length vector of sequence_lengths 224 seq_lens = np.array([seq_len_0], dtype=np.int32) 225 226 # batch_size length vector of log probabilities 227 log_prob_truth = np.array( 228 [ 229 -5.811451, # output beam 0 230 -6.63339 # output beam 1 231 ], 232 np.float32)[np.newaxis, :] 233 234 # decode_truth: two SparseTensors, (ix, values, shape) 235 decode_truth = [ 236 # beam 0, batch 0, two outputs decoded 237 (np.array( 238 [[0, 0], [0, 1]], dtype=np.int64), np.array( 239 [1, 0], dtype=np.int64), np.array( 240 [1, 2], dtype=np.int64)), 241 # beam 1, batch 0, one output decoded 242 (np.array( 243 [[0, 0]], dtype=np.int64), np.array( 244 [1], dtype=np.int64), np.array( 245 [1, 1], dtype=np.int64)), 246 ] 247 248 # Test correct decoding. 249 self._testCTCDecoder( 250 ctc_ops.ctc_beam_search_decoder, 251 inputs, 252 seq_lens, 253 log_prob_truth, 254 decode_truth, 255 beam_width=2, 256 top_paths=2) 257 258 # Requesting more paths than the beam width allows. 259 with self.assertRaisesRegex(errors.InvalidArgumentError, 260 (".*requested more paths than the beam " 261 "width.*")): 262 self._testCTCDecoder( 263 ctc_ops.ctc_beam_search_decoder, 264 inputs, 265 seq_lens, 266 log_prob_truth, 267 decode_truth, 268 beam_width=2, 269 top_paths=3) 270 271 272if __name__ == "__main__": 273 test.main() 274