• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for speech commands models."""
16
17import tensorflow as tf
18
19from tensorflow.examples.speech_commands import models
20from tensorflow.python.framework import test_util
21from tensorflow.python.platform import test
22
23
24class ModelsTest(test.TestCase):
25
26  def _modelSettings(self):
27    return models.prepare_model_settings(
28        label_count=10,
29        sample_rate=16000,
30        clip_duration_ms=1000,
31        window_size_ms=20,
32        window_stride_ms=10,
33        feature_bin_count=40,
34        preprocess="mfcc")
35
36  def testPrepareModelSettings(self):
37    self.assertIsNotNone(
38        models.prepare_model_settings(
39            label_count=10,
40            sample_rate=16000,
41            clip_duration_ms=1000,
42            window_size_ms=20,
43            window_stride_ms=10,
44            feature_bin_count=40,
45            preprocess="mfcc"))
46
47  @test_util.run_deprecated_v1
48  def testCreateModelConvTraining(self):
49    model_settings = self._modelSettings()
50    with self.cached_session() as sess:
51      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
52      logits, dropout_rate = models.create_model(
53          fingerprint_input, model_settings, "conv", True)
54      self.assertIsNotNone(logits)
55      self.assertIsNotNone(dropout_rate)
56      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
57      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
58
59  @test_util.run_deprecated_v1
60  def testCreateModelConvInference(self):
61    model_settings = self._modelSettings()
62    with self.cached_session() as sess:
63      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
64      logits = models.create_model(fingerprint_input, model_settings, "conv",
65                                   False)
66      self.assertIsNotNone(logits)
67      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
68
69  @test_util.run_deprecated_v1
70  def testCreateModelLowLatencyConvTraining(self):
71    model_settings = self._modelSettings()
72    with self.cached_session() as sess:
73      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
74      logits, dropout_rate = models.create_model(
75          fingerprint_input, model_settings, "low_latency_conv", True)
76      self.assertIsNotNone(logits)
77      self.assertIsNotNone(dropout_rate)
78      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
79      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
80
81  @test_util.run_deprecated_v1
82  def testCreateModelFullyConnectedTraining(self):
83    model_settings = self._modelSettings()
84    with self.cached_session() as sess:
85      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
86      logits, dropout_rate = models.create_model(
87          fingerprint_input, model_settings, "single_fc", True)
88      self.assertIsNotNone(logits)
89      self.assertIsNotNone(dropout_rate)
90      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
91      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
92
93  def testCreateModelBadArchitecture(self):
94    model_settings = self._modelSettings()
95    with self.cached_session():
96      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
97      with self.assertRaises(Exception) as e:
98        models.create_model(fingerprint_input, model_settings,
99                            "bad_architecture", True)
100      self.assertIn("not recognized", str(e.exception))
101
102  @test_util.run_deprecated_v1
103  def testCreateModelTinyConvTraining(self):
104    model_settings = self._modelSettings()
105    with self.cached_session() as sess:
106      fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
107      logits, dropout_rate = models.create_model(
108          fingerprint_input, model_settings, "tiny_conv", True)
109      self.assertIsNotNone(logits)
110      self.assertIsNotNone(dropout_rate)
111      self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name))
112      self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name))
113
114
115if __name__ == "__main__":
116  test.main()
117