• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33
34try:
35    import git
36    has_git = True
37except:
38    has_git = False
39
40import yaml
41
42
43import torch
44from torch.optim.lr_scheduler import LambdaLR
45
46from data import LPCNetDataset
47from models import model_dict
48from engine.lpcnet_engine import train_one_epoch, evaluate
49from utils.data import load_features
50from utils.wav import wavwrite16
51
52
53debug = False
54if debug:
55    args = type('dummy', (object,),
56    {
57        'setup' : 'setup.yml',
58        'output' : 'testout',
59        'device' : None,
60        'test_features' : None,
61        'finalize': False,
62        'initial_checkpoint': None,
63        'no-redirect': False
64    })()
65else:
66    parser = argparse.ArgumentParser("train_lpcnet.py")
67    parser.add_argument('setup', type=str, help='setup yaml file')
68    parser.add_argument('output', type=str, help='output path')
69    parser.add_argument('--device', type=str, help='compute device', default=None)
70    parser.add_argument('--test-features', type=str, help='test feature file in v2 format', default=None)
71    parser.add_argument('--finalize', action='store_true', help='run single training round with lr=1e-5')
72    parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint', default=None)
73    parser.add_argument('--no-redirect', action='store_true', help='disables re-direction of output')
74
75    args = parser.parse_args()
76
77
78torch.set_num_threads(4)
79
80with open(args.setup, 'r') as f:
81    setup = yaml.load(f.read(), yaml.FullLoader)
82
83if args.finalize:
84    if args.initial_checkpoint is None:
85        raise ValueError('finalization requires initial checkpoint')
86
87    if 'sparsification' in setup['lpcnet']['config']:
88        for sp_job in setup['lpcnet']['config']['sparsification'].values():
89            sp_job['start'], sp_job['stop'] = 0, 0
90
91    setup['training']['lr'] = 1.0e-5
92    setup['training']['lr_decay_factor'] = 0.0
93    setup['training']['epochs'] = 1
94
95    checkpoint_prefix = 'checkpoint_finalize'
96    output_prefix = 'output_finalize'
97    setup_name = 'setup_finalize.yml'
98    output_file='out_finalize.txt'
99else:
100    checkpoint_prefix = 'checkpoint'
101    output_prefix = 'output'
102    setup_name = 'setup.yml'
103    output_file='out.txt'
104
105
106# check model
107if not 'model' in setup['lpcnet']:
108    print(f'warning: did not find model entry in setup, using default lpcnet')
109    model_name = 'lpcnet'
110else:
111    model_name = setup['lpcnet']['model']
112
113# prepare output folder
114if os.path.exists(args.output) and not debug and not args.finalize:
115    print("warning: output folder exists")
116
117    reply = input('continue? (y/n): ')
118    while reply not in {'y', 'n'}:
119        reply = input('continue? (y/n): ')
120
121    if reply == 'n':
122        os._exit()
123else:
124    os.makedirs(args.output, exist_ok=True)
125
126checkpoint_dir = os.path.join(args.output, 'checkpoints')
127os.makedirs(checkpoint_dir, exist_ok=True)
128
129
130# add repo info to setup
131if has_git:
132    working_dir = os.path.split(__file__)[0]
133    try:
134        repo = git.Repo(working_dir)
135        setup['repo'] = dict()
136        hash = repo.head.object.hexsha
137        urls = list(repo.remote().urls)
138        is_dirty = repo.is_dirty()
139
140        if is_dirty:
141            print("warning: repo is dirty")
142
143        setup['repo']['hash'] = hash
144        setup['repo']['urls'] = urls
145        setup['repo']['dirty'] = is_dirty
146    except:
147        has_git = False
148
149# dump setup
150with open(os.path.join(args.output, setup_name), 'w') as f:
151    yaml.dump(setup, f)
152
153# prepare inference test if wanted
154run_inference_test = False
155if type(args.test_features) != type(None):
156    test_features = load_features(args.test_features)
157    inference_test_dir = os.path.join(args.output, 'inference_test')
158    os.makedirs(inference_test_dir, exist_ok=True)
159    run_inference_test = True
160
161# training parameters
162batch_size      = setup['training']['batch_size']
163epochs          = setup['training']['epochs']
164lr              = setup['training']['lr']
165lr_decay_factor = setup['training']['lr_decay_factor']
166
167# load training dataset
168lpcnet_config = setup['lpcnet']['config']
169data = LPCNetDataset(   setup['dataset'],
170                        features=lpcnet_config['features'],
171                        input_signals=lpcnet_config['signals'],
172                        target=lpcnet_config['target'],
173                        frames_per_sample=setup['training']['frames_per_sample'],
174                        feature_history=lpcnet_config['feature_history'],
175                        feature_lookahead=lpcnet_config['feature_lookahead'],
176                        lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
177
178# load validation dataset if given
179if 'validation_dataset' in setup:
180    validation_data = LPCNetDataset(   setup['validation_dataset'],
181                        features=lpcnet_config['features'],
182                        input_signals=lpcnet_config['signals'],
183                        target=lpcnet_config['target'],
184                        frames_per_sample=setup['training']['frames_per_sample'],
185                        feature_history=lpcnet_config['feature_history'],
186                        feature_lookahead=lpcnet_config['feature_lookahead'],
187                        lpc_gamma=lpcnet_config.get('lpc_gamma', 1))
188
189    validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, drop_last=True, num_workers=4)
190
191    run_validation = True
192else:
193    run_validation = False
194
195# create model
196model = model_dict[model_name](setup['lpcnet']['config'])
197
198if args.initial_checkpoint is not None:
199    print(f"loading state dict from {args.initial_checkpoint}...")
200    chkpt = torch.load(args.initial_checkpoint, map_location='cpu')
201    model.load_state_dict(chkpt['state_dict'])
202
203# set compute device
204if type(args.device) == type(None):
205    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206else:
207    device = torch.device(args.device)
208
209# push model to device
210model.to(device)
211
212# dataloader
213dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=4)
214
215# optimizer is introduced to trainable parameters
216parameters = [p for p in model.parameters() if p.requires_grad]
217optimizer = torch.optim.Adam(parameters, lr=lr)
218
219# learning rate scheduler
220scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
221
222# loss
223criterion = torch.nn.NLLLoss()
224
225# model checkpoint
226checkpoint = {
227    'setup'         : setup,
228    'state_dict'    : model.state_dict(),
229    'loss'          : -1
230}
231
232if not args.no_redirect:
233    print(f"re-directing output to {os.path.join(args.output, output_file)}")
234    sys.stdout = open(os.path.join(args.output, output_file), "w")
235
236best_loss = 1e9
237
238for ep in range(1, epochs + 1):
239    print(f"training epoch {ep}...")
240    new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
241
242
243    # save checkpoint
244    checkpoint['state_dict'] = model.state_dict()
245    checkpoint['loss']       = new_loss
246
247    if run_validation:
248        print("running validation...")
249        validation_loss = evaluate(model, criterion, validation_dataloader, device)
250        checkpoint['validation_loss'] = validation_loss
251
252        if validation_loss < best_loss:
253            torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_best.pth'))
254            best_loss = validation_loss
255
256    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_epoch_{ep}.pth'))
257    torch.save(checkpoint, os.path.join(checkpoint_dir, checkpoint_prefix + f'_last.pth'))
258
259    # run inference test
260    if run_inference_test:
261        model.to("cpu")
262        print("running inference test...")
263
264        output = model.generate(test_features['features'], test_features['periods'], test_features['lpcs'])
265
266        testfilename = os.path.join(inference_test_dir, output_prefix + f'_epoch_{ep}.wav')
267
268        wavwrite16(testfilename, output.numpy(), 16000)
269
270        model.to(device)
271
272    print()
273