1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Model definitions for simple speech recognition. 16 17""" 18import hashlib 19import math 20import os.path 21import random 22import re 23import sys 24import tarfile 25 26import numpy as np 27import urllib 28import tensorflow as tf 29 30from tensorflow.python.ops import gen_audio_ops as audio_ops 31from tensorflow.python.ops import io_ops 32from tensorflow.python.platform import gfile 33from tensorflow.python.util import compat 34 35tf.compat.v1.disable_eager_execution() 36 37# If it's available, load the specialized feature generator. If this doesn't 38# work, try building with bazel instead of running the Python script directly. 39try: 40 from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top 41except ImportError: 42 frontend_op = None 43 44MAX_NUM_WAVS_PER_CLASS = 2**27 - 1 # ~134M 45SILENCE_LABEL = '_silence_' 46SILENCE_INDEX = 0 47UNKNOWN_WORD_LABEL = '_unknown_' 48UNKNOWN_WORD_INDEX = 1 49BACKGROUND_NOISE_DIR_NAME = '_background_noise_' 50RANDOM_SEED = 59185 51 52 53def prepare_words_list(wanted_words): 54 """Prepends common tokens to the custom word list. 55 56 Args: 57 wanted_words: List of strings containing the custom words. 58 59 Returns: 60 List with the standard silence and unknown tokens added. 61 """ 62 return [SILENCE_LABEL, UNKNOWN_WORD_LABEL] + wanted_words 63 64 65def which_set(filename, validation_percentage, testing_percentage): 66 """Determines which data partition the file should belong to. 67 68 We want to keep files in the same training, validation, or testing sets even 69 if new ones are added over time. This makes it less likely that testing 70 samples will accidentally be reused in training when long runs are restarted 71 for example. To keep this stability, a hash of the filename is taken and used 72 to determine which set it should belong to. This determination only depends on 73 the name and the set proportions, so it won't change as other files are added. 74 75 It's also useful to associate particular files as related (for example words 76 spoken by the same person), so anything after '_nohash_' in a filename is 77 ignored for set determination. This ensures that 'bobby_nohash_0.wav' and 78 'bobby_nohash_1.wav' are always in the same set, for example. 79 80 Args: 81 filename: File path of the data sample. 82 validation_percentage: How much of the data set to use for validation. 83 testing_percentage: How much of the data set to use for testing. 84 85 Returns: 86 String, one of 'training', 'validation', or 'testing'. 87 """ 88 base_name = os.path.basename(filename) 89 # We want to ignore anything after '_nohash_' in the file name when 90 # deciding which set to put a wav in, so the data set creator has a way of 91 # grouping wavs that are close variations of each other. 92 hash_name = re.sub(r'_nohash_.*$', '', base_name) 93 # This looks a bit magical, but we need to decide whether this file should 94 # go into the training, testing, or validation sets, and we want to keep 95 # existing files in the same set even if more files are subsequently 96 # added. 97 # To do that, we need a stable way of deciding based on just the file name 98 # itself, so we do a hash of that and then use that to generate a 99 # probability value that we use to assign it. 100 hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 101 percentage_hash = ((int(hash_name_hashed, 16) % 102 (MAX_NUM_WAVS_PER_CLASS + 1)) * 103 (100.0 / MAX_NUM_WAVS_PER_CLASS)) 104 if percentage_hash < validation_percentage: 105 result = 'validation' 106 elif percentage_hash < (testing_percentage + validation_percentage): 107 result = 'testing' 108 else: 109 result = 'training' 110 return result 111 112 113def load_wav_file(filename): 114 """Loads an audio file and returns a float PCM-encoded array of samples. 115 116 Args: 117 filename: Path to the .wav file to load. 118 119 Returns: 120 Numpy array holding the sample data as floats between -1.0 and 1.0. 121 """ 122 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 123 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 124 wav_loader = io_ops.read_file(wav_filename_placeholder) 125 wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1) 126 return sess.run( 127 wav_decoder, 128 feed_dict={wav_filename_placeholder: filename}).audio.flatten() 129 130 131def save_wav_file(filename, wav_data, sample_rate): 132 """Saves audio sample data to a .wav audio file. 133 134 Args: 135 filename: Path to save the file to. 136 wav_data: 2D array of float PCM-encoded audio data. 137 sample_rate: Samples per second to encode in the file. 138 """ 139 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 140 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 141 sample_rate_placeholder = tf.compat.v1.placeholder(tf.int32, []) 142 wav_data_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 1]) 143 wav_encoder = tf.audio.encode_wav(wav_data_placeholder, 144 sample_rate_placeholder) 145 wav_saver = io_ops.write_file(wav_filename_placeholder, wav_encoder) 146 sess.run( 147 wav_saver, 148 feed_dict={ 149 wav_filename_placeholder: filename, 150 sample_rate_placeholder: sample_rate, 151 wav_data_placeholder: np.reshape(wav_data, (-1, 1)) 152 }) 153 154 155def get_features_range(model_settings): 156 """Returns the expected min/max for generated features. 157 158 Args: 159 model_settings: Information about the current model being trained. 160 161 Returns: 162 Min/max float pair holding the range of features. 163 164 Raises: 165 Exception: If preprocessing mode isn't recognized. 166 """ 167 # TODO(petewarden): These values have been derived from the observed ranges 168 # of spectrogram and MFCC inputs. If the preprocessing pipeline changes, 169 # they may need to be updated. 170 if model_settings['preprocess'] == 'average': 171 features_min = 0.0 172 features_max = 127.5 173 elif model_settings['preprocess'] == 'mfcc': 174 features_min = -247.0 175 features_max = 30.0 176 elif model_settings['preprocess'] == 'micro': 177 features_min = 0.0 178 features_max = 26.0 179 else: 180 raise Exception('Unknown preprocess mode "%s" (should be "mfcc",' 181 ' "average", or "micro")' % (model_settings['preprocess'])) 182 return features_min, features_max 183 184 185class AudioProcessor(object): 186 """Handles loading, partitioning, and preparing audio training data.""" 187 188 def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, 189 wanted_words, validation_percentage, testing_percentage, 190 model_settings, summaries_dir): 191 if data_dir: 192 self.data_dir = data_dir 193 self.maybe_download_and_extract_dataset(data_url, data_dir) 194 self.prepare_data_index(silence_percentage, unknown_percentage, 195 wanted_words, validation_percentage, 196 testing_percentage) 197 self.prepare_background_data() 198 self.prepare_processing_graph(model_settings, summaries_dir) 199 200 def maybe_download_and_extract_dataset(self, data_url, dest_directory): 201 """Download and extract data set tar file. 202 203 If the data set we're using doesn't already exist, this function 204 downloads it from the TensorFlow.org website and unpacks it into a 205 directory. 206 If the data_url is none, don't download anything and expect the data 207 directory to contain the correct files already. 208 209 Args: 210 data_url: Web location of the tar file containing the data set. 211 dest_directory: File path to extract data to. 212 """ 213 if not data_url: 214 return 215 if not gfile.Exists(dest_directory): 216 os.makedirs(dest_directory) 217 filename = data_url.split('/')[-1] 218 filepath = os.path.join(dest_directory, filename) 219 if not gfile.Exists(filepath): 220 221 def _progress(count, block_size, total_size): 222 sys.stdout.write( 223 '\r>> Downloading %s %.1f%%' % 224 (filename, float(count * block_size) / float(total_size) * 100.0)) 225 sys.stdout.flush() 226 227 try: 228 filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) 229 except: 230 tf.compat.v1.logging.error( 231 'Failed to download URL: {0} to folder: {1}. Please make sure you ' 232 'have enough free space and an internet connection'.format( 233 data_url, filepath)) 234 raise 235 print() 236 statinfo = os.stat(filepath) 237 tf.compat.v1.logging.info( 238 'Successfully downloaded {0} ({1} bytes)'.format( 239 filename, statinfo.st_size)) 240 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 241 242 def prepare_data_index(self, silence_percentage, unknown_percentage, 243 wanted_words, validation_percentage, 244 testing_percentage): 245 """Prepares a list of the samples organized by set and label. 246 247 The training loop needs a list of all the available data, organized by 248 which partition it should belong to, and with ground truth labels attached. 249 This function analyzes the folders below the `data_dir`, figures out the 250 right 251 labels for each file based on the name of the subdirectory it belongs to, 252 and uses a stable hash to assign it to a data set partition. 253 254 Args: 255 silence_percentage: How much of the resulting data should be background. 256 unknown_percentage: How much should be audio outside the wanted classes. 257 wanted_words: Labels of the classes we want to be able to recognize. 258 validation_percentage: How much of the data set to use for validation. 259 testing_percentage: How much of the data set to use for testing. 260 261 Returns: 262 Dictionary containing a list of file information for each set partition, 263 and a lookup map for each class to determine its numeric index. 264 265 Raises: 266 Exception: If expected files are not found. 267 """ 268 # Make sure the shuffling and picking of unknowns is deterministic. 269 random.seed(RANDOM_SEED) 270 wanted_words_index = {} 271 for index, wanted_word in enumerate(wanted_words): 272 wanted_words_index[wanted_word] = index + 2 273 self.data_index = {'validation': [], 'testing': [], 'training': []} 274 unknown_index = {'validation': [], 'testing': [], 'training': []} 275 all_words = {} 276 # Look through all the subfolders to find audio samples 277 search_path = os.path.join(self.data_dir, '*', '*.wav') 278 for wav_path in gfile.Glob(search_path): 279 _, word = os.path.split(os.path.dirname(wav_path)) 280 word = word.lower() 281 # Treat the '_background_noise_' folder as a special case, since we expect 282 # it to contain long audio samples we mix in to improve training. 283 if word == BACKGROUND_NOISE_DIR_NAME: 284 continue 285 all_words[word] = True 286 set_index = which_set(wav_path, validation_percentage, testing_percentage) 287 # If it's a known class, store its detail, otherwise add it to the list 288 # we'll use to train the unknown label. 289 if word in wanted_words_index: 290 self.data_index[set_index].append({'label': word, 'file': wav_path}) 291 else: 292 unknown_index[set_index].append({'label': word, 'file': wav_path}) 293 if not all_words: 294 raise Exception('No .wavs found at ' + search_path) 295 for index, wanted_word in enumerate(wanted_words): 296 if wanted_word not in all_words: 297 raise Exception('Expected to find ' + wanted_word + 298 ' in labels but only found ' + 299 ', '.join(all_words.keys())) 300 # We need an arbitrary file to load as the input for the silence samples. 301 # It's multiplied by zero later, so the content doesn't matter. 302 silence_wav_path = self.data_index['training'][0]['file'] 303 for set_index in ['validation', 'testing', 'training']: 304 set_size = len(self.data_index[set_index]) 305 silence_size = int(math.ceil(set_size * silence_percentage / 100)) 306 for _ in range(silence_size): 307 self.data_index[set_index].append({ 308 'label': SILENCE_LABEL, 309 'file': silence_wav_path 310 }) 311 # Pick some unknowns to add to each partition of the data set. 312 random.shuffle(unknown_index[set_index]) 313 unknown_size = int(math.ceil(set_size * unknown_percentage / 100)) 314 self.data_index[set_index].extend(unknown_index[set_index][:unknown_size]) 315 # Make sure the ordering is random. 316 for set_index in ['validation', 'testing', 'training']: 317 random.shuffle(self.data_index[set_index]) 318 # Prepare the rest of the result data structure. 319 self.words_list = prepare_words_list(wanted_words) 320 self.word_to_index = {} 321 for word in all_words: 322 if word in wanted_words_index: 323 self.word_to_index[word] = wanted_words_index[word] 324 else: 325 self.word_to_index[word] = UNKNOWN_WORD_INDEX 326 self.word_to_index[SILENCE_LABEL] = SILENCE_INDEX 327 328 def prepare_background_data(self): 329 """Searches a folder for background noise audio, and loads it into memory. 330 331 It's expected that the background audio samples will be in a subdirectory 332 named '_background_noise_' inside the 'data_dir' folder, as .wavs that match 333 the sample rate of the training data, but can be much longer in duration. 334 335 If the '_background_noise_' folder doesn't exist at all, this isn't an 336 error, it's just taken to mean that no background noise augmentation should 337 be used. If the folder does exist, but it's empty, that's treated as an 338 error. 339 340 Returns: 341 List of raw PCM-encoded audio samples of background noise. 342 343 Raises: 344 Exception: If files aren't found in the folder. 345 """ 346 self.background_data = [] 347 background_dir = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME) 348 if not gfile.Exists(background_dir): 349 return self.background_data 350 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 351 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 352 wav_loader = io_ops.read_file(wav_filename_placeholder) 353 wav_decoder = tf.audio.decode_wav(wav_loader, desired_channels=1) 354 search_path = os.path.join(self.data_dir, BACKGROUND_NOISE_DIR_NAME, 355 '*.wav') 356 for wav_path in gfile.Glob(search_path): 357 wav_data = sess.run( 358 wav_decoder, 359 feed_dict={wav_filename_placeholder: wav_path}).audio.flatten() 360 self.background_data.append(wav_data) 361 if not self.background_data: 362 raise Exception('No background wav files were found in ' + search_path) 363 364 def prepare_processing_graph(self, model_settings, summaries_dir): 365 """Builds a TensorFlow graph to apply the input distortions. 366 367 Creates a graph that loads a WAVE file, decodes it, scales the volume, 368 shifts it in time, adds in background noise, calculates a spectrogram, and 369 then builds an MFCC fingerprint from that. 370 371 This must be called with an active TensorFlow session running, and it 372 creates multiple placeholder inputs, and one output: 373 374 - wav_filename_placeholder_: Filename of the WAV to load. 375 - foreground_volume_placeholder_: How loud the main clip should be. 376 - time_shift_padding_placeholder_: Where to pad the clip. 377 - time_shift_offset_placeholder_: How much to move the clip in time. 378 - background_data_placeholder_: PCM sample data for background noise. 379 - background_volume_placeholder_: Loudness of mixed-in background. 380 - output_: Output 2D fingerprint of processed audio. 381 382 Args: 383 model_settings: Information about the current model being trained. 384 summaries_dir: Path to save training summary information to. 385 386 Raises: 387 ValueError: If the preprocessing mode isn't recognized. 388 Exception: If the preprocessor wasn't compiled in. 389 """ 390 with tf.compat.v1.get_default_graph().name_scope('data'): 391 desired_samples = model_settings['desired_samples'] 392 self.wav_filename_placeholder_ = tf.compat.v1.placeholder( 393 tf.string, [], name='wav_filename') 394 wav_loader = io_ops.read_file(self.wav_filename_placeholder_) 395 wav_decoder = tf.audio.decode_wav( 396 wav_loader, desired_channels=1, desired_samples=desired_samples) 397 # Allow the audio sample's volume to be adjusted. 398 self.foreground_volume_placeholder_ = tf.compat.v1.placeholder( 399 tf.float32, [], name='foreground_volume') 400 scaled_foreground = tf.multiply(wav_decoder.audio, 401 self.foreground_volume_placeholder_) 402 # Shift the sample's start position, and pad any gaps with zeros. 403 self.time_shift_padding_placeholder_ = tf.compat.v1.placeholder( 404 tf.int32, [2, 2], name='time_shift_padding') 405 self.time_shift_offset_placeholder_ = tf.compat.v1.placeholder( 406 tf.int32, [2], name='time_shift_offset') 407 padded_foreground = tf.pad( 408 tensor=scaled_foreground, 409 paddings=self.time_shift_padding_placeholder_, 410 mode='CONSTANT') 411 sliced_foreground = tf.slice(padded_foreground, 412 self.time_shift_offset_placeholder_, 413 [desired_samples, -1]) 414 # Mix in background noise. 415 self.background_data_placeholder_ = tf.compat.v1.placeholder( 416 tf.float32, [desired_samples, 1], name='background_data') 417 self.background_volume_placeholder_ = tf.compat.v1.placeholder( 418 tf.float32, [], name='background_volume') 419 background_mul = tf.multiply(self.background_data_placeholder_, 420 self.background_volume_placeholder_) 421 background_add = tf.add(background_mul, sliced_foreground) 422 background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) 423 # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. 424 spectrogram = audio_ops.audio_spectrogram( 425 background_clamp, 426 window_size=model_settings['window_size_samples'], 427 stride=model_settings['window_stride_samples'], 428 magnitude_squared=True) 429 tf.compat.v1.summary.image( 430 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1) 431 # The number of buckets in each FFT row in the spectrogram will depend on 432 # how many input samples there are in each window. This can be quite 433 # large, with a 160 sample window producing 127 buckets for example. We 434 # don't need this level of detail for classification, so we often want to 435 # shrink them down to produce a smaller result. That's what this section 436 # implements. One method is to use average pooling to merge adjacent 437 # buckets, but a more sophisticated approach is to apply the MFCC 438 # algorithm to shrink the representation. 439 if model_settings['preprocess'] == 'average': 440 self.output_ = tf.nn.pool( 441 input=tf.expand_dims(spectrogram, -1), 442 window_shape=[1, model_settings['average_window_width']], 443 strides=[1, model_settings['average_window_width']], 444 pooling_type='AVG', 445 padding='SAME') 446 tf.compat.v1.summary.image('shrunk_spectrogram', 447 self.output_, 448 max_outputs=1) 449 elif model_settings['preprocess'] == 'mfcc': 450 self.output_ = audio_ops.mfcc( 451 spectrogram, 452 wav_decoder.sample_rate, 453 dct_coefficient_count=model_settings['fingerprint_width']) 454 tf.compat.v1.summary.image( 455 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1) 456 elif model_settings['preprocess'] == 'micro': 457 if not frontend_op: 458 raise Exception( 459 'Micro frontend op is currently not available when running' 460 ' TensorFlow directly from Python, you need to build and run' 461 ' through Bazel') 462 sample_rate = model_settings['sample_rate'] 463 window_size_ms = (model_settings['window_size_samples'] * 464 1000) / sample_rate 465 window_step_ms = (model_settings['window_stride_samples'] * 466 1000) / sample_rate 467 int16_input = tf.cast(tf.multiply(background_clamp, 32768), tf.int16) 468 micro_frontend = frontend_op.audio_microfrontend( 469 int16_input, 470 sample_rate=sample_rate, 471 window_size=window_size_ms, 472 window_step=window_step_ms, 473 num_channels=model_settings['fingerprint_width'], 474 out_scale=1, 475 out_type=tf.float32) 476 self.output_ = tf.multiply(micro_frontend, (10.0 / 256.0)) 477 tf.compat.v1.summary.image( 478 'micro', 479 tf.expand_dims(tf.expand_dims(self.output_, -1), 0), 480 max_outputs=1) 481 else: 482 raise ValueError('Unknown preprocess mode "%s" (should be "mfcc", ' 483 ' "average", or "micro")' % 484 (model_settings['preprocess'])) 485 486 # Merge all the summaries and write them out to /tmp/retrain_logs (by 487 # default) 488 self.merged_summaries_ = tf.compat.v1.summary.merge_all(scope='data') 489 if summaries_dir: 490 self.summary_writer_ = tf.compat.v1.summary.FileWriter( 491 summaries_dir + '/data', tf.compat.v1.get_default_graph()) 492 493 def set_size(self, mode): 494 """Calculates the number of samples in the dataset partition. 495 496 Args: 497 mode: Which partition, must be 'training', 'validation', or 'testing'. 498 499 Returns: 500 Number of samples in the partition. 501 """ 502 return len(self.data_index[mode]) 503 504 def get_data(self, how_many, offset, model_settings, background_frequency, 505 background_volume_range, time_shift, mode, sess): 506 """Gather samples from the data set, applying transformations as needed. 507 508 When the mode is 'training', a random selection of samples will be returned, 509 otherwise the first N clips in the partition will be used. This ensures that 510 validation always uses the same samples, reducing noise in the metrics. 511 512 Args: 513 how_many: Desired number of samples to return. -1 means the entire 514 contents of this partition. 515 offset: Where to start when fetching deterministically. 516 model_settings: Information about the current model being trained. 517 background_frequency: How many clips will have background noise, 0.0 to 518 1.0. 519 background_volume_range: How loud the background noise will be. 520 time_shift: How much to randomly shift the clips by in time. 521 mode: Which partition to use, must be 'training', 'validation', or 522 'testing'. 523 sess: TensorFlow session that was active when processor was created. 524 525 Returns: 526 List of sample data for the transformed samples, and list of label indexes 527 528 Raises: 529 ValueError: If background samples are too short. 530 """ 531 # Pick one of the partitions to choose samples from. 532 candidates = self.data_index[mode] 533 if how_many == -1: 534 sample_count = len(candidates) 535 else: 536 sample_count = max(0, min(how_many, len(candidates) - offset)) 537 # Data and labels will be populated and returned. 538 data = np.zeros((sample_count, model_settings['fingerprint_size'])) 539 labels = np.zeros(sample_count) 540 desired_samples = model_settings['desired_samples'] 541 use_background = self.background_data and (mode == 'training') 542 pick_deterministically = (mode != 'training') 543 # Use the processing graph we created earlier to repeatedly to generate the 544 # final output sample data we'll use in training. 545 for i in range(offset, offset + sample_count): 546 # Pick which audio sample to use. 547 if how_many == -1 or pick_deterministically: 548 sample_index = i 549 else: 550 sample_index = np.random.randint(len(candidates)) 551 sample = candidates[sample_index] 552 # If we're time shifting, set up the offset for this sample. 553 if time_shift > 0: 554 time_shift_amount = np.random.randint(-time_shift, time_shift) 555 else: 556 time_shift_amount = 0 557 if time_shift_amount > 0: 558 time_shift_padding = [[time_shift_amount, 0], [0, 0]] 559 time_shift_offset = [0, 0] 560 else: 561 time_shift_padding = [[0, -time_shift_amount], [0, 0]] 562 time_shift_offset = [-time_shift_amount, 0] 563 input_dict = { 564 self.wav_filename_placeholder_: sample['file'], 565 self.time_shift_padding_placeholder_: time_shift_padding, 566 self.time_shift_offset_placeholder_: time_shift_offset, 567 } 568 # Choose a section of background noise to mix in. 569 if use_background or sample['label'] == SILENCE_LABEL: 570 background_index = np.random.randint(len(self.background_data)) 571 background_samples = self.background_data[background_index] 572 if len(background_samples) <= model_settings['desired_samples']: 573 raise ValueError( 574 'Background sample is too short! Need more than %d' 575 ' samples but only %d were found' % 576 (model_settings['desired_samples'], len(background_samples))) 577 background_offset = np.random.randint( 578 0, len(background_samples) - model_settings['desired_samples']) 579 background_clipped = background_samples[background_offset:( 580 background_offset + desired_samples)] 581 background_reshaped = background_clipped.reshape([desired_samples, 1]) 582 if sample['label'] == SILENCE_LABEL: 583 background_volume = np.random.uniform(0, 1) 584 elif np.random.uniform(0, 1) < background_frequency: 585 background_volume = np.random.uniform(0, background_volume_range) 586 else: 587 background_volume = 0 588 else: 589 background_reshaped = np.zeros([desired_samples, 1]) 590 background_volume = 0 591 input_dict[self.background_data_placeholder_] = background_reshaped 592 input_dict[self.background_volume_placeholder_] = background_volume 593 # If we want silence, mute out the main sample but leave the background. 594 if sample['label'] == SILENCE_LABEL: 595 input_dict[self.foreground_volume_placeholder_] = 0 596 else: 597 input_dict[self.foreground_volume_placeholder_] = 1 598 # Run the graph to produce the output audio. 599 summary, data_tensor = sess.run( 600 [self.merged_summaries_, self.output_], feed_dict=input_dict) 601 self.summary_writer_.add_summary(summary) 602 data[i - offset, :] = data_tensor.flatten() 603 label_index = self.word_to_index[sample['label']] 604 labels[i - offset] = label_index 605 return data, labels 606 607 def get_features_for_wav(self, wav_filename, model_settings, sess): 608 """Applies the feature transformation process to the input_wav. 609 610 Runs the feature generation process (generally producing a spectrogram from 611 the input samples) on the WAV file. This can be useful for testing and 612 verifying implementations being run on other platforms. 613 614 Args: 615 wav_filename: The path to the input audio file. 616 model_settings: Information about the current model being trained. 617 sess: TensorFlow session that was active when processor was created. 618 619 Returns: 620 Numpy data array containing the generated features. 621 """ 622 desired_samples = model_settings['desired_samples'] 623 input_dict = { 624 self.wav_filename_placeholder_: wav_filename, 625 self.time_shift_padding_placeholder_: [[0, 0], [0, 0]], 626 self.time_shift_offset_placeholder_: [0, 0], 627 self.background_data_placeholder_: np.zeros([desired_samples, 1]), 628 self.background_volume_placeholder_: 0, 629 self.foreground_volume_placeholder_: 1, 630 } 631 # Run the graph to produce the output audio. 632 data_tensor = sess.run([self.output_], feed_dict=input_dict) 633 return data_tensor 634 635 def get_unprocessed_data(self, how_many, model_settings, mode): 636 """Retrieve sample data for the given partition, with no transformations. 637 638 Args: 639 how_many: Desired number of samples to return. -1 means the entire 640 contents of this partition. 641 model_settings: Information about the current model being trained. 642 mode: Which partition to use, must be 'training', 'validation', or 643 'testing'. 644 645 Returns: 646 List of sample data for the samples, and list of labels in one-hot form. 647 """ 648 candidates = self.data_index[mode] 649 if how_many == -1: 650 sample_count = len(candidates) 651 else: 652 sample_count = how_many 653 desired_samples = model_settings['desired_samples'] 654 words_list = self.words_list 655 data = np.zeros((sample_count, desired_samples)) 656 labels = [] 657 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 658 wav_filename_placeholder = tf.compat.v1.placeholder(tf.string, []) 659 wav_loader = io_ops.read_file(wav_filename_placeholder) 660 wav_decoder = tf.audio.decode_wav( 661 wav_loader, desired_channels=1, desired_samples=desired_samples) 662 foreground_volume_placeholder = tf.compat.v1.placeholder(tf.float32, []) 663 scaled_foreground = tf.multiply(wav_decoder.audio, 664 foreground_volume_placeholder) 665 for i in range(sample_count): 666 if how_many == -1: 667 sample_index = i 668 else: 669 sample_index = np.random.randint(len(candidates)) 670 sample = candidates[sample_index] 671 input_dict = {wav_filename_placeholder: sample['file']} 672 if sample['label'] == SILENCE_LABEL: 673 input_dict[foreground_volume_placeholder] = 0 674 else: 675 input_dict[foreground_volume_placeholder] = 1 676 data[i, :] = sess.run(scaled_foreground, feed_dict=input_dict).flatten() 677 label_index = self.word_to_index[sample['label']] 678 labels.append(words_list[label_index]) 679 return data, labels 680