1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30from re import sub 31import torch 32from torch import nn 33 34 35 36 37def get_subconditioner( method, 38 number_of_subsamples, 39 pcm_embedding_size, 40 state_size, 41 pcm_levels, 42 number_of_signals, 43 **kwargs): 44 45 subconditioner_dict = { 46 'additive' : AdditiveSubconditioner, 47 'concatenative' : ConcatenativeSubconditioner, 48 'modulative' : ModulativeSubconditioner 49 } 50 51 return subconditioner_dict[method](number_of_subsamples, 52 pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs) 53 54 55class Subconditioner(nn.Module): 56 def __init__(self): 57 """ upsampling by subconditioning 58 59 Upsamples a sequence of states conditioning on pcm signals and 60 optionally a feature vector. 61 """ 62 super(Subconditioner, self).__init__() 63 64 def forward(self, states, signals, features=None): 65 raise Exception("Base class should not be called") 66 67 def single_step(self, index, state, signals, features): 68 raise Exception("Base class should not be called") 69 70 def get_output_dim(self, index): 71 raise Exception("Base class should not be called") 72 73 74class AdditiveSubconditioner(Subconditioner): 75 def __init__(self, 76 number_of_subsamples, 77 pcm_embedding_size, 78 state_size, 79 pcm_levels, 80 number_of_signals, 81 **kwargs): 82 """ subconditioning by addition """ 83 84 super(AdditiveSubconditioner, self).__init__() 85 86 self.number_of_subsamples = number_of_subsamples 87 self.pcm_embedding_size = pcm_embedding_size 88 self.state_size = state_size 89 self.pcm_levels = pcm_levels 90 self.number_of_signals = number_of_signals 91 92 if self.pcm_embedding_size != self.state_size: 93 raise ValueError('For additive subconditioning state and embedding ' 94 + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}') 95 96 self.embeddings = [None] 97 for i in range(1, self.number_of_subsamples): 98 embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) 99 self.add_module('pcm_embedding_' + str(i), embedding) 100 self.embeddings.append(embedding) 101 102 def forward(self, states, signals): 103 """ creates list of subconditioned states 104 105 Parameters: 106 ----------- 107 states : torch.tensor 108 states of shape (batch, seq_length // s, state_size) 109 signals : torch.tensor 110 signals of shape (batch, seq_length, number_of_signals) 111 112 Returns: 113 -------- 114 c_states : list of torch.tensor 115 list of s subconditioned states 116 """ 117 118 s = self.number_of_subsamples 119 120 c_states = [states] 121 new_states = states 122 for i in range(1, self.number_of_subsamples): 123 embed = self.embeddings[i](signals[:, i::s]) 124 # reduce signal dimension 125 embed = torch.sum(embed, dim=2) 126 127 new_states = new_states + embed 128 c_states.append(new_states) 129 130 return c_states 131 132 def single_step(self, index, state, signals): 133 """ carry out single step for inference 134 135 Parameters: 136 ----------- 137 index : int 138 position in subconditioning batch 139 140 state : torch.tensor 141 state to sub-condition 142 143 signals : torch.tensor 144 signals for subconditioning, all but the last dimensions 145 must match those of state 146 147 Returns: 148 c_state : torch.tensor 149 subconditioned state 150 """ 151 152 if index == 0: 153 c_state = state 154 else: 155 embed_signals = self.embeddings[index](signals) 156 c = torch.sum(embed_signals, dim=-2) 157 c_state = state + c 158 159 return c_state 160 161 def get_output_dim(self, index): 162 return self.state_size 163 164 def get_average_flops_per_step(self): 165 s = self.number_of_subsamples 166 flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size 167 return flops 168 169 170class ConcatenativeSubconditioner(Subconditioner): 171 def __init__(self, 172 number_of_subsamples, 173 pcm_embedding_size, 174 state_size, 175 pcm_levels, 176 number_of_signals, 177 recurrent=True, 178 **kwargs): 179 """ subconditioning by concatenation """ 180 181 super(ConcatenativeSubconditioner, self).__init__() 182 183 self.number_of_subsamples = number_of_subsamples 184 self.pcm_embedding_size = pcm_embedding_size 185 self.state_size = state_size 186 self.pcm_levels = pcm_levels 187 self.number_of_signals = number_of_signals 188 self.recurrent = recurrent 189 190 self.embeddings = [] 191 start_index = 0 192 if self.recurrent: 193 start_index = 1 194 self.embeddings.append(None) 195 196 for i in range(start_index, self.number_of_subsamples): 197 embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) 198 self.add_module('pcm_embedding_' + str(i), embedding) 199 self.embeddings.append(embedding) 200 201 def forward(self, states, signals): 202 """ creates list of subconditioned states 203 204 Parameters: 205 ----------- 206 states : torch.tensor 207 states of shape (batch, seq_length // s, state_size) 208 signals : torch.tensor 209 signals of shape (batch, seq_length, number_of_signals) 210 211 Returns: 212 -------- 213 c_states : list of torch.tensor 214 list of s subconditioned states 215 """ 216 s = self.number_of_subsamples 217 218 if self.recurrent: 219 c_states = [states] 220 start = 1 221 else: 222 c_states = [] 223 start = 0 224 225 new_states = states 226 for i in range(start, self.number_of_subsamples): 227 embed = self.embeddings[i](signals[:, i::s]) 228 # reduce signal dimension 229 embed = torch.flatten(embed, -2) 230 231 if self.recurrent: 232 new_states = torch.cat((new_states, embed), dim=-1) 233 else: 234 new_states = torch.cat((states, embed), dim=-1) 235 236 c_states.append(new_states) 237 238 return c_states 239 240 def single_step(self, index, state, signals): 241 """ carry out single step for inference 242 243 Parameters: 244 ----------- 245 index : int 246 position in subconditioning batch 247 248 state : torch.tensor 249 state to sub-condition 250 251 signals : torch.tensor 252 signals for subconditioning, all but the last dimensions 253 must match those of state 254 255 Returns: 256 c_state : torch.tensor 257 subconditioned state 258 """ 259 260 if index == 0 and self.recurrent: 261 c_state = state 262 else: 263 embed_signals = self.embeddings[index](signals) 264 c = torch.flatten(embed_signals, -2) 265 if not self.recurrent and index > 0: 266 # overwrite previous conditioning vector 267 c_state = torch.cat((state[...,:self.state_size], c), dim=-1) 268 else: 269 c_state = torch.cat((state, c), dim=-1) 270 return c_state 271 272 return c_state 273 274 def get_average_flops_per_step(self): 275 return 0 276 277 def get_output_dim(self, index): 278 if self.recurrent: 279 return self.state_size + index * self.pcm_embedding_size * self.number_of_signals 280 else: 281 return self.state_size + self.pcm_embedding_size * self.number_of_signals 282 283class ModulativeSubconditioner(Subconditioner): 284 def __init__(self, 285 number_of_subsamples, 286 pcm_embedding_size, 287 state_size, 288 pcm_levels, 289 number_of_signals, 290 state_recurrent=False, 291 **kwargs): 292 """ subconditioning by modulation """ 293 294 super(ModulativeSubconditioner, self).__init__() 295 296 self.number_of_subsamples = number_of_subsamples 297 self.pcm_embedding_size = pcm_embedding_size 298 self.state_size = state_size 299 self.pcm_levels = pcm_levels 300 self.number_of_signals = number_of_signals 301 self.state_recurrent = state_recurrent 302 303 self.hidden_size = self.pcm_embedding_size * self.number_of_signals 304 305 if self.state_recurrent: 306 self.hidden_size += self.pcm_embedding_size 307 self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size) 308 309 self.embeddings = [None] 310 self.alphas = [None] 311 self.betas = [None] 312 313 for i in range(1, self.number_of_subsamples): 314 embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) 315 self.add_module('pcm_embedding_' + str(i), embedding) 316 self.embeddings.append(embedding) 317 318 self.alphas.append(nn.Linear(self.hidden_size, self.state_size)) 319 self.add_module('alpha_dense_' + str(i), self.alphas[-1]) 320 321 self.betas.append(nn.Linear(self.hidden_size, self.state_size)) 322 self.add_module('beta_dense_' + str(i), self.betas[-1]) 323 324 325 326 def forward(self, states, signals): 327 """ creates list of subconditioned states 328 329 Parameters: 330 ----------- 331 states : torch.tensor 332 states of shape (batch, seq_length // s, state_size) 333 signals : torch.tensor 334 signals of shape (batch, seq_length, number_of_signals) 335 336 Returns: 337 -------- 338 c_states : list of torch.tensor 339 list of s subconditioned states 340 """ 341 s = self.number_of_subsamples 342 343 c_states = [states] 344 new_states = states 345 for i in range(1, self.number_of_subsamples): 346 embed = self.embeddings[i](signals[:, i::s]) 347 # reduce signal dimension 348 embed = torch.flatten(embed, -2) 349 350 if self.state_recurrent: 351 comp_states = self.state_transform(new_states) 352 embed = torch.cat((embed, comp_states), dim=-1) 353 354 alpha = torch.tanh(self.alphas[i](embed)) 355 beta = torch.tanh(self.betas[i](embed)) 356 357 # new state obtained by modulating previous state 358 new_states = torch.tanh((1 + alpha) * new_states + beta) 359 360 c_states.append(new_states) 361 362 return c_states 363 364 def single_step(self, index, state, signals): 365 """ carry out single step for inference 366 367 Parameters: 368 ----------- 369 index : int 370 position in subconditioning batch 371 372 state : torch.tensor 373 state to sub-condition 374 375 signals : torch.tensor 376 signals for subconditioning, all but the last dimensions 377 must match those of state 378 379 Returns: 380 c_state : torch.tensor 381 subconditioned state 382 """ 383 384 if index == 0: 385 c_state = state 386 else: 387 embed_signals = self.embeddings[index](signals) 388 c = torch.flatten(embed_signals, -2) 389 if self.state_recurrent: 390 r_state = self.state_transform(state) 391 c = torch.cat((c, r_state), dim=-1) 392 alpha = torch.tanh(self.alphas[index](c)) 393 beta = torch.tanh(self.betas[index](c)) 394 c_state = torch.tanh((1 + alpha) * state + beta) 395 return c_state 396 397 return c_state 398 399 def get_output_dim(self, index): 400 return self.state_size 401 402 def get_average_flops_per_step(self): 403 s = self.number_of_subsamples 404 405 # estimate activation by 10 flops 406 # c_state = torch.tanh((1 + alpha) * state + beta) 407 flops = 13 * self.state_size 408 409 # hidden size 410 hidden_size = self.number_of_signals * self.pcm_embedding_size 411 if self.state_recurrent: 412 hidden_size += self.pcm_embedding_size 413 414 # counting 2 * A * B flops for Linear(A, B) 415 # alpha = torch.tanh(self.alphas[index](c)) 416 # beta = torch.tanh(self.betas[index](c)) 417 flops += 4 * hidden_size * self.state_size + 20 * self.state_size 418 419 # r_state = self.state_transform(state) 420 if self.state_recurrent: 421 flops += 2 * self.state_size * self.pcm_embedding_size 422 423 # average over steps 424 flops *= (s - 1) / s 425 426 return flops 427 428class ComparitiveSubconditioner(Subconditioner): 429 def __init__(self, 430 number_of_subsamples, 431 pcm_embedding_size, 432 state_size, 433 pcm_levels, 434 number_of_signals, 435 error_index=-1, 436 apply_gate=True, 437 normalize=False): 438 """ subconditioning by comparison """ 439 440 super(ComparitiveSubconditioner, self).__init__() 441 442 self.comparison_size = self.pcm_embedding_size 443 self.error_position = error_index 444 self.apply_gate = apply_gate 445 self.normalize = normalize 446 447 self.state_transform = nn.Linear(self.state_size, self.comparison_size) 448 449 self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) 450 self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) 451 452 if self.apply_gate: 453 self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size) 454 455 # embeddings and state transforms 456 self.embeddings = [None] 457 self.alpha_denses = [None] 458 self.beta_denses = [None] 459 self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)] 460 self.add_module('state_transform_0', self.state_transforms[0]) 461 462 for i in range(1, self.number_of_subsamples): 463 embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) 464 self.add_module('pcm_embedding_' + str(i), embedding) 465 self.embeddings.append(embedding) 466 467 state_transform = nn.Linear(self.state_size, self.comparison_size) 468 self.add_module('state_transform_' + str(i), state_transform) 469 self.state_transforms.append(state_transform) 470 471 self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) 472 self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1]) 473 474 self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) 475 self.add_module('beta_dense_' + str(i), self.beta_denses[-1]) 476 477 def forward(self, states, signals): 478 s = self.number_of_subsamples 479 480 c_states = [states] 481 new_states = states 482 for i in range(1, self.number_of_subsamples): 483 embed = self.embeddings[i](signals[:, i::s]) 484 # reduce signal dimension 485 embed = torch.flatten(embed, -2) 486 487 comp_states = self.state_transforms[i](new_states) 488 489 alpha = torch.tanh(self.alpha_dense(embed)) 490 beta = torch.tanh(self.beta_dense(embed)) 491 492 # new state obtained by modulating previous state 493 new_states = torch.tanh((1 + alpha) * comp_states + beta) 494 495 c_states.append(new_states) 496 497 return c_states 498