• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30from re import sub
31import torch
32from torch import nn
33
34
35
36
37def get_subconditioner( method,
38                        number_of_subsamples,
39                        pcm_embedding_size,
40                        state_size,
41                        pcm_levels,
42                        number_of_signals,
43                        **kwargs):
44
45    subconditioner_dict = {
46        'additive'      : AdditiveSubconditioner,
47        'concatenative' : ConcatenativeSubconditioner,
48        'modulative'    : ModulativeSubconditioner
49    }
50
51    return subconditioner_dict[method](number_of_subsamples,
52        pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
53
54
55class Subconditioner(nn.Module):
56    def __init__(self):
57        """ upsampling by subconditioning
58
59            Upsamples a sequence of states conditioning on pcm signals and
60            optionally a feature vector.
61        """
62        super(Subconditioner, self).__init__()
63
64    def forward(self, states, signals, features=None):
65        raise Exception("Base class should not be called")
66
67    def single_step(self, index, state, signals, features):
68        raise Exception("Base class should not be called")
69
70    def get_output_dim(self, index):
71        raise Exception("Base class should not be called")
72
73
74class AdditiveSubconditioner(Subconditioner):
75    def __init__(self,
76                 number_of_subsamples,
77                 pcm_embedding_size,
78                 state_size,
79                 pcm_levels,
80                 number_of_signals,
81                 **kwargs):
82        """ subconditioning by addition """
83
84        super(AdditiveSubconditioner, self).__init__()
85
86        self.number_of_subsamples    = number_of_subsamples
87        self.pcm_embedding_size      = pcm_embedding_size
88        self.state_size              = state_size
89        self.pcm_levels              = pcm_levels
90        self.number_of_signals       = number_of_signals
91
92        if self.pcm_embedding_size != self.state_size:
93            raise ValueError('For additive subconditioning state and embedding '
94            + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
95
96        self.embeddings = [None]
97        for i in range(1, self.number_of_subsamples):
98            embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
99            self.add_module('pcm_embedding_' + str(i), embedding)
100            self.embeddings.append(embedding)
101
102    def forward(self, states, signals):
103        """ creates list of subconditioned states
104
105            Parameters:
106            -----------
107            states : torch.tensor
108                states of shape (batch, seq_length // s, state_size)
109            signals : torch.tensor
110                signals of shape (batch, seq_length, number_of_signals)
111
112            Returns:
113            --------
114            c_states : list of torch.tensor
115                list of s subconditioned states
116        """
117
118        s = self.number_of_subsamples
119
120        c_states = [states]
121        new_states = states
122        for i in range(1, self.number_of_subsamples):
123            embed = self.embeddings[i](signals[:, i::s])
124            # reduce signal dimension
125            embed = torch.sum(embed, dim=2)
126
127            new_states = new_states + embed
128            c_states.append(new_states)
129
130        return c_states
131
132    def single_step(self, index, state, signals):
133        """ carry out single step for inference
134
135            Parameters:
136            -----------
137            index : int
138                position in subconditioning batch
139
140            state : torch.tensor
141                state to sub-condition
142
143            signals : torch.tensor
144                signals for subconditioning, all but the last dimensions
145                must match those of state
146
147            Returns:
148            c_state : torch.tensor
149                subconditioned state
150        """
151
152        if index == 0:
153            c_state = state
154        else:
155            embed_signals = self.embeddings[index](signals)
156            c = torch.sum(embed_signals, dim=-2)
157            c_state = state + c
158
159        return c_state
160
161    def get_output_dim(self, index):
162        return self.state_size
163
164    def get_average_flops_per_step(self):
165        s = self.number_of_subsamples
166        flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
167        return flops
168
169
170class ConcatenativeSubconditioner(Subconditioner):
171    def __init__(self,
172                 number_of_subsamples,
173                 pcm_embedding_size,
174                 state_size,
175                 pcm_levels,
176                 number_of_signals,
177                 recurrent=True,
178                 **kwargs):
179        """ subconditioning by concatenation """
180
181        super(ConcatenativeSubconditioner, self).__init__()
182
183        self.number_of_subsamples    = number_of_subsamples
184        self.pcm_embedding_size      = pcm_embedding_size
185        self.state_size              = state_size
186        self.pcm_levels              = pcm_levels
187        self.number_of_signals       = number_of_signals
188        self.recurrent               = recurrent
189
190        self.embeddings = []
191        start_index = 0
192        if self.recurrent:
193            start_index = 1
194            self.embeddings.append(None)
195
196        for i in range(start_index, self.number_of_subsamples):
197            embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
198            self.add_module('pcm_embedding_' + str(i), embedding)
199            self.embeddings.append(embedding)
200
201    def forward(self, states, signals):
202        """ creates list of subconditioned states
203
204            Parameters:
205            -----------
206            states : torch.tensor
207                states of shape (batch, seq_length // s, state_size)
208            signals : torch.tensor
209                signals of shape (batch, seq_length, number_of_signals)
210
211            Returns:
212            --------
213            c_states : list of torch.tensor
214                list of s subconditioned states
215        """
216        s = self.number_of_subsamples
217
218        if self.recurrent:
219            c_states = [states]
220            start = 1
221        else:
222            c_states = []
223            start = 0
224
225        new_states = states
226        for i in range(start, self.number_of_subsamples):
227            embed = self.embeddings[i](signals[:, i::s])
228            # reduce signal dimension
229            embed = torch.flatten(embed, -2)
230
231            if self.recurrent:
232                new_states = torch.cat((new_states, embed), dim=-1)
233            else:
234                new_states = torch.cat((states, embed), dim=-1)
235
236            c_states.append(new_states)
237
238        return c_states
239
240    def single_step(self, index, state, signals):
241        """ carry out single step for inference
242
243            Parameters:
244            -----------
245            index : int
246                position in subconditioning batch
247
248            state : torch.tensor
249                state to sub-condition
250
251            signals : torch.tensor
252                signals for subconditioning, all but the last dimensions
253                must match those of state
254
255            Returns:
256            c_state : torch.tensor
257                subconditioned state
258        """
259
260        if index == 0 and self.recurrent:
261            c_state = state
262        else:
263            embed_signals = self.embeddings[index](signals)
264            c = torch.flatten(embed_signals, -2)
265            if not self.recurrent and index > 0:
266                # overwrite previous conditioning vector
267                c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
268            else:
269                c_state = torch.cat((state, c), dim=-1)
270            return c_state
271
272        return c_state
273
274    def get_average_flops_per_step(self):
275        return 0
276
277    def get_output_dim(self, index):
278        if self.recurrent:
279            return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
280        else:
281            return self.state_size + self.pcm_embedding_size * self.number_of_signals
282
283class ModulativeSubconditioner(Subconditioner):
284    def __init__(self,
285                 number_of_subsamples,
286                 pcm_embedding_size,
287                 state_size,
288                 pcm_levels,
289                 number_of_signals,
290                 state_recurrent=False,
291                 **kwargs):
292        """ subconditioning by modulation """
293
294        super(ModulativeSubconditioner, self).__init__()
295
296        self.number_of_subsamples    = number_of_subsamples
297        self.pcm_embedding_size      = pcm_embedding_size
298        self.state_size              = state_size
299        self.pcm_levels              = pcm_levels
300        self.number_of_signals       = number_of_signals
301        self.state_recurrent         = state_recurrent
302
303        self.hidden_size = self.pcm_embedding_size * self.number_of_signals
304
305        if self.state_recurrent:
306            self.hidden_size += self.pcm_embedding_size
307            self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
308
309        self.embeddings = [None]
310        self.alphas     = [None]
311        self.betas      = [None]
312
313        for i in range(1, self.number_of_subsamples):
314            embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
315            self.add_module('pcm_embedding_' + str(i), embedding)
316            self.embeddings.append(embedding)
317
318            self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
319            self.add_module('alpha_dense_' + str(i), self.alphas[-1])
320
321            self.betas.append(nn.Linear(self.hidden_size, self.state_size))
322            self.add_module('beta_dense_' + str(i), self.betas[-1])
323
324
325
326    def forward(self, states, signals):
327        """ creates list of subconditioned states
328
329            Parameters:
330            -----------
331            states : torch.tensor
332                states of shape (batch, seq_length // s, state_size)
333            signals : torch.tensor
334                signals of shape (batch, seq_length, number_of_signals)
335
336            Returns:
337            --------
338            c_states : list of torch.tensor
339                list of s subconditioned states
340        """
341        s = self.number_of_subsamples
342
343        c_states = [states]
344        new_states = states
345        for i in range(1, self.number_of_subsamples):
346            embed = self.embeddings[i](signals[:, i::s])
347            # reduce signal dimension
348            embed = torch.flatten(embed, -2)
349
350            if self.state_recurrent:
351                comp_states = self.state_transform(new_states)
352                embed = torch.cat((embed, comp_states), dim=-1)
353
354            alpha = torch.tanh(self.alphas[i](embed))
355            beta  = torch.tanh(self.betas[i](embed))
356
357            # new state obtained by modulating previous state
358            new_states = torch.tanh((1 + alpha) * new_states + beta)
359
360            c_states.append(new_states)
361
362        return c_states
363
364    def single_step(self, index, state, signals):
365        """ carry out single step for inference
366
367            Parameters:
368            -----------
369            index : int
370                position in subconditioning batch
371
372            state : torch.tensor
373                state to sub-condition
374
375            signals : torch.tensor
376                signals for subconditioning, all but the last dimensions
377                must match those of state
378
379            Returns:
380            c_state : torch.tensor
381                subconditioned state
382        """
383
384        if index == 0:
385            c_state = state
386        else:
387            embed_signals = self.embeddings[index](signals)
388            c = torch.flatten(embed_signals, -2)
389            if self.state_recurrent:
390                r_state = self.state_transform(state)
391                c = torch.cat((c, r_state), dim=-1)
392            alpha = torch.tanh(self.alphas[index](c))
393            beta = torch.tanh(self.betas[index](c))
394            c_state = torch.tanh((1 + alpha) * state + beta)
395            return c_state
396
397        return c_state
398
399    def get_output_dim(self, index):
400        return self.state_size
401
402    def get_average_flops_per_step(self):
403        s = self.number_of_subsamples
404
405        # estimate activation by 10 flops
406        # c_state = torch.tanh((1 + alpha) * state + beta)
407        flops = 13 * self.state_size
408
409        # hidden size
410        hidden_size = self.number_of_signals * self.pcm_embedding_size
411        if self.state_recurrent:
412            hidden_size += self.pcm_embedding_size
413
414        # counting 2 * A * B flops for Linear(A, B)
415        # alpha = torch.tanh(self.alphas[index](c))
416        # beta = torch.tanh(self.betas[index](c))
417        flops += 4 * hidden_size * self.state_size + 20 * self.state_size
418
419        # r_state = self.state_transform(state)
420        if self.state_recurrent:
421            flops += 2 * self.state_size * self.pcm_embedding_size
422
423        # average over steps
424        flops *= (s - 1) / s
425
426        return flops
427
428class ComparitiveSubconditioner(Subconditioner):
429    def __init__(self,
430                 number_of_subsamples,
431                 pcm_embedding_size,
432                 state_size,
433                 pcm_levels,
434                 number_of_signals,
435                 error_index=-1,
436                 apply_gate=True,
437                 normalize=False):
438        """ subconditioning by comparison """
439
440        super(ComparitiveSubconditioner, self).__init__()
441
442        self.comparison_size = self.pcm_embedding_size
443        self.error_position  = error_index
444        self.apply_gate      = apply_gate
445        self.normalize       = normalize
446
447        self.state_transform = nn.Linear(self.state_size, self.comparison_size)
448
449        self.alpha_dense     = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
450        self.beta_dense      = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
451
452        if self.apply_gate:
453            self.gate_dense      = nn.Linear(self.pcm_embedding_size, self.state_size)
454
455        # embeddings and state transforms
456        self.embeddings   = [None]
457        self.alpha_denses = [None]
458        self.beta_denses  = [None]
459        self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
460        self.add_module('state_transform_0', self.state_transforms[0])
461
462        for i in range(1, self.number_of_subsamples):
463            embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
464            self.add_module('pcm_embedding_' + str(i), embedding)
465            self.embeddings.append(embedding)
466
467            state_transform = nn.Linear(self.state_size, self.comparison_size)
468            self.add_module('state_transform_' + str(i), state_transform)
469            self.state_transforms.append(state_transform)
470
471            self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
472            self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
473
474            self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
475            self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
476
477    def forward(self, states, signals):
478        s = self.number_of_subsamples
479
480        c_states = [states]
481        new_states = states
482        for i in range(1, self.number_of_subsamples):
483            embed = self.embeddings[i](signals[:, i::s])
484            # reduce signal dimension
485            embed = torch.flatten(embed, -2)
486
487            comp_states = self.state_transforms[i](new_states)
488
489            alpha = torch.tanh(self.alpha_dense(embed))
490            beta  = torch.tanh(self.beta_dense(embed))
491
492            # new state obtained by modulating previous state
493            new_states = torch.tanh((1 + alpha) * comp_states + beta)
494
495            c_states.append(new_states)
496
497        return c_states
498