1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "TestHarness.h"
18
19 #include <android-base/logging.h>
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22
23 #include <algorithm>
24 #include <cmath>
25 #include <functional>
26 #include <limits>
27 #include <map>
28 #include <numeric>
29 #include <set>
30 #include <string>
31 #include <vector>
32
33 namespace test_helper {
34
35 namespace {
36
37 template <typename T>
38 constexpr bool nnIsFloat = std::is_floating_point_v<T> || std::is_same_v<T, _Float16>;
39
40 constexpr uint32_t kMaxNumberOfPrintedErrors = 10;
41
42 // TODO(b/139442217): Allow passing accuracy criteria from spec.
43 // Currently we only need relaxed accuracy criteria on mobilenet tests, so we return the quant8
44 // tolerance simply based on the current test name.
getQuant8AllowedError()45 int getQuant8AllowedError() {
46 const ::testing::TestInfo* const testInfo =
47 ::testing::UnitTest::GetInstance()->current_test_info();
48 const std::string testCaseName = testInfo->test_case_name();
49 const std::string testName = testInfo->name();
50 // We relax the quant8 precision for all tests with mobilenet:
51 // - CTS/VTS GeneratedTest and DynamicOutputShapeTest with mobilenet
52 // - VTS CompilationCachingTest and CompilationCachingSecurityTest except for TOCTOU tests
53 if (testName.find("mobilenet") != std::string::npos ||
54 (testCaseName.find("CompilationCaching") != std::string::npos &&
55 testName.find("TOCTOU") == std::string::npos)) {
56 return 3;
57 } else {
58 return 1;
59 }
60 }
61
getNumberOfElements(const TestOperand & op)62 uint32_t getNumberOfElements(const TestOperand& op) {
63 return std::reduce(op.dimensions.begin(), op.dimensions.end(), 1u, std::multiplies<uint32_t>());
64 }
65
66 // Check if the actual results meet the accuracy criterion.
67 template <typename T>
expectNear(const TestOperand & op,const TestBuffer & result,const AccuracyCriterion & criterion,bool allowInvalid=false)68 void expectNear(const TestOperand& op, const TestBuffer& result, const AccuracyCriterion& criterion,
69 bool allowInvalid = false) {
70 constexpr uint32_t kMinNumberOfElementsToTestBiasMSE = 10;
71 const T* actualBuffer = result.get<T>();
72 const T* expectedBuffer = op.data.get<T>();
73 uint32_t len = getNumberOfElements(op), numErrors = 0, numSkip = 0;
74 double bias = 0.0f, mse = 0.0f;
75 for (uint32_t i = 0; i < len; i++) {
76 // Compare all data types in double for precision and signed arithmetic.
77 double actual = static_cast<double>(actualBuffer[i]);
78 double expected = static_cast<double>(expectedBuffer[i]);
79 double tolerableRange = criterion.atol + criterion.rtol * std::fabs(expected);
80 EXPECT_FALSE(std::isnan(expected));
81
82 // Skip invalid floating point values.
83 if (allowInvalid &&
84 (std::isinf(expected) || (std::is_same_v<T, float> && std::fabs(expected) > 1e3) ||
85 (std::is_same_v<T, _Float16> || std::fabs(expected) > 1e2))) {
86 numSkip++;
87 continue;
88 }
89
90 // Accumulate bias and MSE. Use relative bias and MSE for floating point values.
91 double diff = actual - expected;
92 if constexpr (nnIsFloat<T>) {
93 diff /= std::max(1.0, std::abs(expected));
94 }
95 bias += diff;
96 mse += diff * diff;
97
98 // Print at most kMaxNumberOfPrintedErrors errors by EXPECT_NEAR.
99 if (numErrors < kMaxNumberOfPrintedErrors) {
100 EXPECT_NEAR(expected, actual, tolerableRange) << "When comparing element " << i;
101 }
102 if (std::fabs(actual - expected) > tolerableRange) numErrors++;
103 }
104 EXPECT_EQ(numErrors, 0u);
105
106 // Test bias and MSE.
107 if (len < numSkip + kMinNumberOfElementsToTestBiasMSE) return;
108 bias /= static_cast<double>(len - numSkip);
109 mse /= static_cast<double>(len - numSkip);
110 EXPECT_LE(std::fabs(bias), criterion.bias);
111 EXPECT_LE(mse, criterion.mse);
112 }
113
114 // For boolean values, we expect the number of mismatches does not exceed a certain ratio.
expectBooleanNearlyEqual(const TestOperand & op,const TestBuffer & result,float allowedErrorRatio)115 void expectBooleanNearlyEqual(const TestOperand& op, const TestBuffer& result,
116 float allowedErrorRatio) {
117 const bool8* actualBuffer = result.get<bool8>();
118 const bool8* expectedBuffer = op.data.get<bool8>();
119 uint32_t len = getNumberOfElements(op), numErrors = 0;
120 std::stringstream errorMsg;
121 for (uint32_t i = 0; i < len; i++) {
122 if (expectedBuffer[i] != actualBuffer[i]) {
123 if (numErrors < kMaxNumberOfPrintedErrors)
124 errorMsg << " Expected: " << expectedBuffer[i] << ", actual: " << actualBuffer[i]
125 << ", when comparing element " << i << "\n";
126 numErrors++;
127 }
128 }
129 // When |len| is small, the allowedErrorCount will intentionally ceil at 1, which allows for
130 // greater tolerance.
131 uint32_t allowedErrorCount = static_cast<uint32_t>(std::ceil(allowedErrorRatio * len));
132 EXPECT_LE(numErrors, allowedErrorCount) << errorMsg.str();
133 }
134
135 // Calculates the expected probability from the unnormalized log-probability of
136 // each class in the input and compares it to the actual occurrence of that class
137 // in the output.
expectMultinomialDistributionWithinTolerance(const TestModel & model,const std::vector<TestBuffer> & buffers)138 void expectMultinomialDistributionWithinTolerance(const TestModel& model,
139 const std::vector<TestBuffer>& buffers) {
140 // This function is only for RANDOM_MULTINOMIAL single-operation test.
141 CHECK_EQ(model.referenced.size(), 0u) << "Subgraphs not supported";
142 ASSERT_EQ(model.main.operations.size(), 1u);
143 ASSERT_EQ(model.main.operations[0].type, TestOperationType::RANDOM_MULTINOMIAL);
144 ASSERT_EQ(model.main.inputIndexes.size(), 1u);
145 ASSERT_EQ(model.main.outputIndexes.size(), 1u);
146 ASSERT_EQ(buffers.size(), 1u);
147
148 const auto& inputOperand = model.main.operands[model.main.inputIndexes[0]];
149 const auto& outputOperand = model.main.operands[model.main.outputIndexes[0]];
150 ASSERT_EQ(inputOperand.dimensions.size(), 2u);
151 ASSERT_EQ(outputOperand.dimensions.size(), 2u);
152
153 const int kBatchSize = inputOperand.dimensions[0];
154 const int kNumClasses = inputOperand.dimensions[1];
155 const int kNumSamples = outputOperand.dimensions[1];
156
157 const uint32_t outputLength = getNumberOfElements(outputOperand);
158 const int32_t* outputData = buffers[0].get<int32_t>();
159 std::vector<int> classCounts(kNumClasses);
160 for (uint32_t i = 0; i < outputLength; i++) {
161 classCounts[outputData[i]]++;
162 }
163
164 const uint32_t inputLength = getNumberOfElements(inputOperand);
165 std::vector<float> inputData(inputLength);
166 if (inputOperand.type == TestOperandType::TENSOR_FLOAT32) {
167 const float* inputRaw = inputOperand.data.get<float>();
168 std::copy(inputRaw, inputRaw + inputLength, inputData.begin());
169 } else if (inputOperand.type == TestOperandType::TENSOR_FLOAT16) {
170 const _Float16* inputRaw = inputOperand.data.get<_Float16>();
171 std::transform(inputRaw, inputRaw + inputLength, inputData.begin(),
172 [](_Float16 fp16) { return static_cast<float>(fp16); });
173 } else {
174 FAIL() << "Unknown input operand type for RANDOM_MULTINOMIAL.";
175 }
176
177 for (int b = 0; b < kBatchSize; ++b) {
178 float probabilitySum = 0;
179 const int batchIndex = kBatchSize * b;
180 for (int i = 0; i < kNumClasses; ++i) {
181 probabilitySum += expf(inputData[batchIndex + i]);
182 }
183 for (int i = 0; i < kNumClasses; ++i) {
184 float probability =
185 static_cast<float>(classCounts[i]) / static_cast<float>(kNumSamples);
186 float probabilityExpected = expf(inputData[batchIndex + i]) / probabilitySum;
187 EXPECT_THAT(probability,
188 ::testing::FloatNear(probabilityExpected,
189 model.expectedMultinomialDistributionTolerance));
190 }
191 }
192 }
193
194 } // namespace
195
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers,const AccuracyCriteria & criteria)196 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers,
197 const AccuracyCriteria& criteria) {
198 ASSERT_EQ(model.main.outputIndexes.size(), buffers.size());
199 for (uint32_t i = 0; i < model.main.outputIndexes.size(); i++) {
200 const uint32_t outputIndex = model.main.outputIndexes[i];
201 SCOPED_TRACE(testing::Message()
202 << "When comparing output " << i << " (op" << outputIndex << ")");
203 const auto& operand = model.main.operands[outputIndex];
204 const auto& result = buffers[i];
205 if (operand.isIgnored) continue;
206
207 switch (operand.type) {
208 case TestOperandType::TENSOR_FLOAT32:
209 expectNear<float>(operand, result, criteria.float32, criteria.allowInvalidFpValues);
210 break;
211 case TestOperandType::TENSOR_FLOAT16:
212 expectNear<_Float16>(operand, result, criteria.float16,
213 criteria.allowInvalidFpValues);
214 break;
215 case TestOperandType::TENSOR_INT32:
216 case TestOperandType::INT32:
217 expectNear<int32_t>(operand, result, criteria.int32);
218 break;
219 case TestOperandType::TENSOR_QUANT8_ASYMM:
220 expectNear<uint8_t>(operand, result, criteria.quant8Asymm);
221 break;
222 case TestOperandType::TENSOR_QUANT8_SYMM:
223 expectNear<int8_t>(operand, result, criteria.quant8Symm);
224 break;
225 case TestOperandType::TENSOR_QUANT16_ASYMM:
226 expectNear<uint16_t>(operand, result, criteria.quant16Asymm);
227 break;
228 case TestOperandType::TENSOR_QUANT16_SYMM:
229 expectNear<int16_t>(operand, result, criteria.quant16Symm);
230 break;
231 case TestOperandType::TENSOR_BOOL8:
232 expectBooleanNearlyEqual(operand, result, criteria.bool8AllowedErrorRatio);
233 break;
234 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
235 expectNear<int8_t>(operand, result, criteria.quant8AsymmSigned);
236 break;
237 default:
238 FAIL() << "Data type not supported.";
239 }
240 }
241 }
242
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers)243 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers) {
244 // For RANDOM_MULTINOMIAL test only.
245 if (model.expectedMultinomialDistributionTolerance > 0.0f) {
246 expectMultinomialDistributionWithinTolerance(model, buffers);
247 return;
248 }
249
250 // Decide the default tolerable range.
251 //
252 // For floating-point models, we use the relaxed precision if either
253 // - relaxed computation flag is set
254 // - the model has at least one TENSOR_FLOAT16 operand
255 //
256 // The bias and MSE criteria are implicitly set to the maximum -- we do not enforce these
257 // criteria in normal generated tests.
258 //
259 // TODO: Adjust the error limit based on testing.
260 //
261 AccuracyCriteria criteria = {
262 // The relative tolerance is 5ULP of FP32.
263 .float32 = {.atol = 1e-5, .rtol = 5.0f * 1.1920928955078125e-7},
264 // Both the absolute and relative tolerance are 5ULP of FP16.
265 .float16 = {.atol = 5.0f * 0.0009765625, .rtol = 5.0f * 0.0009765625},
266 .int32 = {.atol = 1},
267 .quant8Asymm = {.atol = 1},
268 .quant8Symm = {.atol = 1},
269 .quant16Asymm = {.atol = 1},
270 .quant16Symm = {.atol = 1},
271 .bool8AllowedErrorRatio = 0.0f,
272 // Since generated tests are hand-calculated, there should be no invalid FP values.
273 .allowInvalidFpValues = false,
274 };
275 bool hasFloat16Inputs = false;
276 model.forEachSubgraph([&hasFloat16Inputs](const TestSubgraph& subgraph) {
277 if (!hasFloat16Inputs) {
278 hasFloat16Inputs = std::any_of(subgraph.operands.begin(), subgraph.operands.end(),
279 [](const TestOperand& op) {
280 return op.type == TestOperandType::TENSOR_FLOAT16;
281 });
282 }
283 });
284 if (model.isRelaxed || hasFloat16Inputs) {
285 criteria.float32 = criteria.float16;
286 }
287 const double quant8AllowedError = getQuant8AllowedError();
288 criteria.quant8Asymm.atol = quant8AllowedError;
289 criteria.quant8AsymmSigned.atol = quant8AllowedError;
290 criteria.quant8Symm.atol = quant8AllowedError;
291
292 checkResults(model, buffers, criteria);
293 }
294
convertQuant8AsymmOperandsToSigned(const TestModel & testModel)295 TestModel convertQuant8AsymmOperandsToSigned(const TestModel& testModel) {
296 auto processSubgraph = [](TestSubgraph* subgraph) {
297 for (TestOperand& operand : subgraph->operands) {
298 if (operand.type == test_helper::TestOperandType::TENSOR_QUANT8_ASYMM) {
299 operand.type = test_helper::TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED;
300 operand.zeroPoint -= 128;
301 const uint8_t* inputOperandData = operand.data.get<uint8_t>();
302 int8_t* outputOperandData = operand.data.getMutable<int8_t>();
303 for (size_t i = 0; i < operand.data.size(); ++i) {
304 outputOperandData[i] =
305 static_cast<int8_t>(static_cast<int32_t>(inputOperandData[i]) - 128);
306 }
307 }
308 }
309 };
310 TestModel converted(testModel.copy());
311 processSubgraph(&converted.main);
312 for (TestSubgraph& subgraph : converted.referenced) {
313 processSubgraph(&subgraph);
314 }
315 return converted;
316 }
317
isQuantizedType(TestOperandType type)318 bool isQuantizedType(TestOperandType type) {
319 static const std::set<TestOperandType> kQuantizedTypes = {
320 TestOperandType::TENSOR_QUANT8_ASYMM,
321 TestOperandType::TENSOR_QUANT8_SYMM,
322 TestOperandType::TENSOR_QUANT16_ASYMM,
323 TestOperandType::TENSOR_QUANT16_SYMM,
324 TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
325 TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
326 };
327 return kQuantizedTypes.count(type) > 0;
328 }
329
isFloatType(TestOperandType type)330 bool isFloatType(TestOperandType type) {
331 static const std::set<TestOperandType> kFloatTypes = {
332 TestOperandType::TENSOR_FLOAT32,
333 TestOperandType::TENSOR_FLOAT16,
334 TestOperandType::FLOAT32,
335 TestOperandType::FLOAT16,
336 };
337 return kFloatTypes.count(type) > 0;
338 }
339
isConstant(TestOperandLifeTime lifetime)340 bool isConstant(TestOperandLifeTime lifetime) {
341 return lifetime == TestOperandLifeTime::CONSTANT_COPY ||
342 lifetime == TestOperandLifeTime::CONSTANT_REFERENCE;
343 }
344
345 namespace {
346
347 const char* kOperationTypeNames[] = {
348 "ADD",
349 "AVERAGE_POOL_2D",
350 "CONCATENATION",
351 "CONV_2D",
352 "DEPTHWISE_CONV_2D",
353 "DEPTH_TO_SPACE",
354 "DEQUANTIZE",
355 "EMBEDDING_LOOKUP",
356 "FLOOR",
357 "FULLY_CONNECTED",
358 "HASHTABLE_LOOKUP",
359 "L2_NORMALIZATION",
360 "L2_POOL",
361 "LOCAL_RESPONSE_NORMALIZATION",
362 "LOGISTIC",
363 "LSH_PROJECTION",
364 "LSTM",
365 "MAX_POOL_2D",
366 "MUL",
367 "RELU",
368 "RELU1",
369 "RELU6",
370 "RESHAPE",
371 "RESIZE_BILINEAR",
372 "RNN",
373 "SOFTMAX",
374 "SPACE_TO_DEPTH",
375 "SVDF",
376 "TANH",
377 "BATCH_TO_SPACE_ND",
378 "DIV",
379 "MEAN",
380 "PAD",
381 "SPACE_TO_BATCH_ND",
382 "SQUEEZE",
383 "STRIDED_SLICE",
384 "SUB",
385 "TRANSPOSE",
386 "ABS",
387 "ARGMAX",
388 "ARGMIN",
389 "AXIS_ALIGNED_BBOX_TRANSFORM",
390 "BIDIRECTIONAL_SEQUENCE_LSTM",
391 "BIDIRECTIONAL_SEQUENCE_RNN",
392 "BOX_WITH_NMS_LIMIT",
393 "CAST",
394 "CHANNEL_SHUFFLE",
395 "DETECTION_POSTPROCESSING",
396 "EQUAL",
397 "EXP",
398 "EXPAND_DIMS",
399 "GATHER",
400 "GENERATE_PROPOSALS",
401 "GREATER",
402 "GREATER_EQUAL",
403 "GROUPED_CONV_2D",
404 "HEATMAP_MAX_KEYPOINT",
405 "INSTANCE_NORMALIZATION",
406 "LESS",
407 "LESS_EQUAL",
408 "LOG",
409 "LOGICAL_AND",
410 "LOGICAL_NOT",
411 "LOGICAL_OR",
412 "LOG_SOFTMAX",
413 "MAXIMUM",
414 "MINIMUM",
415 "NEG",
416 "NOT_EQUAL",
417 "PAD_V2",
418 "POW",
419 "PRELU",
420 "QUANTIZE",
421 "QUANTIZED_16BIT_LSTM",
422 "RANDOM_MULTINOMIAL",
423 "REDUCE_ALL",
424 "REDUCE_ANY",
425 "REDUCE_MAX",
426 "REDUCE_MIN",
427 "REDUCE_PROD",
428 "REDUCE_SUM",
429 "ROI_ALIGN",
430 "ROI_POOLING",
431 "RSQRT",
432 "SELECT",
433 "SIN",
434 "SLICE",
435 "SPLIT",
436 "SQRT",
437 "TILE",
438 "TOPK_V2",
439 "TRANSPOSE_CONV_2D",
440 "UNIDIRECTIONAL_SEQUENCE_LSTM",
441 "UNIDIRECTIONAL_SEQUENCE_RNN",
442 "RESIZE_NEAREST_NEIGHBOR",
443 "QUANTIZED_LSTM",
444 "IF",
445 "WHILE",
446 "ELU",
447 "HARD_SWISH",
448 "FILL",
449 "RANK",
450 };
451
452 const char* kOperandTypeNames[] = {
453 "FLOAT32",
454 "INT32",
455 "UINT32",
456 "TENSOR_FLOAT32",
457 "TENSOR_INT32",
458 "TENSOR_QUANT8_ASYMM",
459 "BOOL",
460 "TENSOR_QUANT16_SYMM",
461 "TENSOR_FLOAT16",
462 "TENSOR_BOOL8",
463 "FLOAT16",
464 "TENSOR_QUANT8_SYMM_PER_CHANNEL",
465 "TENSOR_QUANT16_ASYMM",
466 "TENSOR_QUANT8_SYMM",
467 "TENSOR_QUANT8_ASYMM_SIGNED",
468 };
469
isScalarType(TestOperandType type)470 bool isScalarType(TestOperandType type) {
471 static const std::vector<bool> kIsScalarOperandType = {
472 true, // TestOperandType::FLOAT32
473 true, // TestOperandType::INT32
474 true, // TestOperandType::UINT32
475 false, // TestOperandType::TENSOR_FLOAT32
476 false, // TestOperandType::TENSOR_INT32
477 false, // TestOperandType::TENSOR_QUANT8_ASYMM
478 true, // TestOperandType::BOOL
479 false, // TestOperandType::TENSOR_QUANT16_SYMM
480 false, // TestOperandType::TENSOR_FLOAT16
481 false, // TestOperandType::TENSOR_BOOL8
482 true, // TestOperandType::FLOAT16
483 false, // TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL
484 false, // TestOperandType::TENSOR_QUANT16_ASYMM
485 false, // TestOperandType::TENSOR_QUANT8_SYMM
486 false, // TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED
487 };
488 return kIsScalarOperandType[static_cast<int>(type)];
489 }
490
getOperandClassInSpecFile(TestOperandLifeTime lifetime)491 std::string getOperandClassInSpecFile(TestOperandLifeTime lifetime) {
492 switch (lifetime) {
493 case TestOperandLifeTime::SUBGRAPH_INPUT:
494 return "Input";
495 case TestOperandLifeTime::SUBGRAPH_OUTPUT:
496 return "Output";
497 case TestOperandLifeTime::CONSTANT_COPY:
498 case TestOperandLifeTime::CONSTANT_REFERENCE:
499 case TestOperandLifeTime::NO_VALUE:
500 return "Parameter";
501 case TestOperandLifeTime::TEMPORARY_VARIABLE:
502 return "Internal";
503 default:
504 CHECK(false);
505 return "";
506 }
507 }
508
509 template <typename T>
defaultToStringFunc(const T & value)510 std::string defaultToStringFunc(const T& value) {
511 return std::to_string(value);
512 };
513 template <>
defaultToStringFunc(const _Float16 & value)514 std::string defaultToStringFunc<_Float16>(const _Float16& value) {
515 return defaultToStringFunc(static_cast<float>(value));
516 };
517
518 // Dump floating point values in hex representation.
519 template <typename T>
520 std::string toHexFloatString(const T& value);
521 template <>
toHexFloatString(const float & value)522 std::string toHexFloatString<float>(const float& value) {
523 std::stringstream ss;
524 ss << "\"" << std::hexfloat << value << "\"";
525 return ss.str();
526 };
527 template <>
toHexFloatString(const _Float16 & value)528 std::string toHexFloatString<_Float16>(const _Float16& value) {
529 return toHexFloatString(static_cast<float>(value));
530 };
531
532 template <typename Iterator, class ToStringFunc>
join(const std::string & joint,Iterator begin,Iterator end,ToStringFunc func)533 std::string join(const std::string& joint, Iterator begin, Iterator end, ToStringFunc func) {
534 std::stringstream ss;
535 for (auto it = begin; it < end; it++) {
536 ss << (it == begin ? "" : joint) << func(*it);
537 }
538 return ss.str();
539 }
540
541 template <typename T, class ToStringFunc>
join(const std::string & joint,const std::vector<T> & range,ToStringFunc func)542 std::string join(const std::string& joint, const std::vector<T>& range, ToStringFunc func) {
543 return join(joint, range.begin(), range.end(), func);
544 }
545
546 template <typename T>
dumpTestBufferToSpecFileHelper(const TestBuffer & buffer,bool useHexFloat,std::ostream & os)547 void dumpTestBufferToSpecFileHelper(const TestBuffer& buffer, bool useHexFloat, std::ostream& os) {
548 const T* data = buffer.get<T>();
549 const uint32_t length = buffer.size() / sizeof(T);
550 if constexpr (nnIsFloat<T>) {
551 if (useHexFloat) {
552 os << "from_hex([" << join(", ", data, data + length, toHexFloatString<T>) << "])";
553 return;
554 }
555 }
556 os << "[" << join(", ", data, data + length, defaultToStringFunc<T>) << "]";
557 }
558
559 } // namespace
560
operator <<(std::ostream & os,const TestOperandType & type)561 std::ostream& operator<<(std::ostream& os, const TestOperandType& type) {
562 return os << kOperandTypeNames[static_cast<int>(type)];
563 }
564
operator <<(std::ostream & os,const TestOperationType & type)565 std::ostream& operator<<(std::ostream& os, const TestOperationType& type) {
566 return os << kOperationTypeNames[static_cast<int>(type)];
567 }
568
569 // Dump a test buffer.
dumpTestBuffer(TestOperandType type,const TestBuffer & buffer,bool useHexFloat)570 void SpecDumper::dumpTestBuffer(TestOperandType type, const TestBuffer& buffer, bool useHexFloat) {
571 switch (type) {
572 case TestOperandType::FLOAT32:
573 case TestOperandType::TENSOR_FLOAT32:
574 dumpTestBufferToSpecFileHelper<float>(buffer, useHexFloat, mOs);
575 break;
576 case TestOperandType::INT32:
577 case TestOperandType::TENSOR_INT32:
578 dumpTestBufferToSpecFileHelper<int32_t>(buffer, useHexFloat, mOs);
579 break;
580 case TestOperandType::TENSOR_QUANT8_ASYMM:
581 dumpTestBufferToSpecFileHelper<uint8_t>(buffer, useHexFloat, mOs);
582 break;
583 case TestOperandType::TENSOR_QUANT8_SYMM:
584 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
585 dumpTestBufferToSpecFileHelper<int8_t>(buffer, useHexFloat, mOs);
586 break;
587 case TestOperandType::TENSOR_QUANT16_ASYMM:
588 dumpTestBufferToSpecFileHelper<uint16_t>(buffer, useHexFloat, mOs);
589 break;
590 case TestOperandType::TENSOR_QUANT16_SYMM:
591 dumpTestBufferToSpecFileHelper<int16_t>(buffer, useHexFloat, mOs);
592 break;
593 case TestOperandType::BOOL:
594 case TestOperandType::TENSOR_BOOL8:
595 dumpTestBufferToSpecFileHelper<bool8>(buffer, useHexFloat, mOs);
596 break;
597 case TestOperandType::FLOAT16:
598 case TestOperandType::TENSOR_FLOAT16:
599 dumpTestBufferToSpecFileHelper<_Float16>(buffer, useHexFloat, mOs);
600 break;
601 default:
602 CHECK(false) << "Unknown type when dumping the buffer";
603 }
604 }
605
dumpTestOperand(const TestOperand & operand,uint32_t index)606 void SpecDumper::dumpTestOperand(const TestOperand& operand, uint32_t index) {
607 mOs << "op" << index << " = " << getOperandClassInSpecFile(operand.lifetime) << "(\"op" << index
608 << "\", [\"" << operand.type << "\", ["
609 << join(", ", operand.dimensions, defaultToStringFunc<uint32_t>) << "]";
610 if (operand.scale != 0.0f || operand.zeroPoint != 0) {
611 mOs << ", float.fromhex(" << toHexFloatString(operand.scale) << "), " << operand.zeroPoint;
612 }
613 mOs << "]";
614 if (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
615 operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE) {
616 mOs << ", ";
617 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
618 } else if (operand.lifetime == TestOperandLifeTime::NO_VALUE) {
619 mOs << ", value=None";
620 }
621 mOs << ")";
622 // For quantized data types, append a human-readable scale at the end.
623 if (operand.scale != 0.0f) {
624 mOs << " # scale = " << operand.scale;
625 }
626 // For float buffers, append human-readable values at the end.
627 if (isFloatType(operand.type) &&
628 (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
629 operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE)) {
630 mOs << " # ";
631 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
632 }
633 mOs << "\n";
634 }
635
dumpTestOperation(const TestOperation & operation)636 void SpecDumper::dumpTestOperation(const TestOperation& operation) {
637 auto toOperandName = [](uint32_t index) { return "op" + std::to_string(index); };
638 mOs << "model = model.Operation(\"" << operation.type << "\", "
639 << join(", ", operation.inputs, toOperandName) << ").To("
640 << join(", ", operation.outputs, toOperandName) << ")\n";
641 }
642
dumpTestModel()643 void SpecDumper::dumpTestModel() {
644 CHECK_EQ(kTestModel.referenced.size(), 0u) << "Subgraphs not supported";
645 mOs << "from_hex = lambda l: [float.fromhex(i) for i in l]\n\n";
646
647 // Dump model operands.
648 mOs << "# Model operands\n";
649 for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
650 dumpTestOperand(kTestModel.main.operands[i], i);
651 }
652
653 // Dump model operations.
654 mOs << "\n# Model operations\nmodel = Model()\n";
655 for (const auto& operation : kTestModel.main.operations) {
656 dumpTestOperation(operation);
657 }
658
659 // Dump input/output buffers.
660 mOs << "\n# Example\nExample({\n";
661 for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
662 const auto& operand = kTestModel.main.operands[i];
663 if (operand.lifetime != TestOperandLifeTime::SUBGRAPH_INPUT &&
664 operand.lifetime != TestOperandLifeTime::SUBGRAPH_OUTPUT) {
665 continue;
666 }
667 // For float buffers, dump human-readable values as a comment.
668 if (isFloatType(operand.type)) {
669 mOs << " # op" << i << ": ";
670 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
671 mOs << "\n";
672 }
673 mOs << " op" << i << ": ";
674 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
675 mOs << ",\n";
676 }
677 mOs << "}).DisableLifeTimeVariation()\n";
678 }
679
dumpResults(const std::string & name,const std::vector<TestBuffer> & results)680 void SpecDumper::dumpResults(const std::string& name, const std::vector<TestBuffer>& results) {
681 CHECK_EQ(results.size(), kTestModel.main.outputIndexes.size());
682 mOs << "\n# Results from " << name << "\n{\n";
683 for (uint32_t i = 0; i < results.size(); i++) {
684 const uint32_t outputIndex = kTestModel.main.outputIndexes[i];
685 const auto& operand = kTestModel.main.operands[outputIndex];
686 // For float buffers, dump human-readable values as a comment.
687 if (isFloatType(operand.type)) {
688 mOs << " # op" << outputIndex << ": ";
689 dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/false);
690 mOs << "\n";
691 }
692 mOs << " op" << outputIndex << ": ";
693 dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/true);
694 mOs << ",\n";
695 }
696 mOs << "}\n";
697 }
698
699 template <typename T>
convertOperandToFloat32(const TestOperand & op)700 static TestOperand convertOperandToFloat32(const TestOperand& op) {
701 TestOperand converted = op;
702 converted.type =
703 isScalarType(op.type) ? TestOperandType::FLOAT32 : TestOperandType::TENSOR_FLOAT32;
704 converted.scale = 0.0f;
705 converted.zeroPoint = 0;
706
707 const uint32_t numberOfElements = getNumberOfElements(converted);
708 converted.data = TestBuffer(numberOfElements * sizeof(float));
709 const T* data = op.data.get<T>();
710 float* floatData = converted.data.getMutable<float>();
711
712 if (op.scale != 0.0f) {
713 std::transform(data, data + numberOfElements, floatData, [&op](T val) {
714 return (static_cast<float>(val) - op.zeroPoint) * op.scale;
715 });
716 } else {
717 std::transform(data, data + numberOfElements, floatData,
718 [](T val) { return static_cast<float>(val); });
719 }
720 return converted;
721 }
722
convertToFloat32Model(const TestModel & testModel)723 std::optional<TestModel> convertToFloat32Model(const TestModel& testModel) {
724 // Only single-operation graphs are supported.
725 if (testModel.referenced.size() > 0 || testModel.main.operations.size() > 1) {
726 return std::nullopt;
727 }
728
729 // Check for unsupported operations.
730 CHECK(!testModel.main.operations.empty());
731 const auto& operation = testModel.main.operations[0];
732 // Do not convert type-casting operations.
733 if (operation.type == TestOperationType::DEQUANTIZE ||
734 operation.type == TestOperationType::QUANTIZE ||
735 operation.type == TestOperationType::CAST) {
736 return std::nullopt;
737 }
738 // HASHTABLE_LOOKUP has different behavior in float and quant data types: float
739 // HASHTABLE_LOOKUP will output logical zero when there is a key miss, while quant
740 // HASHTABLE_LOOKUP will output byte zero.
741 if (operation.type == TestOperationType::HASHTABLE_LOOKUP) {
742 return std::nullopt;
743 }
744
745 auto convert = [&testModel, &operation](const TestOperand& op, uint32_t index) {
746 switch (op.type) {
747 case TestOperandType::TENSOR_FLOAT32:
748 case TestOperandType::FLOAT32:
749 case TestOperandType::TENSOR_BOOL8:
750 case TestOperandType::BOOL:
751 case TestOperandType::UINT32:
752 return op;
753 case TestOperandType::INT32:
754 // The third input of PAD_V2 uses INT32 to specify the padded value.
755 if (operation.type == TestOperationType::PAD_V2 && index == operation.inputs[2]) {
756 // The scale and zero point is inherited from the first input.
757 const uint32_t input0Index = operation.inputs[0];
758 const auto& input0 = testModel.main.operands[input0Index];
759 TestOperand scalarWithScaleAndZeroPoint = op;
760 scalarWithScaleAndZeroPoint.scale = input0.scale;
761 scalarWithScaleAndZeroPoint.zeroPoint = input0.zeroPoint;
762 return convertOperandToFloat32<int32_t>(scalarWithScaleAndZeroPoint);
763 }
764 return op;
765 case TestOperandType::TENSOR_INT32:
766 if (op.scale != 0.0f || op.zeroPoint != 0) {
767 return convertOperandToFloat32<int32_t>(op);
768 }
769 return op;
770 case TestOperandType::TENSOR_FLOAT16:
771 case TestOperandType::FLOAT16:
772 return convertOperandToFloat32<_Float16>(op);
773 case TestOperandType::TENSOR_QUANT8_ASYMM:
774 return convertOperandToFloat32<uint8_t>(op);
775 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
776 return convertOperandToFloat32<int8_t>(op);
777 case TestOperandType::TENSOR_QUANT16_ASYMM:
778 return convertOperandToFloat32<uint16_t>(op);
779 case TestOperandType::TENSOR_QUANT16_SYMM:
780 return convertOperandToFloat32<int16_t>(op);
781 default:
782 CHECK(false) << "OperandType not supported";
783 return TestOperand{};
784 }
785 };
786
787 TestModel converted = testModel;
788 for (uint32_t i = 0; i < testModel.main.operands.size(); i++) {
789 converted.main.operands[i] = convert(testModel.main.operands[i], i);
790 }
791 return converted;
792 }
793
794 template <typename T>
setDataFromFloat32Buffer(const TestBuffer & fpBuffer,TestOperand * op)795 static void setDataFromFloat32Buffer(const TestBuffer& fpBuffer, TestOperand* op) {
796 const uint32_t numberOfElements = getNumberOfElements(*op);
797 const float* floatData = fpBuffer.get<float>();
798 T* data = op->data.getMutable<T>();
799
800 if (op->scale != 0.0f) {
801 std::transform(floatData, floatData + numberOfElements, data, [op](float val) {
802 int32_t unclamped = std::round(val / op->scale) + op->zeroPoint;
803 int32_t clamped = std::clamp<int32_t>(unclamped, std::numeric_limits<T>::min(),
804 std::numeric_limits<T>::max());
805 return static_cast<T>(clamped);
806 });
807 } else {
808 std::transform(floatData, floatData + numberOfElements, data,
809 [](float val) { return static_cast<T>(val); });
810 }
811 }
812
setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer> & results,TestModel * model)813 void setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer>& results,
814 TestModel* model) {
815 CHECK_EQ(model->referenced.size(), 0u) << "Subgraphs not supported";
816 CHECK_EQ(model->main.operations.size(), 1u) << "Only single-operation graph is supported";
817
818 for (uint32_t i = 0; i < results.size(); i++) {
819 uint32_t outputIndex = model->main.outputIndexes[i];
820 auto& op = model->main.operands[outputIndex];
821 switch (op.type) {
822 case TestOperandType::TENSOR_FLOAT32:
823 case TestOperandType::FLOAT32:
824 case TestOperandType::TENSOR_BOOL8:
825 case TestOperandType::BOOL:
826 case TestOperandType::INT32:
827 case TestOperandType::UINT32:
828 op.data = results[i];
829 break;
830 case TestOperandType::TENSOR_INT32:
831 if (op.scale != 0.0f) {
832 setDataFromFloat32Buffer<int32_t>(results[i], &op);
833 } else {
834 op.data = results[i];
835 }
836 break;
837 case TestOperandType::TENSOR_FLOAT16:
838 case TestOperandType::FLOAT16:
839 setDataFromFloat32Buffer<_Float16>(results[i], &op);
840 break;
841 case TestOperandType::TENSOR_QUANT8_ASYMM:
842 setDataFromFloat32Buffer<uint8_t>(results[i], &op);
843 break;
844 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
845 setDataFromFloat32Buffer<int8_t>(results[i], &op);
846 break;
847 case TestOperandType::TENSOR_QUANT16_ASYMM:
848 setDataFromFloat32Buffer<uint16_t>(results[i], &op);
849 break;
850 case TestOperandType::TENSOR_QUANT16_SYMM:
851 setDataFromFloat32Buffer<int16_t>(results[i], &op);
852 break;
853 default:
854 CHECK(false) << "OperandType not supported";
855 }
856 }
857 }
858
859 } // namespace test_helper
860