• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ctc_ops.ctc_decoder_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23import numpy as np
24from six.moves import zip_longest
25
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import ctc_ops
31from tensorflow.python.platform import test
32
33
34def grouper(iterable, n, fillvalue=None):
35  """Collect data into fixed-length chunks or blocks."""
36  # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
37  args = [iter(iterable)] * n
38  return zip_longest(fillvalue=fillvalue, *args)
39
40
41def flatten(list_of_lists):
42  """Flatten one level of nesting."""
43  return itertools.chain.from_iterable(list_of_lists)
44
45
46class CTCGreedyDecoderTest(test.TestCase):
47
48  def _testCTCDecoder(self,
49                      decoder,
50                      inputs,
51                      seq_lens,
52                      log_prob_truth,
53                      decode_truth,
54                      expected_err_re=None,
55                      **decoder_args):
56    inputs_t = [ops.convert_to_tensor(x) for x in inputs]
57    # convert inputs_t into a [max_time x batch_size x depth] tensor
58    # from a len time python list of [batch_size x depth] tensors
59    inputs_t = array_ops.stack(inputs_t)
60
61    with self.cached_session(use_gpu=False) as sess:
62      decoded_list, log_probability = decoder(
63          inputs_t, sequence_length=seq_lens, **decoder_args)
64      decoded_unwrapped = list(
65          flatten([(st.indices, st.values, st.dense_shape) for st in
66                   decoded_list]))
67
68      if expected_err_re is None:
69        outputs = sess.run(decoded_unwrapped + [log_probability])
70
71        # Group outputs into (ix, vals, shape) tuples
72        output_sparse_tensors = list(grouper(outputs[:-1], 3))
73
74        output_log_probability = outputs[-1]
75
76        # Check the number of decoded outputs (top_paths) match
77        self.assertEqual(len(output_sparse_tensors), len(decode_truth))
78
79        # For each SparseTensor tuple, compare (ix, vals, shape)
80        for out_st, truth_st, tf_st in zip(output_sparse_tensors, decode_truth,
81                                           decoded_list):
82          self.assertAllEqual(out_st[0], truth_st[0])  # ix
83          self.assertAllEqual(out_st[1], truth_st[1])  # vals
84          self.assertAllEqual(out_st[2], truth_st[2])  # shape
85          # Compare the shapes of the components with the truth. The
86          # `None` elements are not known statically.
87          self.assertEqual([None, truth_st[0].shape[1]],
88                           tf_st.indices.get_shape().as_list())
89          self.assertEqual([None], tf_st.values.get_shape().as_list())
90          self.assertShapeEqual(truth_st[2], tf_st.dense_shape)
91
92        # Make sure decoded probabilities match
93        self.assertAllClose(output_log_probability, log_prob_truth, atol=1e-6)
94      else:
95        with self.assertRaisesOpError(expected_err_re):
96          sess.run(decoded_unwrapped + [log_probability])
97
98  @test_util.run_deprecated_v1
99  def testCTCGreedyDecoder(self):
100    """Test two batch entries - best path decoder."""
101    max_time_steps = 6
102    # depth == 4
103    seq_len_0 = 4
104    input_prob_matrix_0 = np.asarray(
105        [
106            [1.0, 0.0, 0.0, 0.0],  # t=0
107            [0.0, 0.0, 0.4, 0.6],  # t=1
108            [0.0, 0.0, 0.4, 0.6],  # t=2
109            [0.0, 0.9, 0.1, 0.0],  # t=3
110            [0.0, 0.0, 0.0, 0.0],  # t=4 (ignored)
111            [0.0, 0.0, 0.0, 0.0]
112        ],  # t=5 (ignored)
113        dtype=np.float32)
114    input_log_prob_matrix_0 = np.log(input_prob_matrix_0)
115
116    seq_len_1 = 5
117    # dimensions are time x depth
118    input_prob_matrix_1 = np.asarray(
119        [
120            [0.1, 0.9, 0.0, 0.0],  # t=0
121            [0.0, 0.9, 0.1, 0.0],  # t=1
122            [0.0, 0.0, 0.1, 0.9],  # t=2
123            [0.0, 0.9, 0.1, 0.1],  # t=3
124            [0.9, 0.1, 0.0, 0.0],  # t=4
125            [0.0, 0.0, 0.0, 0.0]  # t=5 (ignored)
126        ],
127        dtype=np.float32)
128    input_log_prob_matrix_1 = np.log(input_prob_matrix_1)
129
130    # len max_time_steps array of batch_size x depth matrices
131    inputs = np.array([
132        np.vstack(
133            [input_log_prob_matrix_0[t, :], input_log_prob_matrix_1[t, :]])
134        for t in range(max_time_steps)
135    ])
136
137    # batch_size length vector of sequence_lengths
138    seq_lens = np.array([seq_len_0, seq_len_1], dtype=np.int32)
139
140    # batch_size length vector of negative log probabilities
141    log_prob_truth = np.array([
142        np.sum(-np.log([1.0, 0.6, 0.6, 0.9])),
143        np.sum(-np.log([0.9, 0.9, 0.9, 0.9, 0.9]))
144    ], np.float32)[:, np.newaxis]
145
146    # decode_truth: one SparseTensor (ix, vals, shape)
147    decode_truth = [
148        (
149            np.array(
150                [
151                    [0, 0],  # batch 0, 2 outputs
152                    [0, 1],
153                    [1, 0],  # batch 1, 3 outputs
154                    [1, 1],
155                    [1, 2]
156                ],
157                dtype=np.int64),
158            np.array(
159                [
160                    0,  # batch 0, 2 values
161                    1,
162                    1,  # batch 1, 3 values
163                    1,
164                    0
165                ],
166                dtype=np.int64),
167            # shape is batch x max_decoded_length
168            np.array([2, 3], dtype=np.int64)),
169    ]
170
171    # Test without defining blank_index
172    self._testCTCDecoder(ctc_ops.ctc_greedy_decoder, inputs, seq_lens,
173                         log_prob_truth, decode_truth)
174
175    # Shift blank_index to be somewhere in the middle of inputs
176    blank_index = 2
177    inputs = np.concatenate(
178        (inputs[:, :, :blank_index], inputs[:, :, -1:], inputs[:, :,
179                                                               blank_index:-1]),
180        axis=2)
181
182    # Test positive value in blank_index
183    self._testCTCDecoder(
184        ctc_ops.ctc_greedy_decoder,
185        inputs,
186        seq_lens,
187        log_prob_truth,
188        decode_truth,
189        blank_index=2)
190
191    # Test negative value in blank_index
192    self._testCTCDecoder(
193        ctc_ops.ctc_greedy_decoder,
194        inputs,
195        seq_lens,
196        log_prob_truth,
197        decode_truth,
198        blank_index=-2)
199
200  @test_util.run_deprecated_v1
201  def testCTCDecoderBeamSearch(self):
202    """Test one batch, two beams - hibernating beam search."""
203    # max_time_steps == 8
204    depth = 6
205
206    seq_len_0 = 5
207    input_prob_matrix_0 = np.asarray(
208        [
209            [0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
210            [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
211            [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
212            [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
213            [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
214            # Random entry added in at time=5
215            [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
216        ],
217        dtype=np.float32)
218    # Add arbitrary offset - this is fine
219    input_prob_matrix_0 = input_prob_matrix_0 + 2.0
220
221    # len max_time_steps array of batch_size x depth matrices
222    inputs = ([
223        input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
224    ]  # Pad to max_time_steps = 8
225              + 2 * [np.zeros(
226                  (1, depth), dtype=np.float32)])
227
228    # batch_size length vector of sequence_lengths
229    seq_lens = np.array([seq_len_0], dtype=np.int32)
230
231    # batch_size length vector of log probabilities
232    log_prob_truth = np.array(
233        [
234            -5.811451,  # output beam 0
235            -6.63339  # output beam 1
236        ],
237        np.float32)[np.newaxis, :]
238
239    # decode_truth: two SparseTensors, (ix, values, shape)
240    decode_truth = [
241        # beam 0, batch 0, two outputs decoded
242        (np.array(
243            [[0, 0], [0, 1]], dtype=np.int64), np.array(
244                [1, 0], dtype=np.int64), np.array(
245                    [1, 2], dtype=np.int64)),
246        # beam 1, batch 0, one output decoded
247        (np.array(
248            [[0, 0]], dtype=np.int64), np.array(
249                [1], dtype=np.int64), np.array(
250                    [1, 1], dtype=np.int64)),
251    ]
252
253    # Test correct decoding.
254    self._testCTCDecoder(
255        ctc_ops.ctc_beam_search_decoder,
256        inputs,
257        seq_lens,
258        log_prob_truth,
259        decode_truth,
260        beam_width=2,
261        top_paths=2)
262
263    # Requesting more paths than the beam width allows.
264    with self.assertRaisesRegex(errors.InvalidArgumentError,
265                                (".*requested more paths than the beam "
266                                 "width.*")):
267      self._testCTCDecoder(
268          ctc_ops.ctc_beam_search_decoder,
269          inputs,
270          seq_lens,
271          log_prob_truth,
272          decode_truth,
273          beam_width=2,
274          top_paths=3)
275
276
277if __name__ == "__main__":
278  test.main()
279