1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
16 #include "tensorflow/lite/context.h"
17 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
18 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace audio_microfrontend {
26
27 constexpr int kInputTensor = 0;
28 constexpr int kOutputTensor = 0;
29
30 typedef struct {
31 int sample_rate;
32 FrontendState* state;
33 int left_context;
34 int right_context;
35 int frame_stride;
36 bool zero_padding;
37 int out_scale;
38 bool out_float;
39 } TfLiteAudioMicrofrontendParams;
40
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42 auto* data = new TfLiteAudioMicrofrontendParams;
43
44 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
45 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
46
47 data->sample_rate = m["sample_rate"].AsInt32();
48
49 struct FrontendConfig config;
50 config.window.size_ms = m["window_size"].AsInt32();
51 config.window.step_size_ms = m["window_step"].AsInt32();
52 config.filterbank.num_channels = m["num_channels"].AsInt32();
53 config.filterbank.upper_band_limit = m["upper_band_limit"].AsFloat();
54 config.filterbank.lower_band_limit = m["lower_band_limit"].AsFloat();
55 config.noise_reduction.smoothing_bits = m["smoothing_bits"].AsInt32();
56 config.noise_reduction.even_smoothing = m["even_smoothing"].AsFloat();
57 config.noise_reduction.odd_smoothing = m["odd_smoothing"].AsFloat();
58 config.noise_reduction.min_signal_remaining =
59 m["min_signal_remaining"].AsFloat();
60 config.pcan_gain_control.enable_pcan = m["enable_pcan"].AsBool();
61 config.pcan_gain_control.strength = m["pcan_strength"].AsFloat();
62 config.pcan_gain_control.offset = m["pcan_offset"].AsFloat();
63 config.pcan_gain_control.gain_bits = m["gain_bits"].AsInt32();
64 config.log_scale.enable_log = m["enable_log"].AsBool();
65 config.log_scale.scale_shift = m["scale_shift"].AsInt32();
66
67 data->state = new FrontendState;
68 FrontendPopulateState(&config, data->state, data->sample_rate);
69
70 data->left_context = m["left_context"].AsInt32();
71 data->right_context = m["right_context"].AsInt32();
72 data->frame_stride = m["frame_stride"].AsInt32();
73 data->zero_padding = m["zero_padding"].AsBool();
74 data->out_scale = m["out_scale"].AsInt32();
75 data->out_float = m["out_float"].AsBool();
76
77 return data;
78 }
79
Free(TfLiteContext * context,void * buffer)80 void Free(TfLiteContext* context, void* buffer) {
81 auto* data = reinterpret_cast<TfLiteAudioMicrofrontendParams*>(buffer);
82 FrontendFreeStateContents(data->state);
83 delete data->state;
84 delete data;
85 }
86
Prepare(TfLiteContext * context,TfLiteNode * node)87 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
88 auto* data =
89 reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
90
91 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
92 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
93
94 const TfLiteTensor* input;
95 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
96 TfLiteTensor* output;
97 TF_LITE_ENSURE_OK(context,
98 GetOutputSafe(context, node, kOutputTensor, &output));
99
100 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
101
102 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt16);
103 output->type = kTfLiteInt32;
104 if (data->out_float) {
105 output->type = kTfLiteFloat32;
106 }
107
108 TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
109 int num_frames = 0;
110 if (input->dims->data[0] >= data->state->window.size) {
111 num_frames = (input->dims->data[0] - data->state->window.size) /
112 data->state->window.step / data->frame_stride +
113 1;
114 }
115 output_size->data[0] = num_frames;
116 output_size->data[1] = data->state->filterbank.num_channels *
117 (1 + data->left_context + data->right_context);
118
119 return context->ResizeTensor(context, output, output_size);
120 }
121
122 template <typename T>
GenerateFeatures(TfLiteAudioMicrofrontendParams * data,const TfLiteTensor * input,TfLiteTensor * output)123 void GenerateFeatures(TfLiteAudioMicrofrontendParams* data,
124 const TfLiteTensor* input, TfLiteTensor* output) {
125 const int16_t* audio_data = GetTensorData<int16_t>(input);
126 int64_t audio_size = input->dims->data[0];
127
128 T* filterbanks_flat = GetTensorData<T>(output);
129
130 int num_frames = 0;
131 if (audio_size >= data->state->window.size) {
132 num_frames = (input->dims->data[0] - data->state->window.size) /
133 data->state->window.step +
134 1;
135 }
136 std::vector<std::vector<T>> frame_buffer(num_frames);
137
138 int frame_index = 0;
139 while (audio_size > 0) {
140 size_t num_samples_read;
141 struct FrontendOutput output = FrontendProcessSamples(
142 data->state, audio_data, audio_size, &num_samples_read);
143 audio_data += num_samples_read;
144 audio_size -= num_samples_read;
145
146 if (output.values != nullptr) {
147 frame_buffer[frame_index].reserve(output.size);
148 int i;
149 for (i = 0; i < output.size; ++i) {
150 frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) /
151 data->out_scale);
152 }
153 ++frame_index;
154 }
155 }
156
157 int index = 0;
158 std::vector<T> pad(data->state->filterbank.num_channels, 0);
159 int anchor;
160 for (anchor = 0; anchor < frame_buffer.size(); anchor += data->frame_stride) {
161 int frame;
162 for (frame = anchor - data->left_context;
163 frame <= anchor + data->right_context; ++frame) {
164 std::vector<T>* feature;
165 if (data->zero_padding && (frame < 0 || frame >= frame_buffer.size())) {
166 feature = &pad;
167 } else if (frame < 0) {
168 feature = &frame_buffer[0];
169 } else if (frame >= frame_buffer.size()) {
170 feature = &frame_buffer[frame_buffer.size() - 1];
171 } else {
172 feature = &frame_buffer[frame];
173 }
174 for (auto f : *feature) {
175 filterbanks_flat[index++] = f;
176 }
177 }
178 }
179 }
180
Eval(TfLiteContext * context,TfLiteNode * node)181 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
182 auto* data =
183 reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
184 FrontendReset(data->state);
185
186 const TfLiteTensor* input;
187 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
188 TfLiteTensor* output;
189 TF_LITE_ENSURE_OK(context,
190 GetOutputSafe(context, node, kOutputTensor, &output));
191
192 if (data->out_float) {
193 GenerateFeatures<float>(data, input, output);
194 } else {
195 GenerateFeatures<int32>(data, input, output);
196 }
197
198 return kTfLiteOk;
199 }
200
201 } // namespace audio_microfrontend
202
Register_AUDIO_MICROFRONTEND()203 TfLiteRegistration* Register_AUDIO_MICROFRONTEND() {
204 static TfLiteRegistration r = {
205 audio_microfrontend::Init, audio_microfrontend::Free,
206 audio_microfrontend::Prepare, audio_microfrontend::Eval};
207 return &r;
208 }
209
210 } // namespace custom
211 } // namespace ops
212 } // namespace tflite
213