• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31
32from torch.utils.data import Dataset
33import numpy as np
34
35from utils.silk_features import silk_feature_factory
36from utils.pitch import hangover, calculate_acorr_window
37
38
39class SilkEnhancementSet(Dataset):
40    def __init__(self,
41                 path,
42                 frames_per_sample=100,
43                 no_pitch_value=9,
44                 acorr_radius=2,
45                 pitch_hangover=8,
46                 num_bands_clean_spec=64,
47                 num_bands_noisy_spec=18,
48                 noisy_spec_scale='opus',
49                 noisy_apply_dct=True,
50                 add_offset=False,
51                 add_double_lag_acorr=False
52                 ):
53
54        assert frames_per_sample % 4 == 0
55
56        self.frame_size = 80
57        self.frames_per_sample = frames_per_sample
58        self.no_pitch_value = no_pitch_value
59        self.acorr_radius = acorr_radius
60        self.pitch_hangover = pitch_hangover
61        self.num_bands_clean_spec = num_bands_clean_spec
62        self.num_bands_noisy_spec = num_bands_noisy_spec
63        self.noisy_spec_scale = noisy_spec_scale
64        self.add_double_lag_acorr = add_double_lag_acorr
65
66        self.lpcs = np.fromfile(os.path.join(path, 'features_lpc.f32'), dtype=np.float32).reshape(-1, 16)
67        self.ltps = np.fromfile(os.path.join(path, 'features_ltp.f32'), dtype=np.float32).reshape(-1, 5)
68        self.periods = np.fromfile(os.path.join(path, 'features_period.s16'), dtype=np.int16)
69        self.gains = np.fromfile(os.path.join(path, 'features_gain.f32'), dtype=np.float32)
70        self.num_bits = np.fromfile(os.path.join(path, 'features_num_bits.s32'), dtype=np.int32)
71        self.num_bits_smooth = np.fromfile(os.path.join(path, 'features_num_bits_smooth.f32'), dtype=np.float32)
72        self.offsets = np.fromfile(os.path.join(path, 'features_offset.f32'), dtype=np.float32)
73        self.lpcnet_features = np.from_file(os.path.join(path, 'features_lpcnet.f32'), dtype=np.float32).reshape(-1, 36)
74
75        self.coded_signal = np.fromfile(os.path.join(path, 'coded.s16'), dtype=np.int16)
76
77        self.create_features = silk_feature_factory(no_pitch_value,
78                                                    acorr_radius,
79                                                    pitch_hangover,
80                                                    num_bands_clean_spec,
81                                                    num_bands_noisy_spec,
82                                                    noisy_spec_scale,
83                                                    noisy_apply_dct,
84                                                    add_offset,
85                                                    add_double_lag_acorr)
86
87        self.history_len = 700 if add_double_lag_acorr else 350
88        # discard some frames to have enough signal history
89        self.skip_frames = 4 * ((self.history_len + 319) // 320 + 2)
90
91        num_frames = self.clean_signal.shape[0] // 80 - self.skip_frames
92
93        self.len = num_frames // frames_per_sample
94
95    def __len__(self):
96        return self.len
97
98    def __getitem__(self, index):
99
100        frame_start = self.frames_per_sample * index + self.skip_frames
101        frame_stop  = frame_start + self.frames_per_sample
102
103        signal_start = frame_start * self.frame_size - self.skip
104        signal_stop  = frame_stop  * self.frame_size - self.skip
105
106        coded_signal = self.coded_signal[signal_start : signal_stop].astype(np.float32) / 2**15
107
108        coded_signal_history = self.coded_signal[signal_start - self.history_len : signal_start].astype(np.float32) / 2**15
109
110        features, periods = self.create_features(
111              coded_signal,
112              coded_signal_history,
113              self.lpcs[frame_start : frame_stop],
114              self.gains[frame_start : frame_stop],
115              self.ltps[frame_start : frame_stop],
116              self.periods[frame_start : frame_stop],
117              self.offsets[frame_start : frame_stop]
118        )
119
120        lpcnet_features = self.lpcnet_features[frame_start // 2 : frame_stop // 2, :20]
121
122        num_bits        = np.repeat(self.num_bits[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
123        num_bits_smooth = np.repeat(self.num_bits_smooth[frame_start // 4 : frame_stop // 4], 4).astype(np.float32).reshape(-1, 1)
124
125        numbits = np.concatenate((num_bits, num_bits_smooth), axis=-1)
126
127        return {
128            'silk_features'   : features,
129            'periods'         : periods.astype(np.int64),
130            'numbits'         : numbits.astype(np.float32),
131            'lpcnet_features' : lpcnet_features
132            }
133