• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model definitions for simple speech recognition.
16
17"""
18import hashlib
19import math
20import os.path
21import random
22import re
23import sys
24import tarfile
25
26import numpy as np
27import urllib
28import tensorflow as tf
29
30from tensorflow.python.ops import gen_audio_ops as audio_ops
31from tensorflow.python.ops import io_ops
32from tensorflow.python.platform import gfile
33from tensorflow.python.util import compat
34
35tf.compat.v1.disable_eager_execution()
36
37# If it's available, load the specialized feature generator. If this doesn't
38# work, try building with bazel instead of running the Python script directly.
39try:
40  from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op  # pylint:disable=g-import-not-at-top
41except ImportError:
42  frontend_op = None
43
44MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  # ~134M
45SILENCE_LABEL = '_silence_'
46SILENCE_INDEX = 0
47UNKNOWN_WORD_LABEL = '_unknown_'
48UNKNOWN_WORD_INDEX = 1
49BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
50RANDOM_SEED = 59185
51
52
53def prepare_words_list(wanted_words):
54  """Prepends common tokens to the custom word list.
55
56  Args:
57    wanted_words: List of strings containing the custom words.
58
59  Returns:
60    List with the standard silence and unknown tokens added.
61  """
62  return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words
63
64
65def which_set(filename, validation_percentage, testing_percentage):
66  """Determines which data partition the file should belong to.
67
68  We want to keep files in the same training, validation, or testing sets even
69  if new ones are added over time. This makes it less likely that testing
70  samples will accidentally be reused in training when long runs are restarted
71  for example. To keep this stability, a hash of the filename is taken and used
72  to determine which set it should belong to. This determination only depends on
73  the name and the set proportions, so it won't change as other files are added.
74
75  It's also useful to associate particular files as related (for example words
76  spoken by the same person), so anything after '_nohash_' in a filename is
77  ignored for set determination. This ensures that 'bobby_nohash_0.wav' and
78  'bobby_nohash_1.wav' are always in the same set, for example.
79
80  Args:
81    filename: File path of the data sample.
82    validation_percentage: How much of the data set to use for validation.
83    testing_percentage: How much of the data set to use for testing.
84
85  Returns:
86    String, one of 'training', 'validation', or 'testing'.
87  """
88  base_name = os.path.basename(filename)
89  # We want to ignore anything after '_nohash_' in the file name when
90  # deciding which set to put a wav in, so the data set creator has a way of
91  # grouping wavs that are close variations of each other.
92  hash_name = re.sub(r'_nohash_.*$', '', base_name)
93  # This looks a bit magical, but we need to decide whether this file should
94  # go into the training, testing, or validation sets, and we want to keep
95  # existing files in the same set even if more files are subsequently
96  # added.
97  # To do that, we need a stable way of deciding based on just the file name
98  # itself, so we do a hash of that and then use that to generate a
99  # probability value that we use to assign it.
100  hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
101  percentage_hash = ((int(hash_name_hashed, 16) %
102                      (MAX_NUM_WAVS_PER_CLASS + 1)) *
103                     (100.0 / MAX_NUM_WAVS_PER_CLASS))
104  if percentage_hash < validation_percentage:
105    result = 'validation'
106  elif percentage_hash < (testing_percentage + validation_percentage):
107    result = 'testing'
108  else:
109    result = 'training'
110  return result
111
112
113def load_wav_file(filename):
114  """Loads an audio file and returns a float PCM-encoded array of samples.
115
116  Args:
117    filename: Path to the .wav file to load.
118
119  Returns:
120    Numpy array holding the sample data as floats between -1.0 and 1.0.
121  """
122  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
123    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
124    wav_loader = io_ops.read_file(wav_filename_placeholder)
125    wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
126    return sess.run(
127        wav_decoder,
128        feed_dict={wav_filename_placeholder: filename}).audio.flatten()
129
130
131def save_wav_file(filename, wav_data, sample_rate):
132  """Saves audio sample data to a .wav audio file.
133
134  Args:
135    filename: Path to save the file to.
136    wav_data: 2D array of float PCM-encoded audio data.
137    sample_rate: Samples per second to encode in the file.
138  """
139  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
140    wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
141    sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, [])
142    wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1])
143    wav_encoder = tf.audio.encode_wav(wav_data_placeholder,
144                                      sample_rate_placeholder)
145    wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder)
146    sess.run(
147        wav_saver,
148        feed_dict={
149            wav_filename_placeholder: filename,
150            sample_rate_placeholder: sample_rate,
151            wav_data_placeholder: np.reshape(wav_data, (-1, 1))
152        })
153
154
155def get_features_range(model_settings):
156  """Returns the expected min/max for generated features.
157
158  Args:
159    model_settings: Information about the current model being trained.
160
161  Returns:
162    Min/max float pair holding the range of features.
163
164  Raises:
165    Exception: If preprocessing mode isn't recognized.
166  """
167  # TODO(petewarden): These values have been derived from the observed ranges
168  # of spectrogram and MFCC inputs. If the preprocessing pipeline changes,
169  # they may need to be updated.
170  if model_settings['preprocess'] == 'average':
171    features_min = 0.0
172    features_max = 127.5
173  elif model_settings['preprocess'] == 'mfcc':
174    features_min = -247.0
175    features_max = 30.0
176  elif model_settings['preprocess'] == 'micro':
177    features_min = 0.0
178    features_max = 26.0
179  else:
180    raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
181                    ' "average", or "micro")' % (model_settings['preprocess']))
182  return features_min, features_max
183
184
185class AudioProcessor(object):
186  """Handles loading, partitioning, and preparing audio training data."""
187
188  def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage,
189               wanted_words, validation_percentage, testing_percentage,
190               model_settings, summaries_dir):
191    if data_dir:
192      self.data_dir = data_dir
193      self.maybe_download_and_extract_dataset(data_url, data_dir)
194      self.prepare_data_index(silence_percentage, unknown_percentage,
195                              wanted_words, validation_percentage,
196                              testing_percentage)
197      self.prepare_background_data()
198    self.prepare_processing_graph(model_settings, summaries_dir)
199
200  def maybe_download_and_extract_dataset(self, data_url, dest_directory):
201    """Download and extract data set tar file.
202
203    If the data set we're using doesn't already exist, this function
204    downloads it from the TensorFlow.org website and unpacks it into a
205    directory.
206    If the data_url is none, don't download anything and expect the data
207    directory to contain the correct files already.
208
209    Args:
210      data_url: Web location of the tar file containing the data set.
211      dest_directory: File path to extract data to.
212    """
213    if not data_url:
214      return
215    if not gfile.Exists(dest_directory):
216      os.makedirs(dest_directory)
217    filename = data_url.split('/')[-1]
218    filepath = os.path.join(dest_directory, filename)
219    if not gfile.Exists(filepath):
220
221      def _progress(count, block_size, total_size):
222        sys.stdout.write(
223            '\r>> Downloading %s %.1f%%' %
224            (filename, float(count * block_size) / float(total_size) * 100.0))
225        sys.stdout.flush()
226
227      try:
228        filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
229      except:
230        tf.compat.v1.logging.error(
231            'Failed to download URL: {0} to folder: {1}. Please make sure you '
232            'have enough free space and an internet connection'.format(
233                data_url, filepath))
234        raise
235      print()
236      statinfo = os.stat(filepath)
237      tf.compat.v1.logging.info(
238          'Successfully downloaded {0} ({1} bytes)'.format(
239              filename, statinfo.st_size))
240      tarfile.open(filepath, 'r:gz').extractall(dest_directory)
241
242  def prepare_data_index(self, silence_percentage, unknown_percentage,
243                         wanted_words, validation_percentage,
244                         testing_percentage):
245    """Prepares a list of the samples organized by set and label.
246
247    The training loop needs a list of all the available data, organized by
248    which partition it should belong to, and with ground truth labels attached.
249    This function analyzes the folders below the `data_dir`, figures out the
250    right
251    labels for each file based on the name of the subdirectory it belongs to,
252    and uses a stable hash to assign it to a data set partition.
253
254    Args:
255      silence_percentage: How much of the resulting data should be background.
256      unknown_percentage: How much should be audio outside the wanted classes.
257      wanted_words: Labels of the classes we want to be able to recognize.
258      validation_percentage: How much of the data set to use for validation.
259      testing_percentage: How much of the data set to use for testing.
260
261    Returns:
262      Dictionary containing a list of file information for each set partition,
263      and a lookup map for each class to determine its numeric index.
264
265    Raises:
266      Exception: If expected files are not found.
267    """
268    # Make sure the shuffling and picking of unknowns is deterministic.
269    random.seed(RANDOM_SEED)
270    wanted_words_index = {}
271    for index, wanted_word in enumerate(wanted_words):
272      wanted_words_index[wanted_word] = index + 2
273    self.data_index = {'validation': [], 'testing': [], 'training': []}
274    unknown_index = {'validation': [], 'testing': [], 'training': []}
275    all_words = {}
276    # Look through all the subfolders to find audio samples
277    search_path = os.path.join(self.data_dir, '*', '*.wav')
278    for wav_path in gfile.Glob(search_path):
279      _, word = os.path.split(os.path.dirname(wav_path))
280      word = word.lower()
281      # Treat the '_background_noise_' folder as a special case, since we expect
282      # it to contain long audio samples we mix in to improve training.
283      if word == BACKGROUND_NOISE_DIR_NAME:
284        continue
285      all_words[word] = True
286      set_index = which_set(wav_path, validation_percentage, testing_percentage)
287      # If it's a known class, store its detail, otherwise add it to the list
288      # we'll use to train the unknown label.
289      if word in wanted_words_index:
290        self.data_index[set_index].append({'label': word, 'file': wav_path})
291      else:
292        unknown_index[set_index].append({'label': word, 'file': wav_path})
293    if not all_words:
294      raise Exception('No .wavs found at ' + search_path)
295    for index, wanted_word in enumerate(wanted_words):
296      if wanted_word not in all_words:
297        raise Exception('Expected to find ' + wanted_word +
298                        ' in labels but only found ' +
299                        ', '.join(all_words.keys()))
300    # We need an arbitrary file to load as the input for the silence samples.
301    # It's multiplied by zero later, so the content doesn't matter.
302    silence_wav_path = self.data_index['training'][0]['file']
303    for set_index in ['validation', 'testing', 'training']:
304      set_size = len(self.data_index[set_index])
305      silence_size = int(math.ceil(set_size * silence_percentage / 100))
306      for _ in range(silence_size):
307        self.data_index[set_index].append({
308            'label': SILENCE_LABEL,
309            'file': silence_wav_path
310        })
311      # Pick some unknowns to add to each partition of the data set.
312      random.shuffle(unknown_index[set_index])
313      unknown_size = int(math.ceil(set_size * unknown_percentage / 100))
314      self.data_index[set_index].extend(unknown_index[set_index][:unknown_size])
315    # Make sure the ordering is random.
316    for set_index in ['validation', 'testing', 'training']:
317      random.shuffle(self.data_index[set_index])
318    # Prepare the rest of the result data structure.
319    self.words_list = prepare_words_list(wanted_words)
320    self.word_to_index = {}
321    for word in all_words:
322      if word in wanted_words_index:
323        self.word_to_index[word] = wanted_words_index[word]
324      else:
325        self.word_to_index[word] = UNKNOWN_WORD_INDEX
326    self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX
327
328  def prepare_background_data(self):
329    """Searches a folder for background noise audio, and loads it into memory.
330
331    It's expected that the background audio samples will be in a subdirectory
332    named '_background_noise_' inside the 'data_dir' folder, as .wavs that match
333    the sample rate of the training data, but can be much longer in duration.
334
335    If the '_background_noise_' folder doesn't exist at all, this isn't an
336    error, it's just taken to mean that no background noise augmentation should
337    be used. If the folder does exist, but it's empty, that's treated as an
338    error.
339
340    Returns:
341      List of raw PCM-encoded audio samples of background noise.
342
343    Raises:
344      Exception: If files aren't found in the folder.
345    """
346    self.background_data = []
347    background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME)
348    if not gfile.Exists(background_dir):
349      return self.background_data
350    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
351      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
352      wav_loader = io_ops.read_file(wav_filename_placeholder)
353      wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1)
354      search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME,
355                                 '*.wav')
356      for wav_path in gfile.Glob(search_path):
357        wav_data = sess.run(
358            wav_decoder,
359            feed_dict={wav_filename_placeholder: wav_path}).audio.flatten()
360        self.background_data.append(wav_data)
361      if not self.background_data:
362        raise Exception('No background wav files were found in ' + search_path)
363
364  def prepare_processing_graph(self, model_settings, summaries_dir):
365    """Builds a TensorFlow graph to apply the input distortions.
366
367    Creates a graph that loads a WAVE file, decodes it, scales the volume,
368    shifts it in time, adds in background noise, calculates a spectrogram, and
369    then builds an MFCC fingerprint from that.
370
371    This must be called with an active TensorFlow session running, and it
372    creates multiple placeholder inputs, and one output:
373
374      - wav_filename_placeholder_: Filename of the WAV to load.
375      - foreground_volume_placeholder_: How loud the main clip should be.
376      - time_shift_padding_placeholder_: Where to pad the clip.
377      - time_shift_offset_placeholder_: How much to move the clip in time.
378      - background_data_placeholder_: PCM sample data for background noise.
379      - background_volume_placeholder_: Loudness of mixed-in background.
380      - output_: Output 2D fingerprint of processed audio.
381
382    Args:
383      model_settings: Information about the current model being trained.
384      summaries_dir: Path to save training summary information to.
385
386    Raises:
387      ValueError: If the preprocessing mode isn't recognized.
388      Exception: If the preprocessor wasn't compiled in.
389    """
390    with tf.compat.v1.get_default_graph().name_scope('data'):
391      desired_samples = model_settings['desired_samples']
392      self.wav_filename_placeholder_ = tf.compat.v1.placeholder(
393          tf.string, [], name='wav_filename')
394      wav_loader = io_ops.read_file(self.wav_filename_placeholder_)
395      wav_decoder = tf.audio.decode_wav(
396          wav_loader, desired_channels=1, desired_samples=desired_samples)
397      # Allow the audio sample's volume to be adjusted.
398      self.foreground_volume_placeholder_ = tf.compat.v1.placeholder(
399          tf.float32, [], name='foreground_volume')
400      scaled_foreground = tf.multiply(wav_decoder.audio,
401                                      self.foreground_volume_placeholder_)
402      # Shift the sample's start position, and pad any gaps with zeros.
403      self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder(
404          tf.int32, [2, 2], name='time_shift_padding')
405      self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder(
406          tf.int32, [2], name='time_shift_offset')
407      padded_foreground = tf.pad(
408          tensor=scaled_foreground,
409          paddings=self.time_shift_padding_placeholder_,
410          mode='CONSTANT')
411      sliced_foreground = tf.slice(padded_foreground,
412                                   self.time_shift_offset_placeholder_,
413                                   [desired_samples, -1])
414      # Mix in background noise.
415      self.background_data_placeholder_ = tf.compat.v1.placeholder(
416          tf.float32, [desired_samples, 1], name='background_data')
417      self.background_volume_placeholder_ = tf.compat.v1.placeholder(
418          tf.float32, [], name='background_volume')
419      background_mul = tf.multiply(self.background_data_placeholder_,
420                                   self.background_volume_placeholder_)
421      background_add = tf.add(background_mul, sliced_foreground)
422      background_clamp = tf.clip_by_value(background_add, -1.0, 1.0)
423      # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio.
424      spectrogram = audio_ops.audio_spectrogram(
425          background_clamp,
426          window_size=model_settings['window_size_samples'],
427          stride=model_settings['window_stride_samples'],
428          magnitude_squared=True)
429      tf.compat.v1.summary.image(
430          'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1)
431      # The number of buckets in each FFT row in the spectrogram will depend on
432      # how many input samples there are in each window. This can be quite
433      # large, with a 160 sample window producing 127 buckets for example. We
434      # don't need this level of detail for classification, so we often want to
435      # shrink them down to produce a smaller result. That's what this section
436      # implements. One method is to use average pooling to merge adjacent
437      # buckets, but a more sophisticated approach is to apply the MFCC
438      # algorithm to shrink the representation.
439      if model_settings['preprocess'] == 'average':
440        self.output_ = tf.nn.pool(
441            input=tf.expand_dims(spectrogram, -1),
442            window_shape=[1, model_settings['average_window_width']],
443            strides=[1, model_settings['average_window_width']],
444            pooling_type='AVG',
445            padding='SAME')
446        tf.compat.v1.summary.image('shrunk_spectrogram',
447                                   self.output_,
448                                   max_outputs=1)
449      elif model_settings['preprocess'] == 'mfcc':
450        self.output_ = audio_ops.mfcc(
451            spectrogram,
452            wav_decoder.sample_rate,
453            dct_coefficient_count=model_settings['fingerprint_width'])
454        tf.compat.v1.summary.image(
455            'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1)
456      elif model_settings['preprocess'] == 'micro':
457        if not frontend_op:
458          raise Exception(
459              'Micro frontend op is currently not available when running'
460              ' TensorFlow directly from Python, you need to build and run'
461              ' through Bazel')
462        sample_rate = model_settings['sample_rate']
463        window_size_ms = (model_settings['window_size_samples'] *
464                          1000) / sample_rate
465        window_step_ms = (model_settings['window_stride_samples'] *
466                          1000) / sample_rate
467        int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16)
468        micro_frontend = frontend_op.audio_microfrontend(
469            int16_input,
470            sample_rate=sample_rate,
471            window_size=window_size_ms,
472            window_step=window_step_ms,
473            num_channels=model_settings['fingerprint_width'],
474            out_scale=1,
475            out_type=tf.float32)
476        self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0))
477        tf.compat.v1.summary.image(
478            'micro',
479            tf.expand_dims(tf.expand_dims(self.output_, -1), 0),
480            max_outputs=1)
481      else:
482        raise ValueError('Unknown preprocess mode "%s" (should be "mfcc", '
483                         ' "average", or "micro")' %
484                         (model_settings['preprocess']))
485
486      # Merge all the summaries and write them out to /tmp/retrain_logs (by
487      # default)
488      self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data')
489      if summaries_dir:
490        self.summary_writer_ = tf.compat.v1.summary.FileWriter(
491            summaries_dir + '/data', tf.compat.v1.get_default_graph())
492
493  def set_size(self, mode):
494    """Calculates the number of samples in the dataset partition.
495
496    Args:
497      mode: Which partition, must be 'training', 'validation', or 'testing'.
498
499    Returns:
500      Number of samples in the partition.
501    """
502    return len(self.data_index[mode])
503
504  def get_data(self, how_many, offset, model_settings, background_frequency,
505               background_volume_range, time_shift, mode, sess):
506    """Gather samples from the data set, applying transformations as needed.
507
508    When the mode is 'training', a random selection of samples will be returned,
509    otherwise the first N clips in the partition will be used. This ensures that
510    validation always uses the same samples, reducing noise in the metrics.
511
512    Args:
513      how_many: Desired number of samples to return. -1 means the entire
514        contents of this partition.
515      offset: Where to start when fetching deterministically.
516      model_settings: Information about the current model being trained.
517      background_frequency: How many clips will have background noise, 0.0 to
518        1.0.
519      background_volume_range: How loud the background noise will be.
520      time_shift: How much to randomly shift the clips by in time.
521      mode: Which partition to use, must be 'training', 'validation', or
522        'testing'.
523      sess: TensorFlow session that was active when processor was created.
524
525    Returns:
526      List of sample data for the transformed samples, and list of label indexes
527
528    Raises:
529      ValueError: If background samples are too short.
530    """
531    # Pick one of the partitions to choose samples from.
532    candidates = self.data_index[mode]
533    if how_many == -1:
534      sample_count = len(candidates)
535    else:
536      sample_count = max(0, min(how_many, len(candidates) - offset))
537    # Data and labels will be populated and returned.
538    data = np.zeros((sample_count, model_settings['fingerprint_size']))
539    labels = np.zeros(sample_count)
540    desired_samples = model_settings['desired_samples']
541    use_background = self.background_data and (mode == 'training')
542    pick_deterministically = (mode != 'training')
543    # Use the processing graph we created earlier to repeatedly to generate the
544    # final output sample data we'll use in training.
545    for i in range(offset, offset + sample_count):
546      # Pick which audio sample to use.
547      if how_many == -1 or pick_deterministically:
548        sample_index = i
549      else:
550        sample_index = np.random.randint(len(candidates))
551      sample = candidates[sample_index]
552      # If we're time shifting, set up the offset for this sample.
553      if time_shift > 0:
554        time_shift_amount = np.random.randint(-time_shift, time_shift)
555      else:
556        time_shift_amount = 0
557      if time_shift_amount > 0:
558        time_shift_padding = [[time_shift_amount, 0], [0, 0]]
559        time_shift_offset = [0, 0]
560      else:
561        time_shift_padding = [[0, -time_shift_amount], [0, 0]]
562        time_shift_offset = [-time_shift_amount, 0]
563      input_dict = {
564          self.wav_filename_placeholder_: sample['file'],
565          self.time_shift_padding_placeholder_: time_shift_padding,
566          self.time_shift_offset_placeholder_: time_shift_offset,
567      }
568      # Choose a section of background noise to mix in.
569      if use_background or sample['label'] == SILENCE_LABEL:
570        background_index = np.random.randint(len(self.background_data))
571        background_samples = self.background_data[background_index]
572        if len(background_samples) <= model_settings['desired_samples']:
573          raise ValueError(
574              'Background sample is too short! Need more than %d'
575              ' samples but only %d were found' %
576              (model_settings['desired_samples'], len(background_samples)))
577        background_offset = np.random.randint(
578            0, len(background_samples) - model_settings['desired_samples'])
579        background_clipped = background_samples[background_offset:(
580            background_offset + desired_samples)]
581        background_reshaped = background_clipped.reshape([desired_samples, 1])
582        if sample['label'] == SILENCE_LABEL:
583          background_volume = np.random.uniform(0, 1)
584        elif np.random.uniform(0, 1) < background_frequency:
585          background_volume = np.random.uniform(0, background_volume_range)
586        else:
587          background_volume = 0
588      else:
589        background_reshaped = np.zeros([desired_samples, 1])
590        background_volume = 0
591      input_dict[self.background_data_placeholder_] = background_reshaped
592      input_dict[self.background_volume_placeholder_] = background_volume
593      # If we want silence, mute out the main sample but leave the background.
594      if sample['label'] == SILENCE_LABEL:
595        input_dict[self.foreground_volume_placeholder_] = 0
596      else:
597        input_dict[self.foreground_volume_placeholder_] = 1
598      # Run the graph to produce the output audio.
599      summary, data_tensor = sess.run(
600          [self.merged_summaries_, self.output_], feed_dict=input_dict)
601      self.summary_writer_.add_summary(summary)
602      data[i - offset, :] = data_tensor.flatten()
603      label_index = self.word_to_index[sample['label']]
604      labels[i - offset] = label_index
605    return data, labels
606
607  def get_features_for_wav(self, wav_filename, model_settings, sess):
608    """Applies the feature transformation process to the input_wav.
609
610    Runs the feature generation process (generally producing a spectrogram from
611    the input samples) on the WAV file. This can be useful for testing and
612    verifying implementations being run on other platforms.
613
614    Args:
615      wav_filename: The path to the input audio file.
616      model_settings: Information about the current model being trained.
617      sess: TensorFlow session that was active when processor was created.
618
619    Returns:
620      Numpy data array containing the generated features.
621    """
622    desired_samples = model_settings['desired_samples']
623    input_dict = {
624        self.wav_filename_placeholder_: wav_filename,
625        self.time_shift_padding_placeholder_: [[0, 0], [0, 0]],
626        self.time_shift_offset_placeholder_: [0, 0],
627        self.background_data_placeholder_: np.zeros([desired_samples, 1]),
628        self.background_volume_placeholder_: 0,
629        self.foreground_volume_placeholder_: 1,
630    }
631    # Run the graph to produce the output audio.
632    data_tensor = sess.run([self.output_], feed_dict=input_dict)
633    return data_tensor
634
635  def get_unprocessed_data(self, how_many, model_settings, mode):
636    """Retrieve sample data for the given partition, with no transformations.
637
638    Args:
639      how_many: Desired number of samples to return. -1 means the entire
640        contents of this partition.
641      model_settings: Information about the current model being trained.
642      mode: Which partition to use, must be 'training', 'validation', or
643        'testing'.
644
645    Returns:
646      List of sample data for the samples, and list of labels in one-hot form.
647    """
648    candidates = self.data_index[mode]
649    if how_many == -1:
650      sample_count = len(candidates)
651    else:
652      sample_count = how_many
653    desired_samples = model_settings['desired_samples']
654    words_list = self.words_list
655    data = np.zeros((sample_count, desired_samples))
656    labels = []
657    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
658      wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, [])
659      wav_loader = io_ops.read_file(wav_filename_placeholder)
660      wav_decoder = tf.audio.decode_wav(
661          wav_loader, desired_channels=1, desired_samples=desired_samples)
662      foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, [])
663      scaled_foreground = tf.multiply(wav_decoder.audio,
664                                      foreground_volume_placeholder)
665      for i in range(sample_count):
666        if how_many == -1:
667          sample_index = i
668        else:
669          sample_index = np.random.randint(len(candidates))
670        sample = candidates[sample_index]
671        input_dict = {wav_filename_placeholder: sample['file']}
672        if sample['label'] == SILENCE_LABEL:
673          input_dict[foreground_volume_placeholder] = 0
674        else:
675          input_dict[foreground_volume_placeholder] = 1
676        data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten()
677        label_index = self.word_to_index[sample['label']]
678        labels.append(words_list[label_index])
679    return data, labels
680