1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import sys 33 34try: 35 import git 36 has_git = True 37except: 38 has_git = False 39 40import yaml 41 42 43import torch 44from torch.optim.lr_scheduler import LambdaLR 45 46from data import LPCNetDataset 47from models import model_dict 48from engine.lpcnet_engine import train_one_epoch, evaluate 49from utils.data import load_features 50from utils.wav import wavwrite16 51 52 53debug = False 54if debug: 55 args = type('dummy', (object,), 56 { 57 'setup' : 'setup.yml', 58 'output' : 'testout', 59 'device' : None, 60 'test_features' : None, 61 'finalize': False, 62 'initial_checkpoint': None, 63 'no-redirect': False 64 })() 65else: 66 parser = argparse.ArgumentParser("train_lpcnet.py") 67 parser.add_argument('setup', type=str, help='setup yaml file') 68 parser.add_argument('output', type=str, help='output path') 69 parser.add_argument('--device', type=str, help='compute device', default=None) 70 parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None) 71 parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5') 72 parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None) 73 parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output') 74 75 args = parser.parse_args() 76 77 78torch.set_num_threads(4) 79 80with open(args.setup, 'r') as f: 81 setup = yaml.load(f.read(), yaml.FullLoader) 82 83if args.finalize: 84 if args.initial_checkpoint is None: 85 raise ValueError('finalization requires initial checkpoint') 86 87 if 'sparsification' in setup['lpcnet']['config']: 88 for sp_job in setup['lpcnet']['config']['sparsification'].values(): 89 sp_job['start'], sp_job['stop'] = 0, 0 90 91 setup['training']['lr'] = 1.0e-5 92 setup['training']['lr_decay_factor'] = 0.0 93 setup['training']['epochs'] = 1 94 95 checkpoint_prefix = 'checkpoint_finalize' 96 output_prefix = 'output_finalize' 97 setup_name = 'setup_finalize.yml' 98 output_file='out_finalize.txt' 99else: 100 checkpoint_prefix = 'checkpoint' 101 output_prefix = 'output' 102 setup_name = 'setup.yml' 103 output_file='out.txt' 104 105 106# check model 107if not 'model' in setup['lpcnet']: 108 print(f'warning: did not find model entry in setup, using default lpcnet') 109 model_name = 'lpcnet' 110else: 111 model_name = setup['lpcnet']['model'] 112 113# prepare output folder 114if os.path.exists(args.output) and not debug and not args.finalize: 115 print("warning: output folder exists") 116 117 reply = input('continue? (y/n): ') 118 while reply not in {'y', 'n'}: 119 reply = input('continue? (y/n): ') 120 121 if reply == 'n': 122 os._exit() 123else: 124 os.makedirs(args.output, exist_ok=True) 125 126checkpoint_dir = os.path.join(args.output, 'checkpoints') 127os.makedirs(checkpoint_dir, exist_ok=True) 128 129 130# add repo info to setup 131if has_git: 132 working_dir = os.path.split(__file__)[0] 133 try: 134 repo = git.Repo(working_dir) 135 setup['repo'] = dict() 136 hash = repo.head.object.hexsha 137 urls = list(repo.remote().urls) 138 is_dirty = repo.is_dirty() 139 140 if is_dirty: 141 print("warning: repo is dirty") 142 143 setup['repo']['hash'] = hash 144 setup['repo']['urls'] = urls 145 setup['repo']['dirty'] = is_dirty 146 except: 147 has_git = False 148 149# dump setup 150with open(os.path.join(args.output, setup_name), 'w') as f: 151 yaml.dump(setup, f) 152 153# prepare inference test if wanted 154run_inference_test = False 155if type(args.test_features) != type(None): 156 test_features = load_features(args.test_features) 157 inference_test_dir = os.path.join(args.output, 'inference_test') 158 os.makedirs(inference_test_dir, exist_ok=True) 159 run_inference_test = True 160 161# training parameters 162batch_size = setup['training']['batch_size'] 163epochs = setup['training']['epochs'] 164lr = setup['training']['lr'] 165lr_decay_factor = setup['training']['lr_decay_factor'] 166 167# load training dataset 168lpcnet_config = setup['lpcnet']['config'] 169data = LPCNetDataset( setup['dataset'], 170 features=lpcnet_config['features'], 171 input_signals=lpcnet_config['signals'], 172 target=lpcnet_config['target'], 173 frames_per_sample=setup['training']['frames_per_sample'], 174 feature_history=lpcnet_config['feature_history'], 175 feature_lookahead=lpcnet_config['feature_lookahead'], 176 lpc_gamma=lpcnet_config.get('lpc_gamma', 1)) 177 178# load validation dataset if given 179if 'validation_dataset' in setup: 180 validation_data = LPCNetDataset( setup['validation_dataset'], 181 features=lpcnet_config['features'], 182 input_signals=lpcnet_config['signals'], 183 target=lpcnet_config['target'], 184 frames_per_sample=setup['training']['frames_per_sample'], 185 feature_history=lpcnet_config['feature_history'], 186 feature_lookahead=lpcnet_config['feature_lookahead'], 187 lpc_gamma=lpcnet_config.get('lpc_gamma', 1)) 188 189 validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4) 190 191 run_validation = True 192else: 193 run_validation = False 194 195# create model 196model = model_dict[model_name](setup['lpcnet']['config']) 197 198if args.initial_checkpoint is not None: 199 print(f"loading state dict from {args.initial_checkpoint}...") 200 chkpt = torch.load(args.initial_checkpoint, map_location='cpu') 201 model.load_state_dict(chkpt['state_dict']) 202 203# set compute device 204if type(args.device) == type(None): 205 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 206else: 207 device = torch.device(args.device) 208 209# push model to device 210model.to(device) 211 212# dataloader 213dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4) 214 215# optimizer is introduced to trainable parameters 216parameters = [p for p in model.parameters() if p.requires_grad] 217optimizer = torch.optim.Adam(parameters, lr=lr) 218 219# learning rate scheduler 220scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) 221 222# loss 223criterion = torch.nn.NLLLoss() 224 225# model checkpoint 226checkpoint = { 227 'setup' : setup, 228 'state_dict' : model.state_dict(), 229 'loss' : -1 230} 231 232if not args.no_redirect: 233 print(f"re-directing output to {os.path.join(args.output, output_file)}") 234 sys.stdout = open(os.path.join(args.output, output_file), "w") 235 236best_loss = 1e9 237 238for ep in range(1, epochs + 1): 239 print(f"training epoch {ep}...") 240 new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) 241 242 243 # save checkpoint 244 checkpoint['state_dict'] = model.state_dict() 245 checkpoint['loss'] = new_loss 246 247 if run_validation: 248 print("running validation...") 249 validation_loss = evaluate(model, criterion, validation_dataloader, device) 250 checkpoint['validation_loss'] = validation_loss 251 252 if validation_loss < best_loss: 253 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth')) 254 best_loss = validation_loss 255 256 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth')) 257 torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth')) 258 259 # run inference test 260 if run_inference_test: 261 model.to("cpu") 262 print("running inference test...") 263 264 output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs']) 265 266 testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav') 267 268 wavwrite16(testfilename, output.numpy(), 16000) 269 270 model.to(device) 271 272 print() 273