1# 2# Copyright 2022 Google LLC 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17import numpy as np 18import scipy.fftpack as fftpack 19 20import build.lc3 as lc3 21import tables as T, appendix_c as C 22 23### ------------------------------------------------------------------------ ### 24 25class Sns: 26 27 def __init__(self, dt, sr): 28 29 self.dt = dt 30 self.sr = sr 31 32 (self.ind_lf, self.ind_hf, self.shape, self.gain) = \ 33 (None, None, None, None) 34 35 (self.idx_a, self.ls_a, self.idx_b, self.ls_b) = \ 36 (None, None, None, None) 37 38 def get_data(self): 39 40 data = { 'lfcb' : self.ind_lf, 'hfcb' : self.ind_hf, 41 'shape' : self.shape, 'gain' : self.gain, 42 'idx_a' : self.idx_a, 'ls_a' : self.ls_a } 43 44 if self.idx_b is not None: 45 data.update({ 'idx_b' : self.idx_b, 'ls_b' : self.ls_b }) 46 47 return data 48 49 def get_nbits(self): 50 51 return 38 52 53 def spectral_shaping(self, scf, inv, x): 54 55 ## 3.3.7.4 Scale factors interpolation 56 57 scf_i = np.empty(4*len(scf)) 58 scf_i[0 ] = scf[0] 59 scf_i[1 ] = scf[0] 60 scf_i[2:62:4] = scf[:15] + 1/8 * (scf[1:] - scf[:15]) 61 scf_i[3:63:4] = scf[:15] + 3/8 * (scf[1:] - scf[:15]) 62 scf_i[4:64:4] = scf[:15] + 5/8 * (scf[1:] - scf[:15]) 63 scf_i[5:64:4] = scf[:15] + 7/8 * (scf[1:] - scf[:15]) 64 scf_i[62 ] = scf[15 ] + 1/8 * (scf[15] - scf[14 ]) 65 scf_i[63 ] = scf[15 ] + 3/8 * (scf[15] - scf[14 ]) 66 67 n2 = 64 - min(len(x), 64) 68 69 for i in range(n2): 70 scf_i[i] = 0.5 * (scf_i[2*i] + scf_i[2*i+1]) 71 scf_i = np.append(scf_i[:n2], scf_i[2*n2:]) 72 73 g_sns = np.power(2, [ -scf_i, scf_i ][inv]) 74 75 ## 3.3.7.4 Spectral shaping 76 77 y = np.empty(len(x)) 78 I = T.I[self.dt][self.sr] 79 80 for b in range(len(g_sns)): 81 y[I[b]:I[b+1]] = x[I[b]:I[b+1]] * g_sns[b] 82 83 return y 84 85 86class SnsAnalysis(Sns): 87 88 def __init__(self, dt, sr): 89 90 super().__init__(dt, sr) 91 92 def compute_scale_factors(self, e, att): 93 94 dt = self.dt 95 96 ## 3.3.7.2.1 Padding 97 98 n2 = 64 - len(e) 99 100 e = np.append(np.empty(n2), e) 101 for i in range(n2): 102 e[2*i+0] = e[2*i+1] = e[n2+i] 103 104 ## 3.3.7.2.2 Smoothing 105 106 e_s = np.zeros(len(e)) 107 e_s[0 ] = 0.75 * e[0 ] + 0.25 * e[1 ] 108 e_s[1:63] = 0.25 * e[0:62] + 0.5 * e[1:63] + 0.25 * e[2:64] 109 e_s[ 63] = 0.25 * e[ 62] + 0.75 * e[ 63] 110 111 ## 3.3.7.2.3 Pre-emphasis 112 113 g_tilt = [ 14, 18, 22, 26, 30 ][self.sr] 114 e_p = e_s * (10 ** ((np.arange(64) * g_tilt) / 630)) 115 116 ## 3.3.7.2.4 Noise floor 117 118 noise_floor = max(np.average(e_p) * (10 ** (-40/10)), 2 ** -32) 119 e_p = np.fmax(e_p, noise_floor * np.ones(len(e))) 120 121 ## 3.3.7.2.5 Logarithm 122 123 e_l = np.log2(10 ** -31 + e_p) / 2 124 125 ## 3.3.7.2.6 Band energy grouping 126 127 w = [ 1/12, 2/12, 3/12, 3/12, 2/12, 1/12 ] 128 129 e_4 = np.zeros(len(e_l) // 4) 130 e_4[0 ] = w[0] * e_l[0] + np.sum(w[1:] * e_l[:5]) 131 e_4[1:15] = [ np.sum(w * e_l[4*i-1:4*i+5]) for i in range(1, 15) ] 132 e_4[ 15] = np.sum(w[:5] * e_l[59:64]) + w[5] * e_l[63] 133 134 ## 3.3.7.2.7 Mean removal and scaling, attack handling 135 136 scf = 0.85 * (e_4 - np.average(e_4)) 137 138 scf_a = np.zeros(len(scf)) 139 scf_a[0 ] = np.average(scf[:3]) 140 scf_a[1 ] = np.average(scf[:4]) 141 scf_a[2:14] = [ np.average(scf[i:i+5]) for i in range(12) ] 142 scf_a[ 14] = np.average(scf[12:]) 143 scf_a[ 15] = np.average(scf[13:]) 144 145 scf_a = (0.5 if self.dt == T.DT_10M else 0.3) * \ 146 (scf_a - np.average(scf_a)) 147 148 return scf_a if att else scf 149 150 def enum_mpvq(self, v): 151 152 sign = None 153 index = 0 154 x = 0 155 156 for (n, vn) in enumerate(v[::-1]): 157 158 if sign is not None and vn != 0: 159 index = 2*index + sign 160 if vn != 0: 161 sign = 1 if vn < 0 else 0 162 163 index += T.SNS_MPVQ_OFFSETS[n][x] 164 x += abs(vn) 165 166 return (index, bool(sign)) 167 168 def quantize(self, scf): 169 170 ## 3.3.7.3.2 Stage 1 171 172 dmse_lf = [ np.sum((scf[:8] - T.SNS_LFCB[i]) ** 2) for i in range(32) ] 173 dmse_hf = [ np.sum((scf[8:] - T.SNS_HFCB[i]) ** 2) for i in range(32) ] 174 175 self.ind_lf = np.argmin(dmse_lf) 176 self.ind_hf = np.argmin(dmse_hf) 177 178 st1 = np.append(T.SNS_LFCB[self.ind_lf], T.SNS_HFCB[self.ind_hf]) 179 r1 = scf - st1 180 181 ## 3.3.7.3.3 Stage 2 182 183 t2_rot = fftpack.dct(r1, norm = 'ortho') 184 x = np.abs(t2_rot) 185 186 ## 3.3.7.3.3 Stage 2 Shape search, step 1 187 188 K = 6 189 190 proj_fac = (K - 1) / sum(np.abs(t2_rot)) 191 y3 = np.floor(x * proj_fac).astype(int) 192 193 ## 3.3.7.3.3 Stage 2 Shape search, step 2 194 195 corr_xy = np.sum(y3 * x) 196 energy_y = np.sum(y3 * y3) 197 198 k0 = sum(y3) 199 for k in range(k0, K): 200 q_pvq = ((corr_xy + x) ** 2) / (energy_y + 2*y3 + 1) 201 n_best = np.argmax(q_pvq) 202 203 corr_xy += x[n_best] 204 energy_y += 2*y3[n_best] + 1 205 y3[n_best] += 1 206 207 ## 3.3.7.3.3 Stage 2 Shape search, step 3 208 209 K = 8 210 211 y2 = y3.copy() 212 213 for k in range(sum(y2), K): 214 q_pvq = ((corr_xy + x) ** 2) / (energy_y + 2*y2 + 1) 215 n_best = np.argmax(q_pvq) 216 217 corr_xy += x[n_best] 218 energy_y += 2*y2[n_best] + 1 219 y2[n_best] += 1 220 221 222 ## 3.3.7.3.3 Stage 2 Shape search, step 4 223 224 y1 = np.append(y2[:10], [0] * 6) 225 226 ## 3.3.7.3.3 Stage 2 Shape search, step 5 227 228 corr_xy -= sum(y2[10:] * x[10:]) 229 energy_y -= sum(y2[10:] * y2[10:]) 230 231 ## 3.3.7.3.3 Stage 2 Shape search, step 6 232 233 K = 10 234 235 for k in range(sum(y1), K): 236 q_pvq = ((corr_xy + x[:10]) ** 2) / (energy_y + 2*y1[:10] + 1) 237 n_best = np.argmax(q_pvq) 238 239 corr_xy += x[n_best] 240 energy_y += 2*y1[n_best] + 1 241 y1[n_best] += 1 242 243 ## 3.3.7.3.3 Stage 2 Shape search, step 7 244 245 y0 = np.append(y1[:10], [ 0 ] * 6) 246 247 q_pvq = ((corr_xy + x[10:]) ** 2) / (energy_y + 2*y0[10:] + 1) 248 n_best = 10 + np.argmax(q_pvq) 249 250 y0[n_best] += 1 251 252 ## 3.3.7.3.3 Stage 2 Shape search, step 8 253 254 y0 *= np.sign(t2_rot).astype(int) 255 y1 *= np.sign(t2_rot).astype(int) 256 y2 *= np.sign(t2_rot).astype(int) 257 y3 *= np.sign(t2_rot).astype(int) 258 259 ## 3.3.7.3.3 Stage 2 Shape search, step 9 260 261 xq = [ y / np.sqrt(sum(y ** 2)) for y in (y0, y1, y2, y3) ] 262 263 ## 3.3.7.3.3 Shape and gain combination determination 264 265 G = [ T.SNS_VQ_REG_ADJ_GAINS, T.SNS_VQ_REG_LF_ADJ_GAINS, 266 T.SNS_VQ_NEAR_ADJ_GAINS, T.SNS_VQ_FAR_ADJ_GAINS ] 267 268 dMSE = [ [ sum((t2_rot - G[j][i] * xq[j]) ** 2) 269 for i in range(len(G[j])) ] for j in range(4) ] 270 271 self.shape = np.argmin([ np.min(dMSE[j]) for j in range(4) ]) 272 self.gain = np.argmin(dMSE[self.shape]) 273 274 gain = G[self.shape][self.gain] 275 276 ## 3.3.7.3.3 Enumeration of the selected PVQ pulse configurations 277 278 if self.shape == 0: 279 (self.idx_a, self.ls_a) = self.enum_mpvq(y0[:10]) 280 (self.idx_b, self.ls_b) = self.enum_mpvq(y0[10:]) 281 elif self.shape == 1: 282 (self.idx_a, self.ls_a) = self.enum_mpvq(y1[:10]) 283 (self.idx_b, self.ls_b) = (None, None) 284 elif self.shape == 2: 285 (self.idx_a, self.ls_a) = self.enum_mpvq(y2) 286 (self.idx_b, self.ls_b) = (None, None) 287 elif self.shape == 3: 288 (self.idx_a, self.ls_a) = self.enum_mpvq(y3) 289 (self.idx_b, self.ls_b) = (None, None) 290 291 ## 3.3.7.3.4 Synthesis of the Quantized scale factor 292 293 scf_q = st1 + gain * fftpack.idct(xq[self.shape], norm = 'ortho') 294 295 return scf_q 296 297 def run(self, eb, att, x): 298 299 scf = self.compute_scale_factors(eb, att) 300 scf_q = self.quantize(scf) 301 y = self.spectral_shaping(scf_q, False, x) 302 303 return y 304 305 def store(self, b): 306 307 shape = self.shape 308 gain_msb_bits = np.array([ 1, 1, 2, 2 ])[shape] 309 gain_lsb_bits = np.array([ 0, 1, 0, 1 ])[shape] 310 311 b.write_uint(self.ind_lf, 5) 312 b.write_uint(self.ind_hf, 5) 313 314 b.write_bit(shape >> 1) 315 316 b.write_uint(self.gain >> gain_lsb_bits, gain_msb_bits) 317 318 b.write_bit(self.ls_a) 319 320 if self.shape == 0: 321 sz_shape_a = 2390004 322 index_joint = self.idx_a + \ 323 (2 * self.idx_b + self.ls_b + 2) * sz_shape_a 324 325 elif self.shape == 1: 326 sz_shape_a = 2390004 327 index_joint = self.idx_a + (self.gain & 1) * sz_shape_a 328 329 elif self.shape == 2: 330 index_joint = self.idx_a 331 332 elif self.shape == 3: 333 sz_shape_a = 15158272 334 index_joint = sz_shape_a + (self.gain & 1) + 2 * self.idx_a 335 336 b.write_uint(index_joint, 14 - gain_msb_bits) 337 b.write_uint(index_joint >> (14 - gain_msb_bits), 12) 338 339 340class SnsSynthesis(Sns): 341 342 def __init__(self, dt, sr): 343 344 super().__init__(dt, sr) 345 346 def deenum_mpvq(self, index, ls, npulses, n): 347 348 y = np.zeros(n, dtype=np.int) 349 pos = 0 350 351 for i in range(len(y)-1, -1, -1): 352 353 if index > 0: 354 yi = 0 355 while index < T.SNS_MPVQ_OFFSETS[i][npulses - yi]: yi += 1 356 index -= T.SNS_MPVQ_OFFSETS[i][npulses - yi] 357 else: 358 yi = npulses 359 360 y[pos] = [ yi, -yi ][int(ls)] 361 pos += 1 362 363 npulses -= yi 364 if npulses <= 0: 365 break 366 367 if yi > 0: 368 ls = index & 1 369 index >>= 1 370 371 return y 372 373 def unquantize(self): 374 375 ## 3.7.4.2.1-2 SNS VQ Decoding 376 377 y = np.empty(16, dtype=np.int) 378 379 if self.shape == 0: 380 y[:10] = self.deenum_mpvq(self.idx_a, self.ls_a, 10, 10) 381 y[10:] = self.deenum_mpvq(self.idx_b, self.ls_b, 1, 6) 382 elif self.shape == 1: 383 y[:10] = self.deenum_mpvq(self.idx_a, self.ls_a, 10, 10) 384 y[10:] = np.zeros(6, dtype=np.int) 385 elif self.shape == 2: 386 y = self.deenum_mpvq(self.idx_a, self.ls_a, 8, 16) 387 elif self.shape == 3: 388 y = self.deenum_mpvq(self.idx_a, self.ls_a, 6, 16) 389 390 ## 3.7.4.2.3 Unit energy normalization 391 392 y = y / np.sqrt(sum(y ** 2)) 393 394 ## 3.7.4.2.4 Reconstruction of the quantized scale factors 395 396 G = [ T.SNS_VQ_REG_ADJ_GAINS, T.SNS_VQ_REG_LF_ADJ_GAINS, 397 T.SNS_VQ_NEAR_ADJ_GAINS, T.SNS_VQ_FAR_ADJ_GAINS ] 398 399 gain = G[self.shape][self.gain] 400 401 scf = np.append(T.SNS_LFCB[self.ind_lf], T.SNS_HFCB[self.ind_hf]) \ 402 + gain * fftpack.idct(y, norm = 'ortho') 403 404 return scf 405 406 def load(self, b): 407 408 self.ind_lf = b.read_uint(5) 409 self.ind_hf = b.read_uint(5) 410 411 shape_msb = b.read_bit() 412 413 gain_msb_bits = 1 + shape_msb 414 self.gain = b.read_uint(gain_msb_bits) 415 416 self.ls_a = b.read_bit() 417 418 index_joint = b.read_uint(14 - gain_msb_bits) 419 index_joint |= b.read_uint(12) << (14 - gain_msb_bits) 420 421 if shape_msb == 0: 422 sz_shape_a = 2390004 423 424 if index_joint >= sz_shape_a * 14: 425 raise ValueError('Invalide SNS joint index') 426 427 self.idx_a = index_joint % sz_shape_a 428 index_joint = index_joint // sz_shape_a 429 if index_joint >= 2: 430 self.shape = 0 431 self.idx_b = (index_joint - 2) // 2 432 self.ls_b = (index_joint - 2) % 2 433 else: 434 self.shape = 1 435 self.gain = (self.gain << 1) + (index_joint & 1) 436 437 else: 438 sz_shape_a = 15158272 439 if index_joint >= sz_shape_a + 1549824: 440 raise ValueError('Invalide SNS joint index') 441 442 if index_joint < sz_shape_a: 443 self.shape = 2 444 self.idx_a = index_joint 445 else: 446 self.shape = 3 447 index_joint -= sz_shape_a 448 self.gain = (self.gain << 1) + (index_joint % 2) 449 self.idx_a = index_joint // 2 450 451 def run(self, x): 452 453 scf = self.unquantize() 454 y = self.spectral_shaping(scf, True, x) 455 456 return y 457 458### ------------------------------------------------------------------------ ### 459 460def check_analysis(rng, dt, sr): 461 462 ok = True 463 464 analysis = SnsAnalysis(dt, sr) 465 466 for i in range(10): 467 x = rng.random(T.NE[dt][sr]) * 1e4 468 e = rng.random(min(len(x), 64)) * 1e10 469 470 for att in (0, 1): 471 y = analysis.run(e, att, x) 472 data = analysis.get_data() 473 474 (y_c, data_c) = lc3.sns_analyze(dt, sr, e, att, x) 475 476 for k in data.keys(): 477 ok = ok and data_c[k] == data[k] 478 479 ok = ok and lc3.sns_get_nbits() == analysis.get_nbits() 480 ok = ok and np.amax(np.abs(y - y_c)) < 1e-1 481 482 return ok 483 484def check_synthesis(rng, dt, sr): 485 486 ok = True 487 488 synthesis = SnsSynthesis(dt, sr) 489 490 for i in range(100): 491 492 synthesis.ind_lf = rng.integers(0, 32) 493 synthesis.ind_hf = rng.integers(0, 32) 494 495 shape = rng.integers(0, 4) 496 sz_shape_a = [ 2390004, 2390004, 15158272, 774912 ][shape] 497 sz_shape_b = [ 6, 1, 0, 0 ][shape] 498 synthesis.shape = shape 499 synthesis.gain = rng.integers(0, [ 2, 4, 4, 8 ][shape]) 500 synthesis.idx_a = rng.integers(0, sz_shape_a, endpoint=True) 501 synthesis.ls_a = bool(rng.integers(0, 1, endpoint=True)) 502 synthesis.idx_b = rng.integers(0, sz_shape_b, endpoint=True) 503 synthesis.ls_b = bool(rng.integers(0, 1, endpoint=True)) 504 505 x = rng.random(T.NE[dt][sr]) * 1e4 506 507 y = synthesis.run(x) 508 y_c = lc3.sns_synthesize(dt, sr, synthesis.get_data(), x) 509 ok = ok and np.amax(np.abs(y - y_c)) < 1e0 510 511 return ok 512 513def check_analysis_appendix_c(dt): 514 515 sr = T.SRATE_16K 516 ok = True 517 518 for i in range(len(C.E_B[dt])): 519 520 scf = lc3.sns_compute_scale_factors(dt, sr, C.E_B[dt][i], False) 521 ok = ok and np.amax(np.abs(scf - C.SCF[dt][i])) < 1e-4 522 523 (lf, hf) = lc3.sns_resolve_codebooks(scf) 524 ok = ok and lf == C.IND_LF[dt][i] and hf == C.IND_HF[dt][i] 525 526 (y, yn, shape, gain) = lc3.sns_quantize(scf, lf, hf) 527 ok = ok and np.any(y[0][:16] - C.SNS_Y0[dt][i] == 0) 528 ok = ok and np.any(y[1][:10] - C.SNS_Y1[dt][i] == 0) 529 ok = ok and np.any(y[2][:16] - C.SNS_Y2[dt][i] == 0) 530 ok = ok and np.any(y[3][:16] - C.SNS_Y3[dt][i] == 0) 531 ok = ok and shape == 2*C.SUBMODE_MSB[dt][i] + C.SUBMODE_LSB[dt][i] 532 ok = ok and gain == C.G_IND[dt][i] 533 534 scf_q = lc3.sns_unquantize(lf, hf, yn[shape], shape, gain) 535 ok = ok and np.amax(np.abs(scf_q - C.SCF_Q[dt][i])) < 1e-5 536 537 x = lc3.sns_spectral_shaping(dt, sr, C.SCF_Q[dt][i], False, C.X[dt][i]) 538 ok = ok and np.amax(np.abs(1 - x/C.X_S[dt][i])) < 1e-5 539 540 (x, data) = lc3.sns_analyze(dt, sr, C.E_B[dt][i], False, C.X[dt][i]) 541 ok = ok and data['lfcb'] == C.IND_LF[dt][i] 542 ok = ok and data['hfcb'] == C.IND_HF[dt][i] 543 ok = ok and data['shape'] == \ 544 2*C.SUBMODE_MSB[dt][i] + C.SUBMODE_LSB[dt][i] 545 ok = ok and data['gain'] == C.G_IND[dt][i] 546 ok = ok and data['idx_a'] == C.IDX_A[dt][i] 547 ok = ok and data['ls_a'] == C.LS_IND_A[dt][i] 548 ok = ok and (C.IDX_B[dt][i] is None or 549 data['idx_b'] == C.IDX_B[dt][i]) 550 ok = ok and (C.LS_IND_B[dt][i] is None or 551 data['ls_b'] == C.LS_IND_B[dt][i]) 552 ok = ok and np.amax(np.abs(1 - x/C.X_S[dt][i])) < 1e-5 553 554 return ok 555 556def check_synthesis_appendix_c(dt): 557 558 sr = T.SRATE_16K 559 ok = True 560 561 for i in range(len(C.X_HAT_TNS[dt])): 562 563 data = { 564 'lfcb' : C.IND_LF[dt][i], 'hfcb' : C.IND_HF[dt][i], 565 'shape' : 2*C.SUBMODE_MSB[dt][i] + C.SUBMODE_LSB[dt][i], 566 'gain' : C.G_IND[dt][i], 567 'idx_a' : C.IDX_A[dt][i], 568 'ls_a' : C.LS_IND_A[dt][i], 569 'idx_b' : C.IDX_B[dt][i] if C.IDX_B[dt][i] is not None else 0, 570 'ls_b' : C.LS_IND_B[dt][i] if C.LS_IND_B[dt][i] is not None else 0, 571 } 572 573 x = lc3.sns_synthesize(dt, sr, data, C.X_HAT_TNS[dt][i]) 574 ok = ok and np.amax(np.abs(x - C.X_HAT_SNS[dt][i])) < 1e0 575 576 return ok 577 578def check(): 579 580 rng = np.random.default_rng(1234) 581 ok = True 582 583 for dt in range(T.NUM_DT): 584 for sr in range(T.NUM_SRATE): 585 ok = ok and check_analysis(rng, dt, sr) 586 ok = ok and check_synthesis(rng, dt, sr) 587 588 for dt in range(T.NUM_DT): 589 ok = ok and check_analysis_appendix_c(dt) 590 ok = ok and check_synthesis_appendix_c(dt) 591 592 return ok 593 594### ------------------------------------------------------------------------ ### 595