1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/mfcc.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <vector>
21
22 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/compatibility.h"
25 #include "tensorflow/lite/kernels/internal/mfcc_dct.h"
26 #include "tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
29 #include "tensorflow/lite/kernels/internal/tensor.h"
30 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32
33 namespace tflite {
34 namespace ops {
35 namespace custom {
36 namespace mfcc {
37
38 enum KernelType {
39 kReference,
40 };
41
42 typedef struct {
43 float upper_frequency_limit;
44 float lower_frequency_limit;
45 int filterbank_channel_count;
46 int dct_coefficient_count;
47 } TfLiteMfccParams;
48
49 constexpr int kInputTensorWav = 0;
50 constexpr int kInputTensorRate = 1;
51 constexpr int kOutputTensor = 0;
52
Init(TfLiteContext * context,const char * buffer,size_t length)53 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
54 auto* data = new TfLiteMfccParams;
55
56 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
57
58 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
59 data->upper_frequency_limit = m["upper_frequency_limit"].AsInt64();
60 data->lower_frequency_limit = m["lower_frequency_limit"].AsInt64();
61 data->filterbank_channel_count = m["filterbank_channel_count"].AsInt64();
62 data->dct_coefficient_count = m["dct_coefficient_count"].AsInt64();
63 return data;
64 }
65
Free(TfLiteContext * context,void * buffer)66 void Free(TfLiteContext* context, void* buffer) {
67 delete reinterpret_cast<TfLiteMfccParams*>(buffer);
68 }
69
Prepare(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71 auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
72
73 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
74 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
75
76 const TfLiteTensor* input_wav;
77 TF_LITE_ENSURE_OK(context,
78 GetInputSafe(context, node, kInputTensorWav, &input_wav));
79 const TfLiteTensor* input_rate;
80 TF_LITE_ENSURE_OK(context,
81 GetInputSafe(context, node, kInputTensorRate, &input_rate));
82 TfLiteTensor* output;
83 TF_LITE_ENSURE_OK(context,
84 GetOutputSafe(context, node, kOutputTensor, &output));
85
86 TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
87 TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
88
89 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
90 TF_LITE_ENSURE_TYPES_EQ(context, input_wav->type, output->type);
91 TF_LITE_ENSURE_TYPES_EQ(context, input_rate->type, kTfLiteInt32);
92
93 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
94 output_size->data[0] = input_wav->dims->data[0];
95 output_size->data[1] = input_wav->dims->data[1];
96 output_size->data[2] = params->dct_coefficient_count;
97
98 return context->ResizeTensor(context, output, output_size);
99 }
100
101 // Input is a single squared-magnitude spectrogram frame. The input spectrum
102 // is converted to linear magnitude and weighted into bands using a
103 // triangular mel filterbank, and a discrete cosine transform (DCT) of the
104 // values is taken. Output is populated with the lowest dct_coefficient_count
105 // of these values.
106 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)107 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
108 auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
109
110 const TfLiteTensor* input_wav;
111 TF_LITE_ENSURE_OK(context,
112 GetInputSafe(context, node, kInputTensorWav, &input_wav));
113 const TfLiteTensor* input_rate;
114 TF_LITE_ENSURE_OK(context,
115 GetInputSafe(context, node, kInputTensorRate, &input_rate));
116 TfLiteTensor* output;
117 TF_LITE_ENSURE_OK(context,
118 GetOutputSafe(context, node, kOutputTensor, &output));
119
120 const int32 sample_rate = *GetTensorData<int>(input_rate);
121
122 const int spectrogram_channels = input_wav->dims->data[2];
123 const int spectrogram_samples = input_wav->dims->data[1];
124 const int audio_channels = input_wav->dims->data[0];
125
126 internal::Mfcc mfcc;
127 mfcc.set_upper_frequency_limit(params->upper_frequency_limit);
128 mfcc.set_lower_frequency_limit(params->lower_frequency_limit);
129 mfcc.set_filterbank_channel_count(params->filterbank_channel_count);
130 mfcc.set_dct_coefficient_count(params->dct_coefficient_count);
131
132 mfcc.Initialize(spectrogram_channels, sample_rate);
133
134 const float* spectrogram_flat = GetTensorData<float>(input_wav);
135 float* output_flat = GetTensorData<float>(output);
136
137 for (int audio_channel = 0; audio_channel < audio_channels; ++audio_channel) {
138 for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples;
139 ++spectrogram_sample) {
140 const float* sample_data =
141 spectrogram_flat +
142 (audio_channel * spectrogram_samples * spectrogram_channels) +
143 (spectrogram_sample * spectrogram_channels);
144 std::vector<double> mfcc_input(sample_data,
145 sample_data + spectrogram_channels);
146 std::vector<double> mfcc_output;
147 mfcc.Compute(mfcc_input, &mfcc_output);
148 TF_LITE_ENSURE_EQ(context, params->dct_coefficient_count,
149 mfcc_output.size());
150 float* output_data = output_flat +
151 (audio_channel * spectrogram_samples *
152 params->dct_coefficient_count) +
153 (spectrogram_sample * params->dct_coefficient_count);
154 for (int i = 0; i < params->dct_coefficient_count; ++i) {
155 output_data[i] = mfcc_output[i];
156 }
157 }
158 }
159
160 return kTfLiteOk;
161 }
162
163 } // namespace mfcc
164
Register_MFCC()165 TfLiteRegistration* Register_MFCC() {
166 static TfLiteRegistration r = {mfcc::Init, mfcc::Free, mfcc::Prepare,
167 mfcc::Eval<mfcc::kReference>};
168 return &r;
169 }
170
171 } // namespace custom
172 } // namespace ops
173 } // namespace tflite
174