1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Micro Frontend op.
16
17 #include "tensorflow/lite/experimental/microfrontend/audio_microfrontend.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace custom {
32 namespace {
33
34 using ::testing::ElementsAreArray;
35
36 class MicroFrontendOpModel : public SingleOpModel {
37 public:
MicroFrontendOpModel(int n_input,int n_frame,int n_frequency_per_frame,int n_left_context,int n_right_context,int n_frame_stride,const std::vector<std::vector<int>> & input_shapes)38 MicroFrontendOpModel(int n_input, int n_frame, int n_frequency_per_frame,
39 int n_left_context, int n_right_context,
40 int n_frame_stride,
41 const std::vector<std::vector<int>>& input_shapes)
42 : n_input_(n_input),
43 n_frame_(n_frame),
44 n_frequency_per_frame_(n_frequency_per_frame),
45 n_left_context_(n_left_context),
46 n_right_context_(n_right_context),
47 n_frame_stride_(n_frame_stride) {
48 input_ = AddInput(TensorType_INT16);
49 output_ = AddOutput(TensorType_INT32);
50
51 // Set up and pass in custom options using flexbuffer.
52 flexbuffers::Builder fbb;
53 fbb.Map([&]() {
54 // Parameters to initialize FFT state.
55 fbb.Int("sample_rate", 1000);
56 fbb.Int("window_size", 25);
57 fbb.Int("window_step", 10);
58 fbb.Int("num_channels", 2);
59 fbb.Float("upper_band_limit", 450.0);
60 fbb.Float("lower_band_limit", 8.0);
61 fbb.Int("smoothing_bits", 10);
62 fbb.Float("even_smoothing", 0.025);
63 fbb.Float("odd_smoothing", 0.06);
64 fbb.Float("min_signal_remaining", 0.05);
65 fbb.Bool("enable_pcan", true);
66 fbb.Float("pcan_strength", 0.95);
67 fbb.Float("pcan_offset", 80.0);
68 fbb.Int("gain_bits", 21);
69 fbb.Bool("enable_log", true);
70 fbb.Int("scale_shift", 6);
71
72 // Parameters for micro frontend.
73 fbb.Int("left_context", n_left_context);
74 fbb.Int("right_context", n_right_context);
75 fbb.Int("frame_stride", n_frame_stride);
76 fbb.Bool("zero_padding", true);
77 fbb.Int("out_scale", 1);
78 fbb.Bool("out_float", false);
79 });
80 fbb.Finish();
81 SetCustomOp("MICRO_FRONTEND", fbb.GetBuffer(),
82 Register_AUDIO_MICROFRONTEND);
83 BuildInterpreter(input_shapes);
84 }
85
SetInput(const std::vector<int16_t> & data)86 void SetInput(const std::vector<int16_t>& data) {
87 PopulateTensor(input_, data);
88 }
89
GetOutput()90 std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
91
num_inputs()92 int num_inputs() { return n_input_; }
num_frmes()93 int num_frmes() { return n_frame_; }
num_frequency_per_frame()94 int num_frequency_per_frame() { return n_frequency_per_frame_; }
num_left_context()95 int num_left_context() { return n_left_context_; }
num_right_context()96 int num_right_context() { return n_right_context_; }
num_frame_stride()97 int num_frame_stride() { return n_frame_stride_; }
98
99 protected:
100 int input_;
101 int output_;
102 int n_input_;
103 int n_frame_;
104 int n_frequency_per_frame_;
105 int n_left_context_;
106 int n_right_context_;
107 int n_frame_stride_;
108 };
109
110 class BaseMicroFrontendTest : public ::testing::Test {
111 protected:
112 // Micro frontend input.
113 std::vector<int16_t> micro_frontend_input_;
114
115 // Compares output up to tolerance to the result of the micro_frontend given
116 // the input.
VerifyGoldens(const std::vector<int16_t> & input,const std::vector<std::vector<int>> & output,MicroFrontendOpModel * micro_frontend,float tolerance=1e-5)117 void VerifyGoldens(const std::vector<int16_t>& input,
118 const std::vector<std::vector<int>>& output,
119 MicroFrontendOpModel* micro_frontend,
120 float tolerance = 1e-5) {
121 // Dimensionality check.
122 const int num_inputs = micro_frontend->num_inputs();
123 EXPECT_GT(num_inputs, 0);
124
125 const int num_frames = micro_frontend->num_frmes();
126 EXPECT_GT(num_frames, 0);
127 EXPECT_EQ(num_frames, output.size());
128
129 const int num_frequency_per_frame =
130 micro_frontend->num_frequency_per_frame();
131 EXPECT_GT(num_frequency_per_frame, 0);
132 EXPECT_EQ(num_frequency_per_frame, output[0].size());
133
134 // Set up input.
135 micro_frontend->SetInput(input);
136
137 // Call Invoke.
138 ASSERT_EQ(micro_frontend->Invoke(), kTfLiteOk);
139
140 // Mimic padding behaviour with zero_padding = true.
141 std::vector<int> output_flattened;
142 int anchor;
143 for (anchor = 0; anchor < output.size();
144 anchor += micro_frontend->num_frame_stride()) {
145 int frame;
146 for (frame = anchor - micro_frontend->num_left_context();
147 frame <= anchor + micro_frontend->num_right_context(); ++frame) {
148 if (frame < 0 || frame >= output.size()) {
149 // Padding with zeros.
150 int j;
151 for (j = 0; j < num_frequency_per_frame; ++j) {
152 output_flattened.push_back(0.0);
153 }
154 } else {
155 // Copy real output.
156 for (auto data_point : output[frame]) {
157 output_flattened.push_back(data_point);
158 }
159 }
160 }
161 }
162
163 // Validate result.
164 EXPECT_THAT(micro_frontend->GetOutput(),
165 ElementsAreArray(output_flattened));
166 }
167 }; // namespace
168
169 class TwoConsecutive36InputsMicroFrontendTest : public BaseMicroFrontendTest {
SetUp()170 void SetUp() override {
171 micro_frontend_input_ = {
172 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768,
173 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768,
174 0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768};
175 }
176 };
177
TEST_F(TwoConsecutive36InputsMicroFrontendTest,MicroFrontendBlackBoxTest)178 TEST_F(TwoConsecutive36InputsMicroFrontendTest, MicroFrontendBlackBoxTest) {
179 const int n_input = 36;
180 const int n_frame = 2;
181 const int n_frequency_per_frame = 2;
182
183 MicroFrontendOpModel micro_frontend(n_input, n_frame, n_frequency_per_frame,
184 1, 1, 1,
185 {
186 {n_input},
187 });
188
189 // Verify the final output.
190 const std::vector<std::vector<int>> micro_frontend_golden_output = {
191 {479, 425}, {436, 378}};
192 VerifyGoldens(micro_frontend_input_, micro_frontend_golden_output,
193 µ_frontend);
194 }
195
196 } // namespace
197 } // namespace custom
198 } // namespace ops
199 } // namespace tflite
200