1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for mel_ops.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.eager import context 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.framework import test_util as tf_test_util 26from tensorflow.python.kernel_tests.signal import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops.signal import mel_ops 29from tensorflow.python.platform import test 30 31# mel spectrum constants and functions. 32_MEL_BREAK_FREQUENCY_HERTZ = 700.0 33_MEL_HIGH_FREQUENCY_Q = 1127.0 34 35 36def hertz_to_mel(frequencies_hertz): 37 """Convert frequencies to mel scale using HTK formula. 38 39 Copied from 40 https://github.com/tensorflow/models/blob/master/research/audioset/mel_features.py. 41 42 Args: 43 frequencies_hertz: Scalar or np.array of frequencies in hertz. 44 45 Returns: 46 Object of same size as frequencies_hertz containing corresponding values 47 on the mel scale. 48 """ 49 return _MEL_HIGH_FREQUENCY_Q * np.log( 50 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 51 52 53def spectrogram_to_mel_matrix(num_mel_bins=20, 54 num_spectrogram_bins=129, 55 audio_sample_rate=8000, 56 lower_edge_hertz=125.0, 57 upper_edge_hertz=3800.0, 58 unused_dtype=None): 59 """Return a matrix that can post-multiply spectrogram rows to make mel. 60 61 Copied from 62 https://github.com/tensorflow/models/blob/master/research/audioset/mel_features.py. 63 64 Returns a np.array matrix A that can be used to post-multiply a matrix S of 65 spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 66 "mel spectrogram" M of frames x num_mel_bins. M = S A. 67 68 The classic HTK algorithm exploits the complementarity of adjacent mel bands 69 to multiply each FFT bin by only one mel weight, then add it, with positive 70 and negative signs, to the two adjacent mel bands to which that bin 71 contributes. Here, by expressing this operation as a matrix multiply, we go 72 from num_fft multiplies per frame (plus around 2*num_fft adds) to around 73 num_fft^2 multiplies and adds. However, because these are all presumably 74 accomplished in a single call to np.dot(), it's not clear which approach is 75 faster in Python. The matrix multiplication has the attraction of being more 76 general and flexible, and much easier to read. 77 78 Args: 79 num_mel_bins: How many bands in the resulting mel spectrum. This is 80 the number of columns in the output matrix. 81 num_spectrogram_bins: How many bins there are in the source spectrogram 82 data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 83 only contains the nonredundant FFT bins. 84 audio_sample_rate: Samples per second of the audio at the input to the 85 spectrogram. We need this to figure out the actual frequencies for 86 each spectrogram bin, which dictates how they are mapped into mel. 87 lower_edge_hertz: Lower bound on the frequencies to be included in the mel 88 spectrum. This corresponds to the lower edge of the lowest triangular 89 band. 90 upper_edge_hertz: The desired top edge of the highest frequency band. 91 92 Returns: 93 An np.array with shape (num_spectrogram_bins, num_mel_bins). 94 95 Raises: 96 ValueError: if frequency edges are incorrectly ordered. 97 """ 98 audio_sample_rate = tensor_util.constant_value(audio_sample_rate) 99 nyquist_hertz = audio_sample_rate / 2. 100 if lower_edge_hertz >= upper_edge_hertz: 101 raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 102 (lower_edge_hertz, upper_edge_hertz)) 103 spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 104 spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 105 # The i'th mel band (starting from i=1) has center frequency 106 # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 107 # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 108 # the band_edges_mel arrays. 109 band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 110 hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 111 # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 112 # of spectrogram values. 113 mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 114 for i in range(num_mel_bins): 115 lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 116 # Calculate lower and upper slopes for every spectrogram bin. 117 # Line segments are linear in the *mel* domain, not hertz. 118 lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 119 (center_mel - lower_edge_mel)) 120 upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 121 (upper_edge_mel - center_mel)) 122 # .. then intersect them with each other and zero. 123 mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 124 upper_slope)) 125 # HTK excludes the spectrogram DC bin; make sure it always gets a zero 126 # coefficient. 127 mel_weights_matrix[0, :] = 0.0 128 return mel_weights_matrix 129 130 131@tf_test_util.run_all_in_graph_and_eager_modes 132class LinearToMelTest(test.TestCase, parameterized.TestCase): 133 134 @parameterized.parameters( 135 # Defaults. Integer sample rate. 136 (20, 129, 8000, False, 125.0, 3800.0, dtypes.float64), 137 (20, 129, 8000, True, 125.0, 3800.0, dtypes.float64), 138 # Defaults. Float sample rate. 139 (20, 129, 8000.0, False, 125.0, 3800.0, dtypes.float64), 140 (20, 129, 8000.0, True, 125.0, 3800.0, dtypes.float64), 141 # Settings used by Tacotron (https://arxiv.org/abs/1703.10135). 142 (80, 1025, 24000.0, False, 80.0, 12000.0, dtypes.float64)) 143 def test_matches_reference_implementation( 144 self, num_mel_bins, num_spectrogram_bins, sample_rate, 145 use_tensor_sample_rate, lower_edge_hertz, upper_edge_hertz, dtype): 146 if use_tensor_sample_rate: 147 sample_rate = constant_op.constant(sample_rate) 148 mel_matrix_np = spectrogram_to_mel_matrix( 149 num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, 150 upper_edge_hertz, dtype) 151 mel_matrix = mel_ops.linear_to_mel_weight_matrix( 152 num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, 153 upper_edge_hertz, dtype) 154 self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6) 155 156 @parameterized.parameters(dtypes.float32, dtypes.float64) 157 def test_dtypes(self, dtype): 158 # LinSpace is not supported for tf.float16. 159 self.assertEqual(dtype, 160 mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype) 161 162 def test_error(self): 163 # TODO(rjryan): Error types are different under eager. 164 if context.executing_eagerly(): 165 return 166 with self.assertRaises(ValueError): 167 mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0) 168 with self.assertRaises(ValueError): 169 mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0) 170 with self.assertRaises(ValueError): 171 mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1) 172 with self.assertRaises(ValueError): 173 mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100, 174 upper_edge_hertz=10) 175 with self.assertRaises(ValueError): 176 mel_ops.linear_to_mel_weight_matrix(upper_edge_hertz=1000, 177 sample_rate=800) 178 with self.assertRaises(ValueError): 179 mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32) 180 181 @parameterized.parameters(dtypes.float32, dtypes.float64) 182 def test_constant_folding(self, dtype): 183 """Mel functions should be constant foldable.""" 184 if context.executing_eagerly(): 185 return 186 # TODO(rjryan): tf.bfloat16 cannot be constant folded by Grappler. 187 g = ops.Graph() 188 with g.as_default(): 189 mel_matrix = mel_ops.linear_to_mel_weight_matrix( 190 sample_rate=constant_op.constant(8000.0, dtype=dtypes.float32), 191 dtype=dtype) 192 rewritten_graph = test_util.grappler_optimize(g, [mel_matrix]) 193 self.assertLen(rewritten_graph.node, 1) 194 195 def test_num_spectrogram_bins_dynamic(self): 196 num_spectrogram_bins = array_ops.placeholder_with_default( 197 ops.convert_to_tensor(129, dtype=dtypes.int32), shape=()) 198 mel_matrix_np = spectrogram_to_mel_matrix( 199 20, 129, 8000.0, 125.0, 3800.0) 200 mel_matrix = mel_ops.linear_to_mel_weight_matrix( 201 20, num_spectrogram_bins, 8000.0, 125.0, 3800.0) 202 self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6) 203 204 205if __name__ == "__main__": 206 test.main() 207