1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16
17 #include <algorithm>
18 #include <complex>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #if !defined(__APPLE__)
25 #include "tensorflow/lite/delegates/flex/delegate.h"
26 #endif
27 #include "tensorflow/lite/kernels/custom_ops_register.h"
28 #include "tensorflow/lite/kernels/gradient/gradient_ops.h"
29 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
30 #include "tensorflow/lite/kernels/perception/perception_ops.h"
31 #include "tensorflow/lite/kernels/register.h"
32 #include "tensorflow/lite/kernels/register_ref.h"
33 #include "tensorflow/lite/kernels/test_delegate_providers.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/testing/join.h"
36 #include "tensorflow/lite/testing/split.h"
37 #include "tensorflow/lite/tools/evaluation/utils.h"
38
39 namespace tflite {
40 namespace testing {
41
42 namespace {
43 const double kRelativeThreshold = 1e-2f;
44 const double kAbsoluteThreshold = 1e-4f;
45
46 // For quantized tests, we use a different error measurement from float ones.
47 // Assumes the baseline is a always a float TF model.
48 // Error of a quantized model compared to the baseline comes from two sources:
49 // 1. the math done with quantized inputs, and
50 // 2. quantization of the output.
51 // Assumes there is no error introduced by source 1, the theoretical maximum
52 // error allowed for the output is 0.5 * scale, because scale is equal to the
53 // size of the quantization bucket.
54 //
55 // As a result, we use `scale` as a unit for measuring the quantization error.
56 // To add the error introduced by source 1 as well, we need to relax the
57 // multiplier from 0.5 to a larger number, which is model/op dependent.
58 // The number below is good enough to account for both the two sources of error
59 // for most quantized op tests to pass.
60 const int kQuantizationErrorMultiplier = 4;
61
62 // Returns the value in the given position in a tensor.
63 template <typename T>
Value(void * data,int index)64 T Value(void* data, int index) {
65 return static_cast<T*>(data)[index];
66 }
67
68 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)69 void SetTensorData(const std::vector<T>& values, void* data) {
70 T* input_ptr = static_cast<T*>(data);
71 std::copy(values.begin(), values.end(), input_ptr);
72 }
73
74 // Implement type erasure with unique_ptr with custom deleter
75 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
76
77 template <typename T>
make_type_erased_array(size_t size)78 unique_void_ptr make_type_erased_array(size_t size) {
79 return unique_void_ptr(static_cast<void*>(new T[size]),
80 [](void* data) { delete[] static_cast<T*>(data); });
81 }
82
InterpretAsQuantized(const TfLiteTensor & tensor)83 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
84 if (tensor.quantization.type == kTfLiteNoQuantization) return false;
85
86 // Quantized single-op models with uint8 input/output type are only used for
87 // EdgeTPU tests.
88 // EdgeTPU tests need to read the quantized values as-is to check for
89 // bit-exactness. As a result we don't interpret the tensor as quantized.
90 // TODO(b/176121243): Add an option to interpret uint8 buffers as
91 // non-quantized type and set if from the child class.
92 if (tensor.type == kTfLiteUInt8) return false;
93
94 if (tensor.quantization.params != nullptr) {
95 auto* quantization =
96 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
97 if (quantization->scale != nullptr && quantization->scale->size == 1 &&
98 quantization->zero_point != nullptr &&
99 quantization->zero_point->size == 1) {
100 return true;
101 }
102 }
103 return false;
104 }
105 } // namespace
106
107 class TfLiteDriver::DataExpectation {
108 public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)109 DataExpectation(double relative_threshold, double absolute_threshold,
110 int quantization_error_multiplier)
111 : data_(nullptr, nullptr),
112 num_elements_(0),
113 relative_threshold_(relative_threshold),
114 absolute_threshold_(absolute_threshold),
115 quantization_error_multiplier_(quantization_error_multiplier) {}
116
117 template <typename T>
SetData(const string & csv_values)118 void SetData(const string& csv_values) {
119 const auto& values = testing::Split<T>(csv_values, ",");
120 num_elements_ = values.size();
121 data_ = make_type_erased_array<T>(num_elements_);
122 SetTensorData(values, data_.get());
123 }
124
125 bool Check(bool verbose, const TfLiteTensor& tensor);
126
127 private:
CompareTwoValuesHelper(float v1,float v2)128 bool CompareTwoValuesHelper(float v1, float v2) {
129 if (std::isnan(v1) || std::isnan(v2)) {
130 return !(std::isnan(v1) && std::isnan(v2));
131 }
132
133 float diff = std::abs(v1 - v2);
134 bool error_is_large = false;
135 // For very small numbers, try absolute error, otherwise go with
136 // relative.
137 if (std::abs(v2) < relative_threshold_) {
138 error_is_large = (diff > absolute_threshold_);
139 } else {
140 error_is_large = (diff > relative_threshold_ * std::abs(v2));
141 }
142 return error_is_large;
143 }
144
CompareTwoValuesHelper(double v1,double v2)145 bool CompareTwoValuesHelper(double v1, double v2) {
146 if (std::isnan(v1) || std::isnan(v2)) {
147 return !(std::isnan(v1) && std::isnan(v2));
148 }
149
150 double diff = std::abs(v1 - v2);
151 bool error_is_large = false;
152 // For very small numbers, try absolute error, otherwise go with
153 // relative.
154 if (std::abs(v2) < relative_threshold_) {
155 error_is_large = (diff > absolute_threshold_);
156 } else {
157 error_is_large = (diff > relative_threshold_ * std::abs(v2));
158 }
159 return error_is_large;
160 }
161
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)162 bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
163 return CompareTwoValues(v1.real(), v2.real()) ||
164 CompareTwoValues(v1.imag(), v2.imag());
165 }
166
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)167 bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
168 return CompareTwoValues(v1.real(), v2.real()) ||
169 CompareTwoValues(v1.imag(), v2.imag());
170 }
171
CompareTwoValues(float v1,float v2)172 bool CompareTwoValues(float v1, float v2) {
173 return CompareTwoValuesHelper(v1, v2);
174 }
175
CompareTwoValues(double v1,double v2)176 bool CompareTwoValues(double v1, double v2) {
177 return CompareTwoValuesHelper(v1, v2);
178 }
179
180 template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)181 bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
182 size_t tensor_size = tensor.bytes / sizeof(T);
183
184 if (tensor_size != num_elements_) {
185 std::cerr << "Expected a tensor with " << num_elements_
186 << " elements, got " << tensor_size << std::endl;
187 std::cerr << "while checking tensor " << tensor.name << std::endl;
188 return false;
189 }
190
191 bool good_output = true;
192 for (int i = 0; i < tensor_size; ++i) {
193 TS computed = Value<T>(tensor.data.raw, i);
194 TS reference = Value<T>(data_.get(), i);
195 if (CompareTwoValues(computed, reference)) {
196 good_output = false;
197 if (verbose) {
198 std::cerr << " index " << i << ": got " << computed
199 << ", but expected " << reference << std::endl;
200 }
201 }
202 }
203 return good_output;
204 }
205
206 bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
207 bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
208
209 unique_void_ptr data_;
210 size_t num_elements_;
211 double relative_threshold_;
212 double absolute_threshold_;
213 int quantization_error_multiplier_;
214 };
215
216 class TfLiteDriver::ShapeExpectation {
217 public:
ShapeExpectation(const string & csv_values)218 explicit ShapeExpectation(const string& csv_values)
219 : shape_(testing::Split<int32_t>(csv_values, ",")) {}
220
CheckShape(bool verbose,const TfLiteTensor & tensor)221 bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
222 bool valid = true;
223 if (tensor.dims->size == shape_.size()) {
224 for (int i = 0; i < shape_.size(); ++i) {
225 if (shape_[i] != tensor.dims->data[i]) {
226 valid = false;
227 }
228 }
229 } else {
230 valid = false;
231 }
232 if (!valid && verbose) {
233 std::cerr << "Incorrect output shape while checking tensor "
234 << tensor.name << std::endl;
235 std::cerr << "TFLite output shape: ";
236 for (int i = 0; i < tensor.dims->size; ++i) {
237 std::cerr << tensor.dims->data[i] << ", ";
238 }
239 std::cerr << std::endl;
240 std::cerr << "Expected output shape: ";
241 for (int i = 0; i < shape_.size(); ++i) {
242 std::cerr << shape_[i] << ", ";
243 }
244 std::cerr << std::endl;
245 }
246 return valid;
247 }
248
249 private:
250 std::vector<int32_t> shape_;
251 };
252
253 template <>
SetData(const string & csv_values)254 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
255 string s = absl::HexStringToBytes(csv_values);
256 data_ = make_type_erased_array<char>(s.size());
257 memcpy(data_.get(), s.data(), s.size());
258 }
259
TypedCheckString(bool verbose,const TfLiteTensor & tensor)260 bool TfLiteDriver::DataExpectation::TypedCheckString(
261 bool verbose, const TfLiteTensor& tensor) {
262 if (tensor.data.raw == nullptr) {
263 if (verbose) {
264 std::cerr << " got empty string" << std::endl;
265 }
266 return false;
267 }
268 int expected_num_strings = GetStringCount(data_.get());
269 int returned_num_strings = GetStringCount(&tensor);
270 if (expected_num_strings != returned_num_strings) {
271 if (verbose) {
272 std::cerr << " string count differ: got " << returned_num_strings
273 << ", but expected " << expected_num_strings << std::endl;
274 }
275 return false;
276 }
277 for (int i = 0; i < returned_num_strings; ++i) {
278 auto expected_ref = GetString(data_.get(), i);
279 auto returned_ref = GetString(&tensor, i);
280 if (expected_ref.len != returned_ref.len) {
281 if (verbose) {
282 std::cerr << " index " << i << ": got string of size "
283 << returned_ref.len << ", but expected size "
284 << expected_ref.len << std::endl;
285 }
286 return false;
287 }
288 if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
289 if (verbose) {
290 std::cerr << " index " << i << ": strings are different" << std::endl;
291 }
292 return false;
293 }
294 }
295
296 return true;
297 }
298
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)299 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
300 const TfLiteTensor& tensor) {
301 auto* quantization =
302 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
303 const float scale = quantization->scale->data[0];
304 const int32_t zero_point = quantization->zero_point->data[0];
305
306 bool good_result = true;
307 int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
308 for (int i = 0; i < tensor.bytes / int_size; i++) {
309 int32_t computed =
310 tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
311 const float dequantized =
312 static_cast<float>(scale * (computed - zero_point));
313 int error_multiplier = quantization_error_multiplier_;
314 // If we are doing int16 symmetric quantization of activations, we need to
315 // bump up the potential error. Since the weights are quantized to 8 bits
316 // and the activations are 16bits, the output is could be getting
317 // effectively 8bit error instead of 16bit error. So we need to multiply the
318 // error mulitplier by 255 (the difference in number of values between a
319 // 16bit and 8bit number)
320 if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
321 const float reference = Value<float>(data_.get(), i);
322 if (std::abs(dequantized - reference) > error_multiplier * scale) {
323 if (verbose) {
324 std::cerr << " index " << i << ": got " << dequantized
325 << ", but expected " << reference << std::endl;
326 }
327 good_result = false;
328 }
329 }
330 return good_result;
331 }
332
Check(bool verbose,const TfLiteTensor & tensor)333 bool TfLiteDriver::DataExpectation::Check(bool verbose,
334 const TfLiteTensor& tensor) {
335 if (InterpretAsQuantized(tensor)) {
336 return QuantizedCheck(verbose, tensor);
337 }
338
339 switch (tensor.type) {
340 case kTfLiteFloat32:
341 return TypedCheck<float, float>(verbose, tensor);
342 case kTfLiteInt32:
343 return TypedCheck<int32_t, float>(verbose, tensor);
344 case kTfLiteUInt32:
345 return TypedCheck<uint32_t, float>(verbose, tensor);
346 case kTfLiteInt64:
347 return TypedCheck<int64_t, float>(verbose, tensor);
348 case kTfLiteUInt64:
349 return TypedCheck<uint64_t, float>(verbose, tensor);
350 case kTfLiteUInt8:
351 return TypedCheck<uint8_t, float>(verbose, tensor);
352 case kTfLiteInt8:
353 return TypedCheck<int8_t, float>(verbose, tensor);
354 case kTfLiteInt16:
355 return TypedCheck<int16_t, float>(verbose, tensor);
356 case kTfLiteBool:
357 return TypedCheck<bool, float>(verbose, tensor);
358 case kTfLiteString:
359 return TypedCheckString(verbose, tensor);
360 case kTfLiteComplex64:
361 return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
362 tensor);
363 case kTfLiteComplex128:
364 return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
365 tensor);
366 case kTfLiteFloat64:
367 return TypedCheck<double, double>(verbose, tensor);
368 default:
369 fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
370 return false;
371 }
372 }
373
374 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)375 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
376 return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
377 argv);
378 }
379
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)380 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
381 : delegate_(nullptr, nullptr),
382 relative_threshold_(kRelativeThreshold),
383 absolute_threshold_(kAbsoluteThreshold),
384 quantization_error_multiplier_(kQuantizationErrorMultiplier) {
385 if (reference_kernel) {
386 resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
387 } else {
388 // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
389 // are fully validated against TfLite delegates.
390 resolver_.reset(
391 new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
392 ops::builtin::BuiltinOpResolver* builtin_op_resolver_ =
393 reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
394 builtin_op_resolver_->AddCustom("IRFFT2D",
395 tflite::ops::custom::Register_IRFFT2D());
396 builtin_op_resolver_->AddCustom(
397 "AvgPool3D", tflite::ops::custom::Register_AVG_POOL_3D());
398 builtin_op_resolver_->AddCustom(
399 "MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
400 builtin_op_resolver_->AddCustom("Roll",
401 tflite::ops::custom::Register_ROLL());
402 tflite::ops::custom::AddGradientOps(builtin_op_resolver_);
403 tflite::ops::custom::AddParseExampleOp(builtin_op_resolver_);
404 tflite::ops::custom::AddPerceptionOps(builtin_op_resolver_);
405 }
406
407 switch (delegate_type) {
408 case DelegateType::kNone:
409 break;
410 case DelegateType::kNnapi:
411 delegate_ = evaluation::CreateNNAPIDelegate();
412 break;
413 case DelegateType::kGpu:
414 delegate_ = evaluation::CreateGPUDelegate();
415 break;
416 case DelegateType::kFlex:
417 #if !defined(__APPLE__)
418 delegate_ = FlexDelegate::Create();
419 #endif
420 break;
421 }
422 }
423
~TfLiteDriver()424 TfLiteDriver::~TfLiteDriver() {
425 for (auto t : tensors_to_deallocate_) {
426 DeallocateStringTensor(t.second);
427 }
428 }
429
AllocateTensors()430 void TfLiteDriver::AllocateTensors() {
431 if (must_allocate_tensors_) {
432 if (interpreter_->AllocateTensors() != kTfLiteOk) {
433 Invalidate("Failed to allocate tensors");
434 return;
435 }
436 ResetLSTMStateTensors();
437 must_allocate_tensors_ = false;
438 }
439 }
440
LoadModel(const string & bin_file_path)441 void TfLiteDriver::LoadModel(const string& bin_file_path) {
442 if (!IsValid()) return;
443
444 model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
445 if (!model_) {
446 Invalidate("Failed to mmap model " + bin_file_path);
447 return;
448 }
449 InterpreterBuilder(*model_, *resolver_)(&interpreter_);
450 if (!interpreter_) {
451 Invalidate("Failed build interpreter");
452 return;
453 }
454 if (delegate_) {
455 if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
456 Invalidate("Unable to the build graph using the delegate");
457 return;
458 }
459 } else {
460 auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
461 for (auto& one : delegate_providers->CreateAllDelegates()) {
462 if (interpreter_->ModifyGraphWithDelegate(std::move(one.delegate)) !=
463 kTfLiteOk) {
464 Invalidate(
465 "Unable to the build graph using the delegate initialized from "
466 "tflite::KernelTestDelegateProviders");
467 return;
468 }
469 }
470 }
471
472 must_allocate_tensors_ = true;
473 }
474
ResetTensor(int id)475 void TfLiteDriver::ResetTensor(int id) {
476 if (!IsValid()) return;
477 auto* tensor = interpreter_->tensor(id);
478 memset(tensor->data.raw, 0, tensor->bytes);
479 }
480
ReshapeTensor(int id,const string & csv_values)481 void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
482 if (!IsValid()) return;
483 if (interpreter_->ResizeInputTensor(
484 id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
485 Invalidate("Failed to resize input tensor " + std::to_string(id));
486 return;
487 }
488 must_allocate_tensors_ = true;
489 }
490
SetInput(int id,const string & csv_values)491 void TfLiteDriver::SetInput(int id, const string& csv_values) {
492 if (!IsValid()) return;
493 auto* tensor = interpreter_->tensor(id);
494 switch (tensor->type) {
495 case kTfLiteFloat32: {
496 const auto& values = testing::Split<float>(csv_values, ",");
497 if (!CheckSizes<float>(tensor->bytes, values.size())) return;
498 SetTensorData(values, tensor->data.raw);
499 break;
500 }
501 case kTfLiteInt32: {
502 const auto& values = testing::Split<int32_t>(csv_values, ",");
503 if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
504 SetTensorData(values, tensor->data.raw);
505 break;
506 }
507 case kTfLiteUInt32: {
508 const auto& values = testing::Split<uint32_t>(csv_values, ",");
509 if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
510 SetTensorData(values, tensor->data.raw);
511 break;
512 }
513 case kTfLiteInt64: {
514 const auto& values = testing::Split<int64_t>(csv_values, ",");
515 if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
516 SetTensorData(values, tensor->data.raw);
517 break;
518 }
519 case kTfLiteUInt64: {
520 const auto& values = testing::Split<uint64_t>(csv_values, ",");
521 if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
522 SetTensorData(values, tensor->data.raw);
523 break;
524 }
525 case kTfLiteUInt8: {
526 const auto& values = testing::Split<uint8_t>(csv_values, ",");
527 if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
528 SetTensorData(values, tensor->data.raw);
529 break;
530 }
531 case kTfLiteInt8: {
532 const auto& values = testing::Split<int8_t>(csv_values, ",");
533 if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
534 SetTensorData(values, tensor->data.raw);
535 break;
536 }
537 case kTfLiteInt16: {
538 const auto& values = testing::Split<int16_t>(csv_values, ",");
539 if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
540 SetTensorData(values, tensor->data.raw);
541 break;
542 }
543 case kTfLiteBool: {
544 const auto& values = testing::Split<bool>(csv_values, ",");
545 if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
546 SetTensorData(values, tensor->data.raw);
547 break;
548 }
549 case kTfLiteString: {
550 string s = absl::HexStringToBytes(csv_values);
551
552 DeallocateStringTensor(tensors_to_deallocate_[id]);
553 AllocateStringTensor(id, s.size(), tensor);
554 memcpy(tensor->data.raw, s.data(), s.size());
555
556 break;
557 }
558 case kTfLiteComplex64: {
559 const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
560 if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
561 return;
562 SetTensorData(values, tensor->data.raw);
563 break;
564 }
565 case kTfLiteComplex128: {
566 const auto& values =
567 testing::Split<std::complex<double>>(csv_values, ",");
568 if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
569 return;
570 SetTensorData(values, tensor->data.raw);
571 break;
572 }
573 default:
574 Invalidate(absl::StrCat("Unsupported tensor type ",
575 TfLiteTypeGetName(tensor->type),
576 " in TfLiteDriver::SetInput"));
577 return;
578 }
579 }
580
SetThreshold(double relative_threshold,double absolute_threshold)581 void TfLiteDriver::SetThreshold(double relative_threshold,
582 double absolute_threshold) {
583 relative_threshold_ = relative_threshold;
584 absolute_threshold_ = absolute_threshold;
585 }
586
SetQuantizationErrorMultiplier(int quantization_error_multiplier)587 void TfLiteDriver::SetQuantizationErrorMultiplier(
588 int quantization_error_multiplier) {
589 quantization_error_multiplier_ = quantization_error_multiplier;
590 }
591
SetExpectation(int id,const string & csv_values)592 void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
593 if (!IsValid()) return;
594 auto* tensor = interpreter_->tensor(id);
595 if (expected_output_.count(id) != 0) {
596 Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
597 }
598 expected_output_[id].reset(
599 new DataExpectation(relative_threshold_, absolute_threshold_,
600 quantization_error_multiplier_));
601
602 if (InterpretAsQuantized(*tensor)) {
603 expected_output_[id]->SetData<float>(csv_values);
604 return;
605 }
606
607 switch (tensor->type) {
608 case kTfLiteFloat32:
609 expected_output_[id]->SetData<float>(csv_values);
610 break;
611 case kTfLiteInt32:
612 expected_output_[id]->SetData<int32_t>(csv_values);
613 break;
614 case kTfLiteUInt32:
615 expected_output_[id]->SetData<uint32_t>(csv_values);
616 break;
617 case kTfLiteInt64:
618 expected_output_[id]->SetData<int64_t>(csv_values);
619 break;
620 case kTfLiteUInt64:
621 expected_output_[id]->SetData<uint64_t>(csv_values);
622 break;
623 case kTfLiteUInt8:
624 expected_output_[id]->SetData<uint8_t>(csv_values);
625 break;
626 case kTfLiteInt8:
627 expected_output_[id]->SetData<int8_t>(csv_values);
628 break;
629 case kTfLiteInt16:
630 expected_output_[id]->SetData<int16_t>(csv_values);
631 break;
632 case kTfLiteBool:
633 expected_output_[id]->SetData<bool>(csv_values);
634 break;
635 case kTfLiteString:
636 expected_output_[id]->SetData<string>(csv_values);
637 break;
638 case kTfLiteFloat64:
639 expected_output_[id]->SetData<double>(csv_values);
640 break;
641 case kTfLiteComplex64:
642 expected_output_[id]->SetData<std::complex<float>>(csv_values);
643 break;
644 case kTfLiteComplex128:
645 expected_output_[id]->SetData<std::complex<double>>(csv_values);
646 break;
647 default:
648 Invalidate(absl::StrCat("Unsupported tensor type ",
649 TfLiteTypeGetName(tensor->type),
650 " in TfLiteDriver::SetExpectation"));
651 return;
652 }
653 }
654
SetShapeExpectation(int id,const string & csv_values)655 void TfLiteDriver::SetShapeExpectation(int id, const string& csv_values) {
656 if (!IsValid()) return;
657 if (expected_output_shape_.count(id) != 0) {
658 Invalidate(
659 absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
660 }
661 expected_output_shape_[id].reset(new ShapeExpectation(csv_values));
662 }
663
Invoke()664 void TfLiteDriver::Invoke() {
665 if (!IsValid()) return;
666 if (interpreter_->Invoke() != kTfLiteOk) {
667 Invalidate("Failed to invoke interpreter");
668 }
669 }
670
CheckResults()671 bool TfLiteDriver::CheckResults() {
672 if (!IsValid()) return false;
673 bool success = true;
674 for (const auto& p : expected_output_) {
675 int id = p.first;
676 auto* tensor = interpreter_->tensor(id);
677 if (!p.second->Check(/*verbose=*/false, *tensor)) {
678 // Do not invalidate anything here. Instead, simply output the
679 // differences and return false. Invalidating would prevent all
680 // subsequent invocations from running..
681 std::cerr << "There were errors in invocation '" << GetInvocationId()
682 << "', output tensor '" << id << "':" << std::endl;
683 p.second->Check(/*verbose=*/true, *tensor);
684 success = false;
685 SetOverallSuccess(false);
686 }
687 }
688 for (const auto& p : expected_output_shape_) {
689 int id = p.first;
690 auto* tensor = interpreter_->tensor(id);
691 if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
692 // Do not invalidate anything here. Instead, simply output the
693 // differences and return false. Invalidating would prevent all
694 // subsequent invocations from running..
695 std::cerr << "There were errors in invocation '" << GetInvocationId()
696 << "', output tensor '" << id << "':" << std::endl;
697 p.second->CheckShape(/*verbose=*/true, *tensor);
698 success = false;
699 SetOverallSuccess(false);
700 }
701 }
702 expected_output_.clear();
703 return success;
704 }
705
ResetLSTMStateTensors()706 void TfLiteDriver::ResetLSTMStateTensors() {
707 interpreter_->ResetVariableTensors();
708 }
709
ReadOutput(int id)710 string TfLiteDriver::ReadOutput(int id) {
711 auto* tensor = interpreter_->tensor(id);
712 int num_elements = 1;
713
714 for (int i = 0; i < tensor->dims->size; ++i) {
715 num_elements *= tensor->dims->data[i];
716 }
717
718 switch (tensor->type) {
719 case kTfLiteFloat32:
720 return JoinDefault(tensor->data.f, num_elements, ",");
721 case kTfLiteInt32:
722 return JoinDefault(tensor->data.i32, num_elements, ",");
723 case kTfLiteUInt32:
724 return JoinDefault(tensor->data.u32, num_elements, ",");
725 case kTfLiteInt64:
726 return JoinDefault(tensor->data.i64, num_elements, ",");
727 case kTfLiteUInt64:
728 return JoinDefault(tensor->data.u64, num_elements, ",");
729 case kTfLiteUInt8:
730 return Join(tensor->data.uint8, num_elements, ",");
731 case kTfLiteInt8:
732 return Join(tensor->data.int8, num_elements, ",");
733 case kTfLiteInt16:
734 return Join(tensor->data.i16, num_elements, ",");
735 case kTfLiteBool:
736 return JoinDefault(tensor->data.b, num_elements, ",");
737 default:
738 Invalidate(absl::StrCat("Unsupported tensor type ",
739 TfLiteTypeGetName(tensor->type),
740 " in TfLiteDriver::ReadOutput"));
741 return "";
742 }
743 }
744
745 } // namespace testing
746 } // namespace tflite
747