• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import torch
2from torch import nn
3import torch.nn.functional as F
4
5from utils.complexity import _conv1d_flop_count
6from utils.softquant import soft_quant
7
8class TDShaper(nn.Module):
9    COUNTER = 1
10
11    def __init__(self,
12                 feature_dim,
13                 frame_size=160,
14                 avg_pool_k=4,
15                 innovate=False,
16                 pool_after=False,
17                 softquant=False,
18                 apply_weight_norm=False
19    ):
20        """
21
22        Parameters:
23        -----------
24
25
26        feature_dim : int
27            dimension of input features
28
29        frame_size : int
30            frame size
31
32        avg_pool_k : int, optional
33            kernel size and stride for avg pooling
34
35        padding : List[int, int]
36
37        """
38
39        super().__init__()
40
41
42        self.feature_dim    = feature_dim
43        self.frame_size     = frame_size
44        self.avg_pool_k     = avg_pool_k
45        self.innovate       = innovate
46        self.pool_after     = pool_after
47
48        assert frame_size % avg_pool_k == 0
49        self.env_dim = frame_size // avg_pool_k + 1
50
51        norm = torch.nn.utils.weight_norm if apply_weight_norm else lambda x, name=None: x
52
53        # feature transform
54        self.feature_alpha1_f = norm(nn.Conv1d(self.feature_dim, frame_size, 2))
55        self.feature_alpha1_t = norm(nn.Conv1d(self.env_dim, frame_size, 2))
56        self.feature_alpha2 = norm(nn.Conv1d(frame_size, frame_size, 2))
57
58        if softquant:
59            self.feature_alpha1_f = soft_quant(self.feature_alpha1_f)
60
61        if self.innovate:
62            self.feature_alpha1b = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
63            self.feature_alpha1c = norm(nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2))
64
65            self.feature_alpha2b = norm(nn.Conv1d(frame_size, frame_size, 2))
66            self.feature_alpha2c = norm(nn.Conv1d(frame_size, frame_size, 2))
67
68
69    def flop_count(self, rate):
70
71        frame_rate = rate / self.frame_size
72
73        shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1_f, self.feature_alpha1_t, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
74
75        if self.innovate:
76            inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
77        else:
78            inno_flops = 0
79
80        return shape_flops + inno_flops
81
82    def envelope_transform(self, x):
83
84        x = torch.abs(x)
85        if self.pool_after:
86            x = torch.log(x + .5**16)
87            x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
88        else:
89            x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
90            x = torch.log(x + .5**16)
91
92        x = x.reshape(x.size(0), -1, self.env_dim - 1)
93        avg_x = torch.mean(x, -1, keepdim=True)
94
95        x = torch.cat((x - avg_x, avg_x), dim=-1)
96
97        return x
98
99    def forward(self, x, features, debug=False):
100        """ innovate signal parts with temporal shaping
101
102
103        Parameters:
104        -----------
105        x : torch.tensor
106            input signal of shape (batch_size, 1, num_samples)
107
108        features : torch.tensor
109            frame-wise features of shape (batch_size, num_frames, feature_dim)
110
111        """
112
113        batch_size = x.size(0)
114        num_frames = features.size(1)
115        num_samples = x.size(2)
116        frame_size = self.frame_size
117
118        # generate temporal envelope
119        tenv = self.envelope_transform(x)
120
121        # feature path
122        f = F.pad(features.permute(0, 2, 1), [1, 0])
123        t = F.pad(tenv.permute(0, 2, 1), [1, 0])
124        alpha = self.feature_alpha1_f(f) + self.feature_alpha1_t(t)
125        alpha = F.leaky_relu(alpha, 0.2)
126        alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
127        alpha = alpha.permute(0, 2, 1)
128
129        if self.innovate:
130            inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
131            inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
132            inno_alpha = inno_alpha.permute(0, 2, 1)
133
134            inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
135            inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
136            inno_x = inno_x.permute(0, 2, 1)
137
138        # signal path
139        y = x.reshape(batch_size, num_frames, -1)
140        y = alpha * y
141
142        if self.innovate:
143            y = y + inno_alpha * inno_x
144
145        return y.reshape(batch_size, 1, num_samples)
146