1import torch 2from torch import nn 3import torch.nn.functional as F 4 5from utils.complexity import _conv1d_flop_count 6from utils.softquant import soft_quant 7 8class TDShaper(nn.Module): 9 COUNTER = 1 10 11 def __init__(self, 12 feature_dim, 13 frame_size=160, 14 avg_pool_k=4, 15 innovate=False, 16 pool_after=False, 17 softquant=False, 18 apply_weight_norm=False 19 ): 20 """ 21 22 Parameters: 23 ----------- 24 25 26 feature_dim : int 27 dimension of input features 28 29 frame_size : int 30 frame size 31 32 avg_pool_k : int, optional 33 kernel size and stride for avg pooling 34 35 padding : List[int, int] 36 37 """ 38 39 super().__init__() 40 41 42 self.feature_dim = feature_dim 43 self.frame_size = frame_size 44 self.avg_pool_k = avg_pool_k 45 self.innovate = innovate 46 self.pool_after = pool_after 47 48 assert frame_size % avg_pool_k == 0 49 self.env_dim = frame_size // avg_pool_k + 1 50 51 norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x 52 53 # feature transform 54 self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2)) 55 self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2)) 56 self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2)) 57 58 if softquant: 59 self.feature_alpha1_f = soft_quant(self.feature_alpha1_f) 60 61 if self.innovate: 62 self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) 63 self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)) 64 65 self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2)) 66 self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2)) 67 68 69 def flop_count(self, rate): 70 71 frame_rate = rate / self.frame_size 72 73 shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size 74 75 if self.innovate: 76 inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size 77 else: 78 inno_flops = 0 79 80 return shape_flops + inno_flops 81 82 def envelope_transform(self, x): 83 84 x = torch.abs(x) 85 if self.pool_after: 86 x = torch.log(x + .5**16) 87 x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) 88 else: 89 x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) 90 x = torch.log(x + .5**16) 91 92 x = x.reshape(x.size(0), -1, self.env_dim - 1) 93 avg_x = torch.mean(x, -1, keepdim=True) 94 95 x = torch.cat((x - avg_x, avg_x), dim=-1) 96 97 return x 98 99 def forward(self, x, features, debug=False): 100 """ innovate signal parts with temporal shaping 101 102 103 Parameters: 104 ----------- 105 x : torch.tensor 106 input signal of shape (batch_size, 1, num_samples) 107 108 features : torch.tensor 109 frame-wise features of shape (batch_size, num_frames, feature_dim) 110 111 """ 112 113 batch_size = x.size(0) 114 num_frames = features.size(1) 115 num_samples = x.size(2) 116 frame_size = self.frame_size 117 118 # generate temporal envelope 119 tenv = self.envelope_transform(x) 120 121 # feature path 122 f = F.pad(features.permute(0, 2, 1), [1, 0]) 123 t = F.pad(tenv.permute(0, 2, 1), [1, 0]) 124 alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t) 125 alpha = F.leaky_relu(alpha, 0.2) 126 alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) 127 alpha = alpha.permute(0, 2, 1) 128 129 if self.innovate: 130 inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2) 131 inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0]))) 132 inno_alpha = inno_alpha.permute(0, 2, 1) 133 134 inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2) 135 inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0]))) 136 inno_x = inno_x.permute(0, 2, 1) 137 138 # signal path 139 y = x.reshape(batch_size, num_frames, -1) 140 y = alpha * y 141 142 if self.innovate: 143 y = y + inno_alpha * inno_x 144 145 return y.reshape(batch_size, 1, num_samples) 146