• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts WAV audio files into input features for neural networks.
16
17The models used in this example take in two-dimensional spectrograms as the
18input to their neural network portions. For testing and porting purposes it's
19useful to be able to generate these spectrograms outside of the full model, so
20that on-device implementations using their own FFT and streaming code can be
21tested against the version used in training for example. The output is as a
22C source file, so it can be easily linked into an embedded test application.
23
24To use this, run:
25
26bazel run tensorflow/examples/speech_commands:wav_to_features -- \
27--input_wav=my.wav --output_c_file=my_wav_data.c
28
29"""
30import argparse
31import os.path
32import sys
33
34import tensorflow as tf
35
36import input_data
37import models
38from tensorflow.python.platform import gfile
39
40FLAGS = None
41
42
43def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
44                    window_stride_ms, feature_bin_count, quantize, preprocess,
45                    input_wav, output_c_file):
46  """Converts an audio file into its corresponding feature map.
47
48  Args:
49    sample_rate: Expected sample rate of the wavs.
50    clip_duration_ms: Expected duration in milliseconds of the wavs.
51    window_size_ms: How long each spectrogram timeslice is.
52    window_stride_ms: How far to move in time between spectrogram timeslices.
53    feature_bin_count: How many bins to use for the feature fingerprint.
54    quantize: Whether to train the model for eight-bit deployment.
55    preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
56    input_wav: Path to the audio WAV file to read.
57    output_c_file: Where to save the generated C source file.
58  """
59
60  # Start a new TensorFlow session.
61  sess = tf.compat.v1.InteractiveSession()
62
63  model_settings = models.prepare_model_settings(
64      0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
65      feature_bin_count, preprocess)
66  audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
67                                              model_settings, None)
68
69  results = audio_processor.get_features_for_wav(input_wav, model_settings,
70                                                 sess)
71  features = results[0]
72
73  variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]
74
75  # Save a C source file containing the feature data as an array.
76  with gfile.GFile(output_c_file, 'w') as f:
77    f.write('/* File automatically created by\n')
78    f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
79    f.write(' * --sample_rate=%d \\\n' % sample_rate)
80    f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
81    f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
82    f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
83    f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
84    if quantize:
85      f.write(' * --quantize=1 \\\n')
86    f.write(' * --preprocess="%s" \\\n' % preprocess)
87    f.write(' * --input_wav="%s" \\\n' % input_wav)
88    f.write(' * --output_c_file="%s" \\\n' % output_c_file)
89    f.write(' */\n\n')
90    f.write('const int g_%s_width = %d;\n' %
91            (variable_base, model_settings['fingerprint_width']))
92    f.write('const int g_%s_height = %d;\n' %
93            (variable_base, model_settings['spectrogram_length']))
94    if quantize:
95      features_min, features_max = input_data.get_features_range(model_settings)
96      f.write('const unsigned char g_%s_data[] = {' % variable_base)
97      i = 0
98      for value in features.flatten():
99        quantized_value = int(
100            round(
101                (255 * (value - features_min)) / (features_max - features_min)))
102        if quantized_value < 0:
103          quantized_value = 0
104        if quantized_value > 255:
105          quantized_value = 255
106        if i == 0:
107          f.write('\n  ')
108        f.write('%d, ' % (quantized_value))
109        i = (i + 1) % 10
110    else:
111      f.write('const float g_%s_data[] = {\n' % variable_base)
112      i = 0
113      for value in features.flatten():
114        if i == 0:
115          f.write('\n  ')
116        f.write('%f, ' % value)
117        i = (i + 1) % 10
118    f.write('\n};\n')
119
120
121def main(_):
122  # We want to see all the logging messages.
123  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
124  wav_to_features(FLAGS.sample_rate, FLAGS.clip_duration_ms,
125                  FLAGS.window_size_ms, FLAGS.window_stride_ms,
126                  FLAGS.feature_bin_count, FLAGS.quantize, FLAGS.preprocess,
127                  FLAGS.input_wav, FLAGS.output_c_file)
128  tf.compat.v1.logging.info('Wrote to "%s"' % (FLAGS.output_c_file))
129
130
131if __name__ == '__main__':
132  parser = argparse.ArgumentParser()
133  parser.add_argument(
134      '--sample_rate',
135      type=int,
136      default=16000,
137      help='Expected sample rate of the wavs',)
138  parser.add_argument(
139      '--clip_duration_ms',
140      type=int,
141      default=1000,
142      help='Expected duration in milliseconds of the wavs',)
143  parser.add_argument(
144      '--window_size_ms',
145      type=float,
146      default=30.0,
147      help='How long each spectrogram timeslice is.',)
148  parser.add_argument(
149      '--window_stride_ms',
150      type=float,
151      default=10.0,
152      help='How far to move in time between spectrogram timeslices.',
153  )
154  parser.add_argument(
155      '--feature_bin_count',
156      type=int,
157      default=40,
158      help='How many bins to use for the MFCC fingerprint',
159  )
160  parser.add_argument(
161      '--quantize',
162      type=bool,
163      default=False,
164      help='Whether to train the model for eight-bit deployment')
165  parser.add_argument(
166      '--preprocess',
167      type=str,
168      default='mfcc',
169      help='Spectrogram processing mode. Can be "mfcc", "average", or "micro"')
170  parser.add_argument(
171      '--input_wav',
172      type=str,
173      default=None,
174      help='Path to the audio WAV file to read')
175  parser.add_argument(
176      '--output_c_file',
177      type=str,
178      default=None,
179      help='Where to save the generated C source file containing the features')
180
181  FLAGS, unparsed = parser.parse_known_args()
182  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
183