• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <vector>
21 
22 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/spectrogram.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace custom {
34 namespace audio_spectrogram {
35 
36 constexpr int kInputTensor = 0;
37 constexpr int kOutputTensor = 0;
38 
39 enum KernelType {
40   kReference,
41 };
42 
43 typedef struct {
44   int window_size;
45   int stride;
46   bool magnitude_squared;
47   int output_height;
48   internal::Spectrogram* spectrogram;
49 } TfLiteAudioSpectrogramParams;
50 
Init(TfLiteContext * context,const char * buffer,size_t length)51 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
52   auto* data = new TfLiteAudioSpectrogramParams;
53 
54   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
55 
56   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
57   data->window_size = m["window_size"].AsInt64();
58   data->stride = m["stride"].AsInt64();
59   data->magnitude_squared = m["magnitude_squared"].AsBool();
60 
61   data->spectrogram = new internal::Spectrogram;
62 
63   return data;
64 }
65 
Free(TfLiteContext * context,void * buffer)66 void Free(TfLiteContext* context, void* buffer) {
67   auto* params = reinterpret_cast<TfLiteAudioSpectrogramParams*>(buffer);
68   delete params->spectrogram;
69   delete params;
70 }
71 
Prepare(TfLiteContext * context,TfLiteNode * node)72 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
73   auto* params =
74       reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
75 
76   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
77   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
78 
79   const TfLiteTensor* input;
80   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
81   TfLiteTensor* output;
82   TF_LITE_ENSURE_OK(context,
83                     GetOutputSafe(context, node, kOutputTensor, &output));
84 
85   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
86 
87   TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
88   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
89 
90   TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
91                                                           params->stride));
92   const int64_t sample_count = input->dims->data[0];
93   const int64_t length_minus_window = (sample_count - params->window_size);
94   if (length_minus_window < 0) {
95     params->output_height = 0;
96   } else {
97     params->output_height = 1 + (length_minus_window / params->stride);
98   }
99   TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
100   output_size->data[0] = input->dims->data[1];
101   output_size->data[1] = params->output_height;
102   output_size->data[2] = params->spectrogram->output_frequency_channels();
103 
104   return context->ResizeTensor(context, output, output_size);
105 }
106 
107 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109   auto* params =
110       reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
111 
112   const TfLiteTensor* input;
113   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
114   TfLiteTensor* output;
115   TF_LITE_ENSURE_OK(context,
116                     GetOutputSafe(context, node, kOutputTensor, &output));
117 
118   TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
119                                                           params->stride));
120 
121   const float* input_data = GetTensorData<float>(input);
122 
123   const int64_t sample_count = input->dims->data[0];
124   const int64_t channel_count = input->dims->data[1];
125 
126   const int64_t output_width = params->spectrogram->output_frequency_channels();
127 
128   float* output_flat = GetTensorData<float>(output);
129 
130   std::vector<float> input_for_channel(sample_count);
131   for (int64_t channel = 0; channel < channel_count; ++channel) {
132     float* output_slice =
133         output_flat + (channel * params->output_height * output_width);
134     for (int i = 0; i < sample_count; ++i) {
135       input_for_channel[i] = input_data[i * channel_count + channel];
136     }
137     std::vector<std::vector<float>> spectrogram_output;
138     TF_LITE_ENSURE(context,
139                    params->spectrogram->ComputeSquaredMagnitudeSpectrogram(
140                        input_for_channel, &spectrogram_output));
141     TF_LITE_ENSURE_EQ(context, spectrogram_output.size(),
142                       params->output_height);
143     TF_LITE_ENSURE(context, spectrogram_output.empty() ||
144                                 (spectrogram_output[0].size() == output_width));
145     for (int row_index = 0; row_index < params->output_height; ++row_index) {
146       const std::vector<float>& spectrogram_row = spectrogram_output[row_index];
147       TF_LITE_ENSURE_EQ(context, spectrogram_row.size(), output_width);
148       float* output_row = output_slice + (row_index * output_width);
149       if (params->magnitude_squared) {
150         for (int i = 0; i < output_width; ++i) {
151           output_row[i] = spectrogram_row[i];
152         }
153       } else {
154         for (int i = 0; i < output_width; ++i) {
155           output_row[i] = sqrtf(spectrogram_row[i]);
156         }
157       }
158     }
159   }
160   return kTfLiteOk;
161 }
162 
163 }  // namespace audio_spectrogram
164 
Register_AUDIO_SPECTROGRAM()165 TfLiteRegistration* Register_AUDIO_SPECTROGRAM() {
166   static TfLiteRegistration r = {
167       audio_spectrogram::Init, audio_spectrogram::Free,
168       audio_spectrogram::Prepare,
169       audio_spectrogram::Eval<audio_spectrogram::kReference>};
170   return &r;
171 }
172 
173 }  // namespace custom
174 }  // namespace ops
175 }  // namespace tflite
176