• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for summary V1 audio op."""
16
17import numpy as np
18
19from tensorflow.core.framework import summary_pb2
20from tensorflow.python.framework import ops
21from tensorflow.python.platform import test
22from tensorflow.python.summary import summary
23
24
25class SummaryV1AudioOpTest(test.TestCase):
26
27  def _AsSummary(self, s):
28    summ = summary_pb2.Summary()
29    summ.ParseFromString(s)
30    return summ
31
32  def _CheckProto(self, audio_summ, sample_rate, num_channels, length_frames):
33    """Verify that the non-audio parts of the audio_summ proto match shape."""
34    # Only the first 3 sounds are returned.
35    for v in audio_summ.value:
36      v.audio.ClearField("encoded_audio_string")
37    expected = "\n".join("""
38        value {
39          tag: "snd/audio/%d"
40          audio { content_type: "audio/wav" sample_rate: %d
41                  num_channels: %d length_frames: %d }
42        }""" % (i, sample_rate, num_channels, length_frames) for i in range(3))
43    self.assertProtoEquals(expected, audio_summ)
44
45  def testAudioSummary(self):
46    np.random.seed(7)
47    for channels in (1, 2, 5, 8):
48      with self.session(graph=ops.Graph()) as sess:
49        num_frames = 7
50        shape = (4, num_frames, channels)
51        # Generate random audio in the range [-1.0, 1.0).
52        const = 2.0 * np.random.random(shape) - 1.0
53
54        # Summarize
55        sample_rate = 8000
56        summ = summary.audio(
57            "snd", const, max_outputs=3, sample_rate=sample_rate)
58        value = self.evaluate(summ)
59        self.assertEqual([], summ.get_shape())
60        audio_summ = self._AsSummary(value)
61
62        # Check the rest of the proto
63        self._CheckProto(audio_summ, sample_rate, channels, num_frames)
64
65
66if __name__ == "__main__":
67  test.main()
68