1 2""" 3/* Copyright (c) 2023 Amazon 4 Written by Jan Buethe */ 5/* 6 Redistribution and use in source and binary forms, with or without 7 modification, are permitted provided that the following conditions 8 are met: 9 10 - Redistributions of source code must retain the above copyright 11 notice, this list of conditions and the following disclaimer. 12 13 - Redistributions in binary form must reproduce the above copyright 14 notice, this list of conditions and the following disclaimer in the 15 documentation and/or other materials provided with the distribution. 16 17 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 21 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 22 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 23 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 26 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28*/ 29""" 30 31import torch 32from torch import nn 33import torch.nn.functional as F 34 35import numpy as np 36 37from utils.layers.silk_upsampler import SilkUpsampler 38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 39from utils.layers.td_shaper import TDShaper 40from utils.layers.deemph import Deemph 41from utils.misc import freeze_model 42 43from models.nns_base import NNSBase 44from models.silk_feature_net_pl import SilkFeatureNetPL 45from models.silk_feature_net import SilkFeatureNet 46from .scale_embedding import ScaleEmbedding 47 48 49 50class ShapeUp48(NNSBase): 51 FRAME_SIZE16k=80 52 53 def __init__(self, 54 num_features=47, 55 pitch_embedding_dim=64, 56 cond_dim=256, 57 pitch_max=257, 58 kernel_size=15, 59 preemph=0.85, 60 skip=288, 61 conv_gain_limits_db=[-6, 6], 62 numbits_range=[50, 650], 63 numbits_embedding_dim=8, 64 hidden_feature_dim=64, 65 partial_lookahead=True, 66 norm_p=2, 67 target_fs=48000, 68 noise_amplitude=0, 69 prenet=None, 70 avg_pool_k=4): 71 72 super().__init__(skip=skip, preemph=preemph) 73 74 75 self.num_features = num_features 76 self.cond_dim = cond_dim 77 self.pitch_max = pitch_max 78 self.pitch_embedding_dim = pitch_embedding_dim 79 self.kernel_size = kernel_size 80 self.preemph = preemph 81 self.skip = skip 82 self.numbits_range = numbits_range 83 self.numbits_embedding_dim = numbits_embedding_dim 84 self.hidden_feature_dim = hidden_feature_dim 85 self.partial_lookahead = partial_lookahead 86 self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1) 87 self.frame_size32 = self.FRAME_SIZE16k * 2 88 self.noise_amplitude = noise_amplitude 89 self.prenet = prenet 90 91 # freeze prenet if given 92 if prenet is not None: 93 freeze_model(self.prenet) 94 try: 95 self.deemph = Deemph(prenet.preemph) 96 except: 97 print("[warning] prenet model is expected to have preemph attribute") 98 self.deemph = Deemph(0) 99 100 101 102 # upsampler 103 self.upsampler = SilkUpsampler() 104 105 # pitch embedding 106 self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) 107 108 # numbits embedding 109 self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True) 110 111 # feature net 112 if partial_lookahead: 113 self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) 114 else: 115 self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) 116 117 # non-linear transforms 118 self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k) 119 self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k) 120 121 # spectral shaping 122 self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p) 123 self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 124 self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 125 self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) 126 127 128 def flop_count(self, rate=16000, verbose=False): 129 130 frame_rate = rate / self.FRAME_SIZE16k 131 132 # feature net 133 feature_net_flops = self.feature_net.flop_count(frame_rate) 134 af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate) 135 136 if verbose: 137 print(f"feature net: {feature_net_flops / 1e6} MFLOPS") 138 print(f"adaptive conv: {af_flops / 1e6} MFLOPS") 139 140 return feature_net_flops + af_flops 141 142 def forward(self, x, features, periods, numbits, debug=False): 143 144 if self.prenet is not None: 145 with torch.no_grad(): 146 x = self.prenet(x, features, periods, numbits) 147 x = self.deemph(x) 148 149 150 151 periods = periods.squeeze(-1) 152 pitch_embedding = self.pitch_embedding(periods) 153 numbits_embedding = self.numbits_embedding(numbits).flatten(2) 154 155 full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) 156 cf = self.feature_net(full_features) 157 158 y32 = self.upsampler.hq_2x_up(x) 159 160 noise = self.noise_amplitude * torch.randn_like(y32) 161 noise = self.af_noise(noise, cf) 162 163 y32 = self.af1(y32, cf, debug=debug) 164 165 y32_1 = y32[:, 0:1, :] 166 y32_2 = self.tdshape1(y32[:, 1:2, :], cf) 167 y32 = torch.cat((y32_1, y32_2, noise), dim=1) 168 169 y32 = self.af2(y32, cf, debug=debug) 170 171 y48 = self.upsampler.interpolate_3_2(y32) 172 173 y48_1 = y48[:, 0:1, :] 174 y48_2 = self.tdshape2(y48[:, 1:2, :], cf) 175 y48 = torch.cat((y48_1, y48_2), dim=1) 176 177 y48 = self.af3(y48, cf, debug=debug) 178 179 return y48 180