• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for mel_ops."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.framework import test_util as tf_test_util
26from tensorflow.python.kernel_tests.signal import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops.signal import mel_ops
29from tensorflow.python.platform import test
30
31# mel spectrum constants and functions.
32_MEL_BREAK_FREQUENCY_HERTZ = 700.0
33_MEL_HIGH_FREQUENCY_Q = 1127.0
34
35
36def hertz_to_mel(frequencies_hertz):
37  """Convert frequencies to mel scale using HTK formula.
38
39  Copied from
40  https://github.com/tensorflow/models/blob/master/research/audioset/mel_features.py.
41
42  Args:
43    frequencies_hertz: Scalar or np.array of frequencies in hertz.
44
45  Returns:
46    Object of same size as frequencies_hertz containing corresponding values
47    on the mel scale.
48  """
49  return _MEL_HIGH_FREQUENCY_Q * np.log(
50      1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
51
52
53def spectrogram_to_mel_matrix(num_mel_bins=20,
54                              num_spectrogram_bins=129,
55                              audio_sample_rate=8000,
56                              lower_edge_hertz=125.0,
57                              upper_edge_hertz=3800.0,
58                              unused_dtype=None):
59  """Return a matrix that can post-multiply spectrogram rows to make mel.
60
61  Copied from
62  https://github.com/tensorflow/models/blob/master/research/audioset/mel_features.py.
63
64  Returns a np.array matrix A that can be used to post-multiply a matrix S of
65  spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
66  "mel spectrogram" M of frames x num_mel_bins.  M = S A.
67
68  The classic HTK algorithm exploits the complementarity of adjacent mel bands
69  to multiply each FFT bin by only one mel weight, then add it, with positive
70  and negative signs, to the two adjacent mel bands to which that bin
71  contributes.  Here, by expressing this operation as a matrix multiply, we go
72  from num_fft multiplies per frame (plus around 2*num_fft adds) to around
73  num_fft^2 multiplies and adds.  However, because these are all presumably
74  accomplished in a single call to np.dot(), it's not clear which approach is
75  faster in Python.  The matrix multiplication has the attraction of being more
76  general and flexible, and much easier to read.
77
78  Args:
79    num_mel_bins: How many bands in the resulting mel spectrum.  This is
80      the number of columns in the output matrix.
81    num_spectrogram_bins: How many bins there are in the source spectrogram
82      data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
83      only contains the nonredundant FFT bins.
84    audio_sample_rate: Samples per second of the audio at the input to the
85      spectrogram. We need this to figure out the actual frequencies for
86      each spectrogram bin, which dictates how they are mapped into mel.
87    lower_edge_hertz: Lower bound on the frequencies to be included in the mel
88      spectrum.  This corresponds to the lower edge of the lowest triangular
89      band.
90    upper_edge_hertz: The desired top edge of the highest frequency band.
91
92  Returns:
93    An np.array with shape (num_spectrogram_bins, num_mel_bins).
94
95  Raises:
96    ValueError: if frequency edges are incorrectly ordered.
97  """
98  audio_sample_rate = tensor_util.constant_value(audio_sample_rate)
99  nyquist_hertz = audio_sample_rate / 2.
100  if lower_edge_hertz >= upper_edge_hertz:
101    raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
102                     (lower_edge_hertz, upper_edge_hertz))
103  spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
104  spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
105  # The i'th mel band (starting from i=1) has center frequency
106  # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
107  # band_edges_mel[i+1].  Thus, we need num_mel_bins + 2 values in
108  # the band_edges_mel arrays.
109  band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
110                               hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
111  # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
112  # of spectrogram values.
113  mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
114  for i in range(num_mel_bins):
115    lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
116    # Calculate lower and upper slopes for every spectrogram bin.
117    # Line segments are linear in the *mel* domain, not hertz.
118    lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
119                   (center_mel - lower_edge_mel))
120    upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
121                   (upper_edge_mel - center_mel))
122    # .. then intersect them with each other and zero.
123    mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
124                                                          upper_slope))
125  # HTK excludes the spectrogram DC bin; make sure it always gets a zero
126  # coefficient.
127  mel_weights_matrix[0, :] = 0.0
128  return mel_weights_matrix
129
130
131@tf_test_util.run_all_in_graph_and_eager_modes
132class LinearToMelTest(test.TestCase, parameterized.TestCase):
133
134  @parameterized.parameters(
135      # Defaults. Integer sample rate.
136      (20, 129, 8000, False, 125.0, 3800.0, dtypes.float64),
137      (20, 129, 8000, True, 125.0, 3800.0, dtypes.float64),
138      # Defaults. Float sample rate.
139      (20, 129, 8000.0, False, 125.0, 3800.0, dtypes.float64),
140      (20, 129, 8000.0, True, 125.0, 3800.0, dtypes.float64),
141      # Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
142      (80, 1025, 24000.0, False, 80.0, 12000.0, dtypes.float64))
143  def test_matches_reference_implementation(
144      self, num_mel_bins, num_spectrogram_bins, sample_rate,
145      use_tensor_sample_rate, lower_edge_hertz, upper_edge_hertz, dtype):
146    if use_tensor_sample_rate:
147      sample_rate = constant_op.constant(sample_rate)
148    mel_matrix_np = spectrogram_to_mel_matrix(
149        num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
150        upper_edge_hertz, dtype)
151    mel_matrix = mel_ops.linear_to_mel_weight_matrix(
152        num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
153        upper_edge_hertz, dtype)
154    self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6)
155
156  @parameterized.parameters(dtypes.float32, dtypes.float64)
157  def test_dtypes(self, dtype):
158    # LinSpace is not supported for tf.float16.
159    self.assertEqual(dtype,
160                     mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
161
162  def test_error(self):
163    # TODO(rjryan): Error types are different under eager.
164    if context.executing_eagerly():
165      return
166    with self.assertRaises(ValueError):
167      mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
168    with self.assertRaises(ValueError):
169      mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
170    with self.assertRaises(ValueError):
171      mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
172    with self.assertRaises(ValueError):
173      mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100,
174                                          upper_edge_hertz=10)
175    with self.assertRaises(ValueError):
176      mel_ops.linear_to_mel_weight_matrix(upper_edge_hertz=1000,
177                                          sample_rate=800)
178    with self.assertRaises(ValueError):
179      mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
180
181  @parameterized.parameters(dtypes.float32, dtypes.float64)
182  def test_constant_folding(self, dtype):
183    """Mel functions should be constant foldable."""
184    if context.executing_eagerly():
185      return
186    # TODO(rjryan): tf.bfloat16 cannot be constant folded by Grappler.
187    g = ops.Graph()
188    with g.as_default():
189      mel_matrix = mel_ops.linear_to_mel_weight_matrix(
190          sample_rate=constant_op.constant(8000.0, dtype=dtypes.float32),
191          dtype=dtype)
192      rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
193      self.assertLen(rewritten_graph.node, 1)
194
195  def test_num_spectrogram_bins_dynamic(self):
196    num_spectrogram_bins = array_ops.placeholder_with_default(
197        ops.convert_to_tensor(129, dtype=dtypes.int32), shape=())
198    mel_matrix_np = spectrogram_to_mel_matrix(
199        20, 129, 8000.0, 125.0, 3800.0)
200    mel_matrix = mel_ops.linear_to_mel_weight_matrix(
201        20, num_spectrogram_bins, 8000.0, 125.0, 3800.0)
202    self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6)
203
204
205if __name__ == "__main__":
206  test.main()
207