1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h"
17
18 #include "tensorflow/lite/micro/micro_error_reporter.h"
19 #include "tensorflow/lite/micro/testing/micro_test.h"
20
21 namespace tflite {
22 // We don't declare this in the header since it's not a public interface, but we
23 // need to call it to test it, so declare it here instead.
24 void ReverseSortInPlace(int* values, int* ids, int size);
25 } // namespace tflite
26
27 namespace {
28 constexpr int kScratchBufferSize = 4096;
29 unsigned char g_scratch_buffer[kScratchBufferSize];
30 } // namespace
31
32 TF_LITE_MICRO_TESTS_BEGIN
33
TF_LITE_MICRO_TEST(TestReverseSortInPlace)34 TF_LITE_MICRO_TEST(TestReverseSortInPlace) {
35 tflite::MicroErrorReporter micro_error_reporter;
36
37 constexpr int a_size = 10;
38 int a_values[a_size] = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
39 int a_ids[a_size] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
40 const int a_expected_values[a_size] = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
41 const int a_expected_ids[a_size] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
42 tflite::ReverseSortInPlace(a_values, a_ids, a_size);
43 for (int i = 0; i < a_size; ++i) {
44 TF_LITE_MICRO_EXPECT_EQ(a_expected_values[i], a_values[i]);
45 TF_LITE_MICRO_EXPECT_EQ(a_expected_ids[i], a_ids[i]);
46 }
47
48 constexpr int b_size = 10;
49 int b_values[b_size] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
50 int b_ids[b_size] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
51 const int b_expected_values[b_size] = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
52 const int b_expected_ids[b_size] = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
53 tflite::ReverseSortInPlace(b_values, b_ids, b_size);
54 for (int i = 0; i < b_size; ++i) {
55 TF_LITE_MICRO_EXPECT_EQ(b_expected_values[i], b_values[i]);
56 TF_LITE_MICRO_EXPECT_EQ(b_expected_ids[i], b_ids[i]);
57 }
58
59 constexpr int c_size = 100;
60 int c_values[c_size] = {
61 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
62 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
63 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
64 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
65 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
66 int c_ids[c_size] = {
67 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
68 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
69 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
70 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
71 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
72 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99};
73 const int c_expected_values[c_size] = {
74 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
75 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
76 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
77 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
78 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
79 const int c_expected_ids[c_size] = {
80 9, 19, 29, 39, 49, 59, 69, 79, 89, 99, 8, 18, 28, 38, 48, 58, 68,
81 78, 88, 98, 7, 17, 27, 37, 47, 57, 67, 77, 87, 97, 6, 16, 26, 36,
82 46, 56, 66, 76, 86, 96, 5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 4,
83 14, 24, 34, 44, 54, 64, 74, 84, 94, 3, 13, 23, 33, 43, 53, 63, 73,
84 83, 93, 2, 12, 22, 32, 42, 52, 62, 72, 82, 92, 1, 11, 21, 31, 41,
85 51, 61, 71, 81, 91, 0, 10, 20, 30, 40, 50, 60, 70, 80, 90};
86 tflite::ReverseSortInPlace(c_values, c_ids, c_size);
87 for (int i = 0; i < c_size; ++i) {
88 TF_LITE_MICRO_EXPECT_EQ(c_expected_values[i], c_values[i]);
89 TF_LITE_MICRO_EXPECT_EQ(c_expected_ids[i], c_ids[i]);
90 }
91 }
92
TF_LITE_MICRO_TEST(TestGreedyBasics)93 TF_LITE_MICRO_TEST(TestGreedyBasics) {
94 tflite::MicroErrorReporter micro_error_reporter;
95
96 tflite::GreedyMemoryPlanner planner(g_scratch_buffer, kScratchBufferSize);
97 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
98 planner.AddBuffer(µ_error_reporter, 10, 0, 1));
99 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
100 planner.AddBuffer(µ_error_reporter, 20, 2, 3));
101
102 TF_LITE_MICRO_EXPECT_EQ(false,
103 planner.DoAnyBuffersOverlap(µ_error_reporter));
104
105 TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(20),
106 planner.GetMaximumMemorySize());
107
108 int offset = -1;
109 TF_LITE_MICRO_EXPECT_EQ(
110 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 0, &offset));
111 TF_LITE_MICRO_EXPECT_EQ(0, offset);
112
113 TF_LITE_MICRO_EXPECT_EQ(
114 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 1, &offset));
115 TF_LITE_MICRO_EXPECT_EQ(0, offset);
116 }
117
TF_LITE_MICRO_TEST(TestGreedyMedium)118 TF_LITE_MICRO_TEST(TestGreedyMedium) {
119 tflite::MicroErrorReporter micro_error_reporter;
120
121 tflite::GreedyMemoryPlanner planner(g_scratch_buffer, kScratchBufferSize);
122 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
123 planner.AddBuffer(µ_error_reporter, 10, 0, 1));
124 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
125 planner.AddBuffer(µ_error_reporter, 20, 1, 2));
126 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
127 planner.AddBuffer(µ_error_reporter, 30, 2, 3));
128 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
129 planner.AddBuffer(µ_error_reporter, 40, 3, 4));
130 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
131 planner.AddBuffer(µ_error_reporter, 50, 0, 1));
132
133 int offset = -1;
134 TF_LITE_MICRO_EXPECT_EQ(
135 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 0, &offset));
136 TF_LITE_MICRO_EXPECT_EQ(50, offset);
137
138 TF_LITE_MICRO_EXPECT_EQ(
139 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 1, &offset));
140 TF_LITE_MICRO_EXPECT_EQ(70, offset);
141
142 TF_LITE_MICRO_EXPECT_EQ(
143 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 2, &offset));
144 TF_LITE_MICRO_EXPECT_EQ(40, offset);
145
146 TF_LITE_MICRO_EXPECT_EQ(
147 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 3, &offset));
148 TF_LITE_MICRO_EXPECT_EQ(0, offset);
149
150 TF_LITE_MICRO_EXPECT_EQ(
151 kTfLiteOk, planner.GetOffsetForBuffer(µ_error_reporter, 4, &offset));
152 TF_LITE_MICRO_EXPECT_EQ(0, offset);
153
154 planner.PrintMemoryPlan(µ_error_reporter);
155
156 TF_LITE_MICRO_EXPECT_EQ(false,
157 planner.DoAnyBuffersOverlap(µ_error_reporter));
158
159 TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(90),
160 planner.GetMaximumMemorySize());
161 }
162
TF_LITE_MICRO_TEST(TestPersonDetectionModel)163 TF_LITE_MICRO_TEST(TestPersonDetectionModel) {
164 tflite::MicroErrorReporter micro_error_reporter;
165
166 tflite::GreedyMemoryPlanner planner(g_scratch_buffer, kScratchBufferSize);
167 // These buffer sizes and time ranges are taken from the 250KB MobileNet model
168 // used in the person detection example.
169 TF_LITE_MICRO_EXPECT_EQ(
170 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 9216, 0, 29));
171 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
172 planner.AddBuffer(µ_error_reporter, 3, 28, 29));
173 TF_LITE_MICRO_EXPECT_EQ(
174 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 256, 27, 28));
175 TF_LITE_MICRO_EXPECT_EQ(
176 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 2304, 26, 27));
177 TF_LITE_MICRO_EXPECT_EQ(
178 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 2304, 25, 26));
179 TF_LITE_MICRO_EXPECT_EQ(
180 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 2304, 24, 25));
181 TF_LITE_MICRO_EXPECT_EQ(
182 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 1152, 23, 24));
183 TF_LITE_MICRO_EXPECT_EQ(
184 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 22, 23));
185 TF_LITE_MICRO_EXPECT_EQ(
186 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 21, 22));
187 TF_LITE_MICRO_EXPECT_EQ(
188 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 20, 21));
189 TF_LITE_MICRO_EXPECT_EQ(
190 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 19, 20));
191 TF_LITE_MICRO_EXPECT_EQ(
192 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 18, 19));
193 TF_LITE_MICRO_EXPECT_EQ(
194 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 17, 18));
195 TF_LITE_MICRO_EXPECT_EQ(
196 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 16, 17));
197 TF_LITE_MICRO_EXPECT_EQ(
198 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 15, 16));
199 TF_LITE_MICRO_EXPECT_EQ(
200 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 14, 15));
201 TF_LITE_MICRO_EXPECT_EQ(
202 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 13, 14));
203 TF_LITE_MICRO_EXPECT_EQ(
204 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 4608, 12, 13));
205 TF_LITE_MICRO_EXPECT_EQ(
206 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 2304, 11, 12));
207 TF_LITE_MICRO_EXPECT_EQ(
208 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 9216, 10, 11));
209 TF_LITE_MICRO_EXPECT_EQ(
210 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 9216, 9, 10));
211 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
212 planner.AddBuffer(µ_error_reporter, 9216, 8, 9));
213 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
214 planner.AddBuffer(µ_error_reporter, 4608, 7, 8));
215 TF_LITE_MICRO_EXPECT_EQ(
216 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 18432, 6, 7));
217 TF_LITE_MICRO_EXPECT_EQ(
218 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 18432, 5, 6));
219 TF_LITE_MICRO_EXPECT_EQ(
220 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 18432, 4, 5));
221 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
222 planner.AddBuffer(µ_error_reporter, 9216, 3, 4));
223 TF_LITE_MICRO_EXPECT_EQ(
224 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 36864, 2, 3));
225 TF_LITE_MICRO_EXPECT_EQ(
226 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 18432, 1, 2));
227 TF_LITE_MICRO_EXPECT_EQ(
228 kTfLiteOk, planner.AddBuffer(µ_error_reporter, 18432, 0, 1));
229
230 planner.PrintMemoryPlan(µ_error_reporter);
231
232 TF_LITE_MICRO_EXPECT_EQ(false,
233 planner.DoAnyBuffersOverlap(µ_error_reporter));
234
235 // The sum of all the buffers is 241,027 bytes, so we at least expect the plan
236 // to come up with something smaller than this.
237 TF_LITE_MICRO_EXPECT_GT(static_cast<size_t>(241027),
238 planner.GetMaximumMemorySize());
239 }
240
TF_LITE_MICRO_TEST(TestOverlapCase)241 TF_LITE_MICRO_TEST(TestOverlapCase) {
242 tflite::MicroErrorReporter micro_error_reporter;
243
244 tflite::GreedyMemoryPlanner planner(g_scratch_buffer, kScratchBufferSize);
245 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
246 planner.AddBuffer(µ_error_reporter, 100, 0, 1));
247 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
248 planner.AddBuffer(µ_error_reporter, 50, 2, 3));
249 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
250 planner.AddBuffer(µ_error_reporter, 20, 1, 2));
251
252 planner.PrintMemoryPlan(µ_error_reporter);
253
254 TF_LITE_MICRO_EXPECT_EQ(false,
255 planner.DoAnyBuffersOverlap(µ_error_reporter));
256
257 TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(120),
258 planner.GetMaximumMemorySize());
259 }
260
TF_LITE_MICRO_TEST(TestSmallScratch)261 TF_LITE_MICRO_TEST(TestSmallScratch) {
262 tflite::MicroErrorReporter micro_error_reporter;
263
264 constexpr int scratch_buffer_size = 40;
265 unsigned char scratch_buffer[scratch_buffer_size];
266 tflite::GreedyMemoryPlanner planner(scratch_buffer, scratch_buffer_size);
267 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
268 planner.AddBuffer(µ_error_reporter, 100, 0, 1));
269 TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
270 planner.AddBuffer(µ_error_reporter, 50, 2, 3));
271 }
272
273 TF_LITE_MICRO_TESTS_END
274