• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ctc_ops.ctc_decoder_ops."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import ctc_ops
26from tensorflow.python.platform import test
27
28
29def grouper(iterable, n, fillvalue=None):
30  """Collect data into fixed-length chunks or blocks."""
31  # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
32  args = [iter(iterable)] * n
33  return itertools.zip_longest(fillvalue=fillvalue, *args)
34
35
36def flatten(list_of_lists):
37  """Flatten one level of nesting."""
38  return itertools.chain.from_iterable(list_of_lists)
39
40
41class CTCGreedyDecoderTest(test.TestCase):
42
43  def _testCTCDecoder(self,
44                      decoder,
45                      inputs,
46                      seq_lens,
47                      log_prob_truth,
48                      decode_truth,
49                      expected_err_re=None,
50                      **decoder_args):
51    inputs_t = [ops.convert_to_tensor(x) for x in inputs]
52    # convert inputs_t into a [max_time x batch_size x depth] tensor
53    # from a len time python list of [batch_size x depth] tensors
54    inputs_t = array_ops.stack(inputs_t)
55
56    with self.cached_session(use_gpu=False) as sess:
57      decoded_list, log_probability = decoder(
58          inputs_t, sequence_length=seq_lens, **decoder_args)
59      decoded_unwrapped = list(
60          flatten([(st.indices, st.values, st.dense_shape) for st in
61                   decoded_list]))
62
63      if expected_err_re is None:
64        outputs = sess.run(decoded_unwrapped + [log_probability])
65
66        # Group outputs into (ix, vals, shape) tuples
67        output_sparse_tensors = list(grouper(outputs[:-1], 3))
68
69        output_log_probability = outputs[-1]
70
71        # Check the number of decoded outputs (top_paths) match
72        self.assertEqual(len(output_sparse_tensors), len(decode_truth))
73
74        # For each SparseTensor tuple, compare (ix, vals, shape)
75        for out_st, truth_st, tf_st in zip(output_sparse_tensors, decode_truth,
76                                           decoded_list):
77          self.assertAllEqual(out_st[0], truth_st[0])  # ix
78          self.assertAllEqual(out_st[1], truth_st[1])  # vals
79          self.assertAllEqual(out_st[2], truth_st[2])  # shape
80          # Compare the shapes of the components with the truth. The
81          # `None` elements are not known statically.
82          self.assertEqual([None, truth_st[0].shape[1]],
83                           tf_st.indices.get_shape().as_list())
84          self.assertEqual([None], tf_st.values.get_shape().as_list())
85          self.assertShapeEqual(truth_st[2], tf_st.dense_shape)
86
87        # Make sure decoded probabilities match
88        self.assertAllClose(output_log_probability, log_prob_truth, atol=1e-6)
89      else:
90        with self.assertRaisesOpError(expected_err_re):
91          sess.run(decoded_unwrapped + [log_probability])
92
93  @test_util.run_deprecated_v1
94  def testCTCGreedyDecoder(self):
95    """Test two batch entries - best path decoder."""
96    max_time_steps = 6
97    # depth == 4
98    seq_len_0 = 4
99    input_prob_matrix_0 = np.asarray(
100        [
101            [1.0, 0.0, 0.0, 0.0],  # t=0
102            [0.0, 0.0, 0.4, 0.6],  # t=1
103            [0.0, 0.0, 0.4, 0.6],  # t=2
104            [0.0, 0.9, 0.1, 0.0],  # t=3
105            [0.0, 0.0, 0.0, 0.0],  # t=4 (ignored)
106            [0.0, 0.0, 0.0, 0.0]
107        ],  # t=5 (ignored)
108        dtype=np.float32)
109    input_log_prob_matrix_0 = np.log(input_prob_matrix_0)
110
111    seq_len_1 = 5
112    # dimensions are time x depth
113    input_prob_matrix_1 = np.asarray(
114        [
115            [0.1, 0.9, 0.0, 0.0],  # t=0
116            [0.0, 0.9, 0.1, 0.0],  # t=1
117            [0.0, 0.0, 0.1, 0.9],  # t=2
118            [0.0, 0.9, 0.1, 0.1],  # t=3
119            [0.9, 0.1, 0.0, 0.0],  # t=4
120            [0.0, 0.0, 0.0, 0.0]  # t=5 (ignored)
121        ],
122        dtype=np.float32)
123    input_log_prob_matrix_1 = np.log(input_prob_matrix_1)
124
125    # len max_time_steps array of batch_size x depth matrices
126    inputs = np.array([
127        np.vstack(
128            [input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]])
129        for t in range(max_time_steps)
130    ])
131
132    # batch_size length vector of sequence_lengths
133    seq_lens = np.array([seq_len_0, seq_len_1], dtype=np.int32)
134
135    # batch_size length vector of negative log probabilities
136    log_prob_truth = np.array([
137        np.sum(-np.log([1.0, 0.6, 0.6, 0.9])),
138        np.sum(-np.log([0.9, 0.9, 0.9, 0.9, 0.9]))
139    ], np.float32)[:, np.newaxis]
140
141    # decode_truth: one SparseTensor (ix, vals, shape)
142    decode_truth = [
143        (
144            np.array(
145                [
146                    [0, 0],  # batch 0, 2 outputs
147                    [0, 1],
148                    [1, 0],  # batch 1, 3 outputs
149                    [1, 1],
150                    [1, 2]
151                ],
152                dtype=np.int64),
153            np.array(
154                [
155                    0,  # batch 0, 2 values
156                    1,
157                    1,  # batch 1, 3 values
158                    1,
159                    0
160                ],
161                dtype=np.int64),
162            # shape is batch x max_decoded_length
163            np.array([2, 3], dtype=np.int64)),
164    ]
165
166    # Test without defining blank_index
167    self._testCTCDecoder(ctc_ops.ctc_greedy_decoder, inputs, seq_lens,
168                         log_prob_truth, decode_truth)
169
170    # Shift blank_index to be somewhere in the middle of inputs
171    blank_index = 2
172    inputs = np.concatenate(
173        (inputs[:, :, :blank_index], inputs[:, :, -1:], inputs[:, :,
174                                                               blank_index:-1]),
175        axis=2)
176
177    # Test positive value in blank_index
178    self._testCTCDecoder(
179        ctc_ops.ctc_greedy_decoder,
180        inputs,
181        seq_lens,
182        log_prob_truth,
183        decode_truth,
184        blank_index=2)
185
186    # Test negative value in blank_index
187    self._testCTCDecoder(
188        ctc_ops.ctc_greedy_decoder,
189        inputs,
190        seq_lens,
191        log_prob_truth,
192        decode_truth,
193        blank_index=-2)
194
195  @test_util.run_deprecated_v1
196  def testCTCDecoderBeamSearch(self):
197    """Test one batch, two beams - hibernating beam search."""
198    # max_time_steps == 8
199    depth = 6
200
201    seq_len_0 = 5
202    input_prob_matrix_0 = np.asarray(
203        [
204            [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
205            [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
206            [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
207            [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
208            [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
209            # Random entry added in at time=5
210            [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
211        ],
212        dtype=np.float32)
213    # Add arbitrary offset - this is fine
214    input_prob_matrix_0 = input_prob_matrix_0 + 2.0
215
216    # len max_time_steps array of batch_size x depth matrices
217    inputs = ([
218        input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
219    ]  # Pad to max_time_steps = 8
220              + 2 * [np.zeros(
221                  (1, depth), dtype=np.float32)])
222
223    # batch_size length vector of sequence_lengths
224    seq_lens = np.array([seq_len_0], dtype=np.int32)
225
226    # batch_size length vector of log probabilities
227    log_prob_truth = np.array(
228        [
229            -5.811451,  # output beam 0
230            -6.63339  # output beam 1
231        ],
232        np.float32)[np.newaxis, :]
233
234    # decode_truth: two SparseTensors, (ix, values, shape)
235    decode_truth = [
236        # beam 0, batch 0, two outputs decoded
237        (np.array(
238            [[0, 0], [0, 1]], dtype=np.int64), np.array(
239                [1, 0], dtype=np.int64), np.array(
240                    [1, 2], dtype=np.int64)),
241        # beam 1, batch 0, one output decoded
242        (np.array(
243            [[0, 0]], dtype=np.int64), np.array(
244                [1], dtype=np.int64), np.array(
245                    [1, 1], dtype=np.int64)),
246    ]
247
248    # Test correct decoding.
249    self._testCTCDecoder(
250        ctc_ops.ctc_beam_search_decoder,
251        inputs,
252        seq_lens,
253        log_prob_truth,
254        decode_truth,
255        beam_width=2,
256        top_paths=2)
257
258    # Requesting more paths than the beam width allows.
259    with self.assertRaisesRegex(errors.InvalidArgumentError,
260                                (".*requested more paths than the beam "
261                                 "width.*")):
262      self._testCTCDecoder(
263          ctc_ops.ctc_beam_search_decoder,
264          inputs,
265          seq_lens,
266          log_prob_truth,
267          decode_truth,
268          beam_width=2,
269          top_paths=3)
270
271
272if __name__ == "__main__":
273  test.main()
274