• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Unit test for TFLite Micro Frontend op.
16 
17 #include "tensorflow/lite/experimental/microfrontend/audio_microfrontend.h"
18 
19 #include <memory>
20 #include <vector>
21 
22 #include <gmock/gmock.h>
23 #include <gtest/gtest.h>
24 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/kernels/test_util.h"
27 #include "tensorflow/lite/model.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace custom {
32 namespace {
33 
34 using ::testing::ElementsAreArray;
35 
36 class MicroFrontendOpModel : public SingleOpModel {
37  public:
MicroFrontendOpModel(int n_input,int n_frame,int n_frequency_per_frame,int n_left_context,int n_right_context,int n_frame_stride,const std::vector<std::vector<int>> & input_shapes)38   MicroFrontendOpModel(int n_input, int n_frame, int n_frequency_per_frame,
39                        int n_left_context, int n_right_context,
40                        int n_frame_stride,
41                        const std::vector<std::vector<int>>& input_shapes)
42       : n_input_(n_input),
43         n_frame_(n_frame),
44         n_frequency_per_frame_(n_frequency_per_frame),
45         n_left_context_(n_left_context),
46         n_right_context_(n_right_context),
47         n_frame_stride_(n_frame_stride) {
48     input_ = AddInput(TensorType_INT16);
49     output_ = AddOutput(TensorType_INT32);
50 
51     // Set up and pass in custom options using flexbuffer.
52     flexbuffers::Builder fbb;
53     fbb.Map([&]() {
54       // Parameters to initialize FFT state.
55       fbb.Int("sample_rate", 1000);
56       fbb.Int("window_size", 25);
57       fbb.Int("window_step", 10);
58       fbb.Int("num_channels", 2);
59       fbb.Float("upper_band_limit", 450.0);
60       fbb.Float("lower_band_limit", 8.0);
61       fbb.Int("smoothing_bits", 10);
62       fbb.Float("even_smoothing", 0.025);
63       fbb.Float("odd_smoothing", 0.06);
64       fbb.Float("min_signal_remaining", 0.05);
65       fbb.Bool("enable_pcan", true);
66       fbb.Float("pcan_strength", 0.95);
67       fbb.Float("pcan_offset", 80.0);
68       fbb.Int("gain_bits", 21);
69       fbb.Bool("enable_log", true);
70       fbb.Int("scale_shift", 6);
71 
72       // Parameters for micro frontend.
73       fbb.Int("left_context", n_left_context);
74       fbb.Int("right_context", n_right_context);
75       fbb.Int("frame_stride", n_frame_stride);
76       fbb.Bool("zero_padding", true);
77       fbb.Int("out_scale", 1);
78       fbb.Bool("out_float", false);
79     });
80     fbb.Finish();
81     SetCustomOp("MICRO_FRONTEND", fbb.GetBuffer(),
82                 Register_AUDIO_MICROFRONTEND);
83     BuildInterpreter(input_shapes);
84   }
85 
SetInput(const std::vector<int16_t> & data)86   void SetInput(const std::vector<int16_t>& data) {
87     PopulateTensor(input_, data);
88   }
89 
GetOutput()90   std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
91 
num_inputs()92   int num_inputs() { return n_input_; }
num_frmes()93   int num_frmes() { return n_frame_; }
num_frequency_per_frame()94   int num_frequency_per_frame() { return n_frequency_per_frame_; }
num_left_context()95   int num_left_context() { return n_left_context_; }
num_right_context()96   int num_right_context() { return n_right_context_; }
num_frame_stride()97   int num_frame_stride() { return n_frame_stride_; }
98 
99  protected:
100   int input_;
101   int output_;
102   int n_input_;
103   int n_frame_;
104   int n_frequency_per_frame_;
105   int n_left_context_;
106   int n_right_context_;
107   int n_frame_stride_;
108 };
109 
110 class BaseMicroFrontendTest : public ::testing::Test {
111  protected:
112   // Micro frontend input.
113   std::vector<int16_t> micro_frontend_input_;
114 
115   // Compares output up to tolerance to the result of the micro_frontend given
116   // the input.
VerifyGoldens(const std::vector<int16_t> & input,const std::vector<std::vector<int>> & output,MicroFrontendOpModel * micro_frontend,float tolerance=1e-5)117   void VerifyGoldens(const std::vector<int16_t>& input,
118                      const std::vector<std::vector<int>>& output,
119                      MicroFrontendOpModel* micro_frontend,
120                      float tolerance = 1e-5) {
121     // Dimensionality check.
122     const int num_inputs = micro_frontend->num_inputs();
123     EXPECT_GT(num_inputs, 0);
124 
125     const int num_frames = micro_frontend->num_frmes();
126     EXPECT_GT(num_frames, 0);
127     EXPECT_EQ(num_frames, output.size());
128 
129     const int num_frequency_per_frame =
130         micro_frontend->num_frequency_per_frame();
131     EXPECT_GT(num_frequency_per_frame, 0);
132     EXPECT_EQ(num_frequency_per_frame, output[0].size());
133 
134     // Set up input.
135     micro_frontend->SetInput(input);
136 
137     // Call Invoke.
138     ASSERT_EQ(micro_frontend->Invoke(), kTfLiteOk);
139 
140     // Mimic padding behaviour with zero_padding = true.
141     std::vector<int> output_flattened;
142     int anchor;
143     for (anchor = 0; anchor < output.size();
144          anchor += micro_frontend->num_frame_stride()) {
145       int frame;
146       for (frame = anchor - micro_frontend->num_left_context();
147            frame <= anchor + micro_frontend->num_right_context(); ++frame) {
148         if (frame < 0 || frame >= output.size()) {
149           // Padding with zeros.
150           int j;
151           for (j = 0; j < num_frequency_per_frame; ++j) {
152             output_flattened.push_back(0.0);
153           }
154         } else {
155           // Copy real output.
156           for (auto data_point : output[frame]) {
157             output_flattened.push_back(data_point);
158           }
159         }
160       }
161     }
162 
163     // Validate result.
164     EXPECT_THAT(micro_frontend->GetOutput(),
165                 ElementsAreArray(output_flattened));
166   }
167 };  // namespace
168 
169 class TwoConsecutive36InputsMicroFrontendTest : public BaseMicroFrontendTest {
SetUp()170   void SetUp() override {
171     micro_frontend_input_ = {
172         0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768,
173         0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768,
174         0, 32767, 0, -32768, 0, 32767, 0, -32768, 0, 32767, 0, -32768};
175   }
176 };
177 
TEST_F(TwoConsecutive36InputsMicroFrontendTest,MicroFrontendBlackBoxTest)178 TEST_F(TwoConsecutive36InputsMicroFrontendTest, MicroFrontendBlackBoxTest) {
179   const int n_input = 36;
180   const int n_frame = 2;
181   const int n_frequency_per_frame = 2;
182 
183   MicroFrontendOpModel micro_frontend(n_input, n_frame, n_frequency_per_frame,
184                                       1, 1, 1,
185                                       {
186                                           {n_input},
187                                       });
188 
189   // Verify the final output.
190   const std::vector<std::vector<int>> micro_frontend_golden_output = {
191       {479, 425}, {436, 378}};
192   VerifyGoldens(micro_frontend_input_, micro_frontend_golden_output,
193                 &micro_frontend);
194 }
195 
196 }  // namespace
197 }  // namespace custom
198 }  // namespace ops
199 }  // namespace tflite
200