• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2"""
3/* Copyright (c) 2023 Amazon
4   Written by Jan Buethe */
5/*
6   Redistribution and use in source and binary forms, with or without
7   modification, are permitted provided that the following conditions
8   are met:
9
10   - Redistributions of source code must retain the above copyright
11   notice, this list of conditions and the following disclaimer.
12
13   - Redistributions in binary form must reproduce the above copyright
14   notice, this list of conditions and the following disclaimer in the
15   documentation and/or other materials provided with the distribution.
16
17   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
21   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28*/
29"""
30
31import torch
32from torch import nn
33import torch.nn.functional as F
34
35import numpy as np
36
37from utils.layers.silk_upsampler import SilkUpsampler
38from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
39from utils.layers.td_shaper import TDShaper
40from utils.layers.deemph import Deemph
41from utils.misc import freeze_model
42
43from models.nns_base import NNSBase
44from models.silk_feature_net_pl import SilkFeatureNetPL
45from models.silk_feature_net import SilkFeatureNet
46from .scale_embedding import ScaleEmbedding
47
48
49
50class ShapeUp48(NNSBase):
51    FRAME_SIZE16k=80
52
53    def __init__(self,
54                 num_features=47,
55                 pitch_embedding_dim=64,
56                 cond_dim=256,
57                 pitch_max=257,
58                 kernel_size=15,
59                 preemph=0.85,
60                 skip=288,
61                 conv_gain_limits_db=[-6, 6],
62                 numbits_range=[50, 650],
63                 numbits_embedding_dim=8,
64                 hidden_feature_dim=64,
65                 partial_lookahead=True,
66                 norm_p=2,
67                 target_fs=48000,
68                 noise_amplitude=0,
69                 prenet=None,
70                 avg_pool_k=4):
71
72        super().__init__(skip=skip, preemph=preemph)
73
74
75        self.num_features           = num_features
76        self.cond_dim               = cond_dim
77        self.pitch_max              = pitch_max
78        self.pitch_embedding_dim    = pitch_embedding_dim
79        self.kernel_size            = kernel_size
80        self.preemph                = preemph
81        self.skip                   = skip
82        self.numbits_range          = numbits_range
83        self.numbits_embedding_dim  = numbits_embedding_dim
84        self.hidden_feature_dim     = hidden_feature_dim
85        self.partial_lookahead      = partial_lookahead
86        self.frame_size48           = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
87        self.frame_size32           = self.FRAME_SIZE16k * 2
88        self.noise_amplitude        = noise_amplitude
89        self.prenet                 = prenet
90
91        # freeze prenet if given
92        if prenet is not None:
93            freeze_model(self.prenet)
94            try:
95                self.deemph = Deemph(prenet.preemph)
96            except:
97                print("[warning] prenet model is expected to have preemph attribute")
98                self.deemph = Deemph(0)
99
100
101
102        # upsampler
103        self.upsampler = SilkUpsampler()
104
105        # pitch embedding
106        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
107
108        # numbits embedding
109        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
110
111        # feature net
112        if partial_lookahead:
113            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
114        else:
115            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
116
117        # non-linear transforms
118        self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
119        self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
120
121        # spectral shaping
122        self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
123        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
124        self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
125        self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
126
127
128    def flop_count(self, rate=16000, verbose=False):
129
130        frame_rate = rate / self.FRAME_SIZE16k
131
132        # feature net
133        feature_net_flops = self.feature_net.flop_count(frame_rate)
134        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
135
136        if verbose:
137            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
138            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
139
140        return feature_net_flops + af_flops
141
142    def forward(self, x, features, periods, numbits, debug=False):
143
144        if self.prenet is not None:
145            with torch.no_grad():
146                x = self.prenet(x, features, periods, numbits)
147                x = self.deemph(x)
148
149
150
151        periods         = periods.squeeze(-1)
152        pitch_embedding = self.pitch_embedding(periods)
153        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
154
155        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
156        cf = self.feature_net(full_features)
157
158        y32 = self.upsampler.hq_2x_up(x)
159
160        noise = self.noise_amplitude * torch.randn_like(y32)
161        noise = self.af_noise(noise, cf)
162
163        y32 = self.af1(y32, cf, debug=debug)
164
165        y32_1 = y32[:, 0:1, :]
166        y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
167        y32 = torch.cat((y32_1, y32_2, noise), dim=1)
168
169        y32 = self.af2(y32, cf, debug=debug)
170
171        y48 = self.upsampler.interpolate_3_2(y32)
172
173        y48_1 = y48[:, 0:1, :]
174        y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
175        y48 = torch.cat((y48_1, y48_2), dim=1)
176
177        y48 = self.af3(y48, cf, debug=debug)
178
179        return y48
180