• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
16 #include "tensorflow/lite/context.h"
17 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
18 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace audio_microfrontend {
26 
27 constexpr int kInputTensor = 0;
28 constexpr int kOutputTensor = 0;
29 
30 typedef struct {
31   int sample_rate;
32   FrontendState* state;
33   int left_context;
34   int right_context;
35   int frame_stride;
36   bool zero_padding;
37   int out_scale;
38   bool out_float;
39 } TfLiteAudioMicrofrontendParams;
40 
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42   auto* data = new TfLiteAudioMicrofrontendParams;
43 
44   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
45   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
46 
47   data->sample_rate = m["sample_rate"].AsInt32();
48 
49   struct FrontendConfig config;
50   config.window.size_ms = m["window_size"].AsInt32();
51   config.window.step_size_ms = m["window_step"].AsInt32();
52   config.filterbank.num_channels = m["num_channels"].AsInt32();
53   config.filterbank.upper_band_limit = m["upper_band_limit"].AsFloat();
54   config.filterbank.lower_band_limit = m["lower_band_limit"].AsFloat();
55   config.noise_reduction.smoothing_bits = m["smoothing_bits"].AsInt32();
56   config.noise_reduction.even_smoothing = m["even_smoothing"].AsFloat();
57   config.noise_reduction.odd_smoothing = m["odd_smoothing"].AsFloat();
58   config.noise_reduction.min_signal_remaining =
59       m["min_signal_remaining"].AsFloat();
60   config.pcan_gain_control.enable_pcan = m["enable_pcan"].AsBool();
61   config.pcan_gain_control.strength = m["pcan_strength"].AsFloat();
62   config.pcan_gain_control.offset = m["pcan_offset"].AsFloat();
63   config.pcan_gain_control.gain_bits = m["gain_bits"].AsInt32();
64   config.log_scale.enable_log = m["enable_log"].AsBool();
65   config.log_scale.scale_shift = m["scale_shift"].AsInt32();
66 
67   data->state = new FrontendState;
68   FrontendPopulateState(&config, data->state, data->sample_rate);
69 
70   data->left_context = m["left_context"].AsInt32();
71   data->right_context = m["right_context"].AsInt32();
72   data->frame_stride = m["frame_stride"].AsInt32();
73   data->zero_padding = m["zero_padding"].AsBool();
74   data->out_scale = m["out_scale"].AsInt32();
75   data->out_float = m["out_float"].AsBool();
76 
77   return data;
78 }
79 
Free(TfLiteContext * context,void * buffer)80 void Free(TfLiteContext* context, void* buffer) {
81   auto* data = reinterpret_cast<TfLiteAudioMicrofrontendParams*>(buffer);
82   FrontendFreeStateContents(data->state);
83   delete data->state;
84   delete data;
85 }
86 
Prepare(TfLiteContext * context,TfLiteNode * node)87 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
88   auto* data =
89       reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
90 
91   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
92   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
93 
94   const TfLiteTensor* input;
95   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
96   TfLiteTensor* output;
97   TF_LITE_ENSURE_OK(context,
98                     GetOutputSafe(context, node, kOutputTensor, &output));
99 
100   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
101 
102   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt16);
103   output->type = kTfLiteInt32;
104   if (data->out_float) {
105     output->type = kTfLiteFloat32;
106   }
107 
108   TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
109   int num_frames = 0;
110   if (input->dims->data[0] >= data->state->window.size) {
111     num_frames = (input->dims->data[0] - data->state->window.size) /
112                      data->state->window.step / data->frame_stride +
113                  1;
114   }
115   output_size->data[0] = num_frames;
116   output_size->data[1] = data->state->filterbank.num_channels *
117                          (1 + data->left_context + data->right_context);
118 
119   return context->ResizeTensor(context, output, output_size);
120 }
121 
122 template <typename T>
GenerateFeatures(TfLiteAudioMicrofrontendParams * data,const TfLiteTensor * input,TfLiteTensor * output)123 void GenerateFeatures(TfLiteAudioMicrofrontendParams* data,
124                       const TfLiteTensor* input, TfLiteTensor* output) {
125   const int16_t* audio_data = GetTensorData<int16_t>(input);
126   int64_t audio_size = input->dims->data[0];
127 
128   T* filterbanks_flat = GetTensorData<T>(output);
129 
130   int num_frames = 0;
131   if (audio_size >= data->state->window.size) {
132     num_frames = (input->dims->data[0] - data->state->window.size) /
133                      data->state->window.step +
134                  1;
135   }
136   std::vector<std::vector<T>> frame_buffer(num_frames);
137 
138   int frame_index = 0;
139   while (audio_size > 0) {
140     size_t num_samples_read;
141     struct FrontendOutput output = FrontendProcessSamples(
142         data->state, audio_data, audio_size, &num_samples_read);
143     audio_data += num_samples_read;
144     audio_size -= num_samples_read;
145 
146     if (output.values != nullptr) {
147       frame_buffer[frame_index].reserve(output.size);
148       int i;
149       for (i = 0; i < output.size; ++i) {
150         frame_buffer[frame_index].push_back(static_cast<T>(output.values[i]) /
151                                             data->out_scale);
152       }
153       ++frame_index;
154     }
155   }
156 
157   int index = 0;
158   std::vector<T> pad(data->state->filterbank.num_channels, 0);
159   int anchor;
160   for (anchor = 0; anchor < frame_buffer.size(); anchor += data->frame_stride) {
161     int frame;
162     for (frame = anchor - data->left_context;
163          frame <= anchor + data->right_context; ++frame) {
164       std::vector<T>* feature;
165       if (data->zero_padding && (frame < 0 || frame >= frame_buffer.size())) {
166         feature = &pad;
167       } else if (frame < 0) {
168         feature = &frame_buffer[0];
169       } else if (frame >= frame_buffer.size()) {
170         feature = &frame_buffer[frame_buffer.size() - 1];
171       } else {
172         feature = &frame_buffer[frame];
173       }
174       for (auto f : *feature) {
175         filterbanks_flat[index++] = f;
176       }
177     }
178   }
179 }
180 
Eval(TfLiteContext * context,TfLiteNode * node)181 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
182   auto* data =
183       reinterpret_cast<TfLiteAudioMicrofrontendParams*>(node->user_data);
184   FrontendReset(data->state);
185 
186   const TfLiteTensor* input;
187   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
188   TfLiteTensor* output;
189   TF_LITE_ENSURE_OK(context,
190                     GetOutputSafe(context, node, kOutputTensor, &output));
191 
192   if (data->out_float) {
193     GenerateFeatures<float>(data, input, output);
194   } else {
195     GenerateFeatures<int32>(data, input, output);
196   }
197 
198   return kTfLiteOk;
199 }
200 
201 }  // namespace audio_microfrontend
202 
Register_AUDIO_MICROFRONTEND()203 TfLiteRegistration* Register_AUDIO_MICROFRONTEND() {
204   static TfLiteRegistration r = {
205       audio_microfrontend::Init, audio_microfrontend::Free,
206       audio_microfrontend::Prepare, audio_microfrontend::Eval};
207   return &r;
208 }
209 
210 }  // namespace custom
211 }  // namespace ops
212 }  // namespace tflite
213