1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts WAV audio files into input features for neural networks. 16 17The models used in this example take in two-dimensional spectrograms as the 18input to their neural network portions. For testing and porting purposes it's 19useful to be able to generate these spectrograms outside of the full model, so 20that on-device implementations using their own FFT and streaming code can be 21tested against the version used in training for example. The output is as a 22C source file, so it can be easily linked into an embedded test application. 23 24To use this, run: 25 26bazel run tensorflow/examples/speech_commands:wav_to_features -- \ 27--input_wav=my.wav --output_c_file=my_wav_data.c 28 29""" 30import argparse 31import os.path 32import sys 33 34import tensorflow as tf 35 36import input_data 37import models 38from tensorflow.python.platform import gfile 39 40FLAGS = None 41 42 43def wav_to_features(sample_rate, clip_duration_ms, window_size_ms, 44 window_stride_ms, feature_bin_count, quantize, preprocess, 45 input_wav, output_c_file): 46 """Converts an audio file into its corresponding feature map. 47 48 Args: 49 sample_rate: Expected sample rate of the wavs. 50 clip_duration_ms: Expected duration in milliseconds of the wavs. 51 window_size_ms: How long each spectrogram timeslice is. 52 window_stride_ms: How far to move in time between spectrogram timeslices. 53 feature_bin_count: How many bins to use for the feature fingerprint. 54 quantize: Whether to train the model for eight-bit deployment. 55 preprocess: Spectrogram processing mode; "mfcc", "average" or "micro". 56 input_wav: Path to the audio WAV file to read. 57 output_c_file: Where to save the generated C source file. 58 """ 59 60 # Start a new TensorFlow session. 61 sess = tf.compat.v1.InteractiveSession() 62 63 model_settings = models.prepare_model_settings( 64 0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, 65 feature_bin_count, preprocess) 66 audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0, 67 model_settings, None) 68 69 results = audio_processor.get_features_for_wav(input_wav, model_settings, 70 sess) 71 features = results[0] 72 73 variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0] 74 75 # Save a C source file containing the feature data as an array. 76 with gfile.GFile(output_c_file, 'w') as f: 77 f.write('/* File automatically created by\n') 78 f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n') 79 f.write(' * --sample_rate=%d \\\n' % sample_rate) 80 f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms) 81 f.write(' * --window_size_ms=%d \\\n' % window_size_ms) 82 f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms) 83 f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count) 84 if quantize: 85 f.write(' * --quantize=1 \\\n') 86 f.write(' * --preprocess="%s" \\\n' % preprocess) 87 f.write(' * --input_wav="%s" \\\n' % input_wav) 88 f.write(' * --output_c_file="%s" \\\n' % output_c_file) 89 f.write(' */\n\n') 90 f.write('const int g_%s_width = %d;\n' % 91 (variable_base, model_settings['fingerprint_width'])) 92 f.write('const int g_%s_height = %d;\n' % 93 (variable_base, model_settings['spectrogram_length'])) 94 if quantize: 95 features_min, features_max = input_data.get_features_range(model_settings) 96 f.write('const unsigned char g_%s_data[] = {' % variable_base) 97 i = 0 98 for value in features.flatten(): 99 quantized_value = int( 100 round( 101 (255 * (value - features_min)) / (features_max - features_min))) 102 if quantized_value < 0: 103 quantized_value = 0 104 if quantized_value > 255: 105 quantized_value = 255 106 if i == 0: 107 f.write('\n ') 108 f.write('%d, ' % (quantized_value)) 109 i = (i + 1) % 10 110 else: 111 f.write('const float g_%s_data[] = {\n' % variable_base) 112 i = 0 113 for value in features.flatten(): 114 if i == 0: 115 f.write('\n ') 116 f.write('%f, ' % value) 117 i = (i + 1) % 10 118 f.write('\n};\n') 119 120 121def main(_): 122 # We want to see all the logging messages. 123 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 124 wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms, 125 FLAGS.window_size_ms, FLAGS.window_stride_ms, 126 FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess, 127 FLAGS.input_wav, FLAGS.output_c_file) 128 tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file)) 129 130 131if __name__ == '__main__': 132 parser = argparse.ArgumentParser() 133 parser.add_argument( 134 '--sample_rate', 135 type=int, 136 default=16000, 137 help='Expected sample rate of the wavs',) 138 parser.add_argument( 139 '--clip_duration_ms', 140 type=int, 141 default=1000, 142 help='Expected duration in milliseconds of the wavs',) 143 parser.add_argument( 144 '--window_size_ms', 145 type=float, 146 default=30.0, 147 help='How long each spectrogram timeslice is.',) 148 parser.add_argument( 149 '--window_stride_ms', 150 type=float, 151 default=10.0, 152 help='How far to move in time between spectrogram timeslices.', 153 ) 154 parser.add_argument( 155 '--feature_bin_count', 156 type=int, 157 default=40, 158 help='How many bins to use for the MFCC fingerprint', 159 ) 160 parser.add_argument( 161 '--quantize', 162 type=bool, 163 default=False, 164 help='Whether to train the model for eight-bit deployment') 165 parser.add_argument( 166 '--preprocess', 167 type=str, 168 default='mfcc', 169 help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"') 170 parser.add_argument( 171 '--input_wav', 172 type=str, 173 default=None, 174 help='Path to the audio WAV file to read') 175 parser.add_argument( 176 '--output_c_file', 177 type=str, 178 default=None, 179 help='Where to save the generated C source file containing the features') 180 181 FLAGS, unparsed = parser.parse_known_args() 182 tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) 183