1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for speech commands models.""" 16 17import tensorflow as tf 18 19from tensorflow.examples.speech_commands import models 20from tensorflow.python.framework import test_util 21from tensorflow.python.platform import test 22 23 24class ModelsTest(test.TestCase): 25 26 def _modelSettings(self): 27 return models.prepare_model_settings( 28 label_count=10, 29 sample_rate=16000, 30 clip_duration_ms=1000, 31 window_size_ms=20, 32 window_stride_ms=10, 33 feature_bin_count=40, 34 preprocess="mfcc") 35 36 def testPrepareModelSettings(self): 37 self.assertIsNotNone( 38 models.prepare_model_settings( 39 label_count=10, 40 sample_rate=16000, 41 clip_duration_ms=1000, 42 window_size_ms=20, 43 window_stride_ms=10, 44 feature_bin_count=40, 45 preprocess="mfcc")) 46 47 @test_util.run_deprecated_v1 48 def testCreateModelConvTraining(self): 49 model_settings = self._modelSettings() 50 with self.cached_session() as sess: 51 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 52 logits, dropout_rate = models.create_model( 53 fingerprint_input, model_settings, "conv", True) 54 self.assertIsNotNone(logits) 55 self.assertIsNotNone(dropout_rate) 56 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 57 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 58 59 @test_util.run_deprecated_v1 60 def testCreateModelConvInference(self): 61 model_settings = self._modelSettings() 62 with self.cached_session() as sess: 63 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 64 logits = models.create_model(fingerprint_input, model_settings, "conv", 65 False) 66 self.assertIsNotNone(logits) 67 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 68 69 @test_util.run_deprecated_v1 70 def testCreateModelLowLatencyConvTraining(self): 71 model_settings = self._modelSettings() 72 with self.cached_session() as sess: 73 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 74 logits, dropout_rate = models.create_model( 75 fingerprint_input, model_settings, "low_latency_conv", True) 76 self.assertIsNotNone(logits) 77 self.assertIsNotNone(dropout_rate) 78 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 79 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 80 81 @test_util.run_deprecated_v1 82 def testCreateModelFullyConnectedTraining(self): 83 model_settings = self._modelSettings() 84 with self.cached_session() as sess: 85 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 86 logits, dropout_rate = models.create_model( 87 fingerprint_input, model_settings, "single_fc", True) 88 self.assertIsNotNone(logits) 89 self.assertIsNotNone(dropout_rate) 90 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 91 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 92 93 def testCreateModelBadArchitecture(self): 94 model_settings = self._modelSettings() 95 with self.cached_session(): 96 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 97 with self.assertRaises(Exception) as e: 98 models.create_model(fingerprint_input, model_settings, 99 "bad_architecture", True) 100 self.assertIn("not recognized", str(e.exception)) 101 102 @test_util.run_deprecated_v1 103 def testCreateModelTinyConvTraining(self): 104 model_settings = self._modelSettings() 105 with self.cached_session() as sess: 106 fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) 107 logits, dropout_rate = models.create_model( 108 fingerprint_input, model_settings, "tiny_conv", True) 109 self.assertIsNotNone(logits) 110 self.assertIsNotNone(dropout_rate) 111 self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) 112 self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_rate.name)) 113 114 115if __name__ == "__main__": 116 test.main() 117