1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16
17 #include <algorithm>
18 #include <complex>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #if !defined(__APPLE__)
25 #include "tensorflow/lite/delegates/flex/delegate.h"
26 #endif
27 #include "tensorflow/lite/kernels/custom_ops_register.h"
28 #include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
29 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
30 #include "tensorflow/lite/kernels/perception/perception_ops.h"
31 #include "tensorflow/lite/kernels/register.h"
32 #include "tensorflow/lite/kernels/register_ref.h"
33 #include "tensorflow/lite/kernels/test_delegate_providers.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/testing/join.h"
36 #include "tensorflow/lite/testing/split.h"
37 #include "tensorflow/lite/tools/evaluation/utils.h"
38
39 namespace tflite {
40 namespace testing {
41
42 namespace {
43 const double kRelativeThreshold = 1e-2f;
44 const double kAbsoluteThreshold = 1e-4f;
45
46 // For quantized tests, we use a different error measurement from float ones.
47 // Assumes the baseline is a always a float TF model.
48 // Error of a quantized model compared to the baseline comes from two sources:
49 // 1. the math done with quantized inputs, and
50 // 2. quantization of the output.
51 // Assumes there is no error introduced by source 1, the theoretical maximum
52 // error allowed for the output is 0.5 * scale, because scale is equal to the
53 // size of the quantization bucket.
54 //
55 // As a result, we use `scale` as a unit for measuring the quantization error.
56 // To add the error introduced by source 1 as well, we need to relax the
57 // multiplier from 0.5 to a larger number, which is model/op dependent.
58 // The number below is good enough to account for both the two sources of error
59 // for most quantized op tests to pass.
60 const int kQuantizationErrorMultiplier = 4;
61
62 // Returns the value in the given position in a tensor.
63 template <typename T>
Value(void * data,int index)64 T Value(void* data, int index) {
65 return static_cast<T*>(data)[index];
66 }
67
68 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)69 void SetTensorData(const std::vector<T>& values, void* data) {
70 T* input_ptr = static_cast<T*>(data);
71 std::copy(values.begin(), values.end(), input_ptr);
72 }
73
74 // Implement type erasure with unique_ptr with custom deleter
75 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
76
77 template <typename T>
make_type_erased_array(size_t size)78 unique_void_ptr make_type_erased_array(size_t size) {
79 return unique_void_ptr(static_cast<void*>(new T[size]),
80 [](void* data) { delete[] static_cast<T*>(data); });
81 }
82
InterpretAsQuantized(const TfLiteTensor & tensor)83 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
84 if (tensor.quantization.type == kTfLiteNoQuantization) return false;
85
86 // Quantized single-op models with uint8 input/output type are only used for
87 // EdgeTPU tests.
88 // EdgeTPU tests need to read the quantized values as-is to check for
89 // bit-exactness. As a result we don't interpret the tensor as quantized.
90 // TODO(b/176121243): Add an option to interpret uint8 buffers as
91 // non-quantized type and set if from the child class.
92 if (tensor.type == kTfLiteUInt8) return false;
93
94 if (tensor.quantization.params != nullptr) {
95 auto* quantization =
96 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
97 if (quantization->scale != nullptr && quantization->scale->size == 1 &&
98 quantization->zero_point != nullptr &&
99 quantization->zero_point->size == 1) {
100 return true;
101 }
102 }
103 return false;
104 }
105 } // namespace
106
107 class TfLiteDriver::DataExpectation {
108 public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)109 DataExpectation(double relative_threshold, double absolute_threshold,
110 int quantization_error_multiplier)
111 : data_(nullptr, nullptr),
112 num_elements_(0),
113 relative_threshold_(relative_threshold),
114 absolute_threshold_(absolute_threshold),
115 quantization_error_multiplier_(quantization_error_multiplier) {}
116
117 template <typename T>
SetData(const string & csv_values)118 void SetData(const string& csv_values) {
119 const auto& values = testing::Split<T>(csv_values, ",");
120 num_elements_ = values.size();
121 data_ = make_type_erased_array<T>(num_elements_);
122 SetTensorData(values, data_.get());
123 }
124
125 bool Check(bool verbose, const TfLiteTensor& tensor);
126
127 private:
CompareTwoValuesHelper(float v1,float v2)128 bool CompareTwoValuesHelper(float v1, float v2) {
129 float diff = std::abs(v1 - v2);
130 bool error_is_large = false;
131 // For very small numbers, try absolute error, otherwise go with
132 // relative.
133 if (std::abs(v2) < relative_threshold_) {
134 error_is_large = (diff > absolute_threshold_);
135 } else {
136 error_is_large = (diff > relative_threshold_ * std::abs(v2));
137 }
138 return error_is_large;
139 }
140
CompareTwoValuesHelper(double v1,double v2)141 bool CompareTwoValuesHelper(double v1, double v2) {
142 double diff = std::abs(v1 - v2);
143 bool error_is_large = false;
144 // For very small numbers, try absolute error, otherwise go with
145 // relative.
146 if (std::abs(v2) < relative_threshold_) {
147 error_is_large = (diff > absolute_threshold_);
148 } else {
149 error_is_large = (diff > relative_threshold_ * std::abs(v2));
150 }
151 return error_is_large;
152 }
153
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)154 bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
155 return CompareTwoValues(v1.real(), v2.real()) ||
156 CompareTwoValues(v1.imag(), v2.imag());
157 }
158
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)159 bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
160 return CompareTwoValues(v1.real(), v2.real()) ||
161 CompareTwoValues(v1.imag(), v2.imag());
162 }
163
CompareTwoValues(float v1,float v2)164 bool CompareTwoValues(float v1, float v2) {
165 return CompareTwoValuesHelper(v1, v2);
166 }
167
CompareTwoValues(double v1,double v2)168 bool CompareTwoValues(double v1, double v2) {
169 return CompareTwoValuesHelper(v1, v2);
170 }
171
172 template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)173 bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
174 size_t tensor_size = tensor.bytes / sizeof(T);
175
176 if (tensor_size != num_elements_) {
177 std::cerr << "Expected a tensor with " << num_elements_
178 << " elements, got " << tensor_size << std::endl;
179 std::cerr << "while checking tensor " << tensor.name << std::endl;
180 return false;
181 }
182
183 bool good_output = true;
184 for (int i = 0; i < tensor_size; ++i) {
185 TS computed = Value<T>(tensor.data.raw, i);
186 TS reference = Value<T>(data_.get(), i);
187 if (CompareTwoValues(computed, reference)) {
188 good_output = false;
189 if (verbose) {
190 std::cerr << " index " << i << ": got " << computed
191 << ", but expected " << reference << std::endl;
192 }
193 }
194 }
195 return good_output;
196 }
197
198 bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
199 bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
200
201 unique_void_ptr data_;
202 size_t num_elements_;
203 double relative_threshold_;
204 double absolute_threshold_;
205 int quantization_error_multiplier_;
206 };
207
208 class TfLiteDriver::ShapeExpectation {
209 public:
ShapeExpectation(const string & csv_values)210 explicit ShapeExpectation(const string& csv_values)
211 : shape_(testing::Split<int32_t>(csv_values, ",")) {}
212
CheckShape(bool verbose,const TfLiteTensor & tensor)213 bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
214 bool valid = true;
215 if (tensor.dims->size == shape_.size()) {
216 for (int i = 0; i < shape_.size(); ++i) {
217 if (shape_[i] != tensor.dims->data[i]) {
218 valid = false;
219 }
220 }
221 } else {
222 valid = false;
223 }
224 if (!valid && verbose) {
225 std::cerr << "Incorrect output shape while checking tensor "
226 << tensor.name << std::endl;
227 std::cerr << "TFLite output shape: ";
228 for (int i = 0; i < tensor.dims->size; ++i) {
229 std::cerr << tensor.dims->data[i] << ", ";
230 }
231 std::cerr << std::endl;
232 std::cerr << "Expected output shape: ";
233 for (int i = 0; i < shape_.size(); ++i) {
234 std::cerr << shape_[i] << ", ";
235 }
236 std::cerr << std::endl;
237 }
238 return valid;
239 }
240
241 private:
242 std::vector<int32_t> shape_;
243 };
244
245 template <>
SetData(const string & csv_values)246 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
247 string s = absl::HexStringToBytes(csv_values);
248 data_ = make_type_erased_array<char>(s.size());
249 memcpy(data_.get(), s.data(), s.size());
250 }
251
TypedCheckString(bool verbose,const TfLiteTensor & tensor)252 bool TfLiteDriver::DataExpectation::TypedCheckString(
253 bool verbose, const TfLiteTensor& tensor) {
254 if (tensor.data.raw == nullptr) {
255 if (verbose) {
256 std::cerr << " got empty string" << std::endl;
257 }
258 return false;
259 }
260 int expected_num_strings = GetStringCount(data_.get());
261 int returned_num_strings = GetStringCount(&tensor);
262 if (expected_num_strings != returned_num_strings) {
263 if (verbose) {
264 std::cerr << " string count differ: got " << returned_num_strings
265 << ", but expected " << expected_num_strings << std::endl;
266 }
267 return false;
268 }
269 for (int i = 0; i < returned_num_strings; ++i) {
270 auto expected_ref = GetString(data_.get(), i);
271 auto returned_ref = GetString(&tensor, i);
272 if (expected_ref.len != returned_ref.len) {
273 if (verbose) {
274 std::cerr << " index " << i << ": got string of size "
275 << returned_ref.len << ", but expected size "
276 << expected_ref.len << std::endl;
277 }
278 return false;
279 }
280 if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
281 if (verbose) {
282 std::cerr << " index " << i << ": strings are different" << std::endl;
283 }
284 return false;
285 }
286 }
287
288 return true;
289 }
290
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)291 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
292 const TfLiteTensor& tensor) {
293 auto* quantization =
294 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
295 const float scale = quantization->scale->data[0];
296 const int32_t zero_point = quantization->zero_point->data[0];
297
298 bool good_result = true;
299 int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
300 for (int i = 0; i < tensor.bytes / int_size; i++) {
301 int32_t computed =
302 tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
303 const float dequantized =
304 static_cast<float>(scale * (computed - zero_point));
305 int error_multiplier = quantization_error_multiplier_;
306 // If we are doing int16 symmetric quantization of activations, we need to
307 // bump up the potential error. Since the weights are quantized to 8 bits
308 // and the activations are 16bits, the output is could be getting
309 // effectively 8bit error instead of 16bit error. So we need to multiply the
310 // error mulitplier by 255 (the difference in number of values between a
311 // 16bit and 8bit number)
312 if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
313 const float reference = Value<float>(data_.get(), i);
314 if (std::abs(dequantized - reference) > error_multiplier * scale) {
315 if (verbose) {
316 std::cerr << " index " << i << ": got " << dequantized
317 << ", but expected " << reference << std::endl;
318 }
319 good_result = false;
320 }
321 }
322 return good_result;
323 }
324
Check(bool verbose,const TfLiteTensor & tensor)325 bool TfLiteDriver::DataExpectation::Check(bool verbose,
326 const TfLiteTensor& tensor) {
327 if (InterpretAsQuantized(tensor)) {
328 return QuantizedCheck(verbose, tensor);
329 }
330
331 switch (tensor.type) {
332 case kTfLiteFloat32:
333 return TypedCheck<float, float>(verbose, tensor);
334 case kTfLiteInt32:
335 return TypedCheck<int32_t, float>(verbose, tensor);
336 case kTfLiteUInt32:
337 return TypedCheck<uint32_t, float>(verbose, tensor);
338 case kTfLiteInt64:
339 return TypedCheck<int64_t, float>(verbose, tensor);
340 case kTfLiteUInt64:
341 return TypedCheck<uint64_t, float>(verbose, tensor);
342 case kTfLiteUInt8:
343 return TypedCheck<uint8_t, float>(verbose, tensor);
344 case kTfLiteInt8:
345 return TypedCheck<int8_t, float>(verbose, tensor);
346 case kTfLiteInt16:
347 return TypedCheck<int16_t, float>(verbose, tensor);
348 case kTfLiteBool:
349 return TypedCheck<bool, float>(verbose, tensor);
350 case kTfLiteString:
351 return TypedCheckString(verbose, tensor);
352 case kTfLiteComplex64:
353 return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
354 tensor);
355 case kTfLiteComplex128:
356 return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
357 tensor);
358 case kTfLiteFloat64:
359 return TypedCheck<double, double>(verbose, tensor);
360 default:
361 fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
362 return false;
363 }
364 }
365
366 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)367 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
368 return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
369 argv);
370 }
371
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)372 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
373 : delegate_(nullptr, nullptr),
374 relative_threshold_(kRelativeThreshold),
375 absolute_threshold_(kAbsoluteThreshold),
376 quantization_error_multiplier_(kQuantizationErrorMultiplier) {
377 if (reference_kernel) {
378 resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
379 } else {
380 // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
381 // are fully validated against TfLite delegates.
382 resolver_.reset(
383 new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
384 ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
385 reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
386 tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
387 tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
388 tflite::ops::custom::AddPerceptionOps(buildinop_resolver_);
389 }
390
391 switch (delegate_type) {
392 case DelegateType::kNone:
393 break;
394 case DelegateType::kNnapi:
395 delegate_ = evaluation::CreateNNAPIDelegate();
396 break;
397 case DelegateType::kGpu:
398 delegate_ = evaluation::CreateGPUDelegate();
399 break;
400 case DelegateType::kFlex:
401 #if !defined(__APPLE__)
402 delegate_ = FlexDelegate::Create();
403 #endif
404 break;
405 }
406 }
407
~TfLiteDriver()408 TfLiteDriver::~TfLiteDriver() {
409 for (auto t : tensors_to_deallocate_) {
410 DeallocateStringTensor(t.second);
411 }
412 }
413
AllocateTensors()414 void TfLiteDriver::AllocateTensors() {
415 if (must_allocate_tensors_) {
416 if (interpreter_->AllocateTensors() != kTfLiteOk) {
417 Invalidate("Failed to allocate tensors");
418 return;
419 }
420 ResetLSTMStateTensors();
421 must_allocate_tensors_ = false;
422 }
423 }
424
LoadModel(const string & bin_file_path)425 void TfLiteDriver::LoadModel(const string& bin_file_path) {
426 if (!IsValid()) return;
427
428 model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
429 if (!model_) {
430 Invalidate("Failed to mmap model " + bin_file_path);
431 return;
432 }
433 InterpreterBuilder(*model_, *resolver_)(&interpreter_);
434 if (!interpreter_) {
435 Invalidate("Failed build interpreter");
436 return;
437 }
438 if (delegate_) {
439 if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
440 Invalidate("Unable to the build graph using the delegate");
441 return;
442 }
443 } else {
444 auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
445 for (auto& one : delegate_providers->CreateAllDelegates()) {
446 if (interpreter_->ModifyGraphWithDelegate(std::move(one)) != kTfLiteOk) {
447 Invalidate(
448 "Unable to the build graph using the delegate initialized from "
449 "tflite::KernelTestDelegateProviders");
450 return;
451 }
452 }
453 }
454
455 must_allocate_tensors_ = true;
456 }
457
ResetTensor(int id)458 void TfLiteDriver::ResetTensor(int id) {
459 if (!IsValid()) return;
460 auto* tensor = interpreter_->tensor(id);
461 memset(tensor->data.raw, 0, tensor->bytes);
462 }
463
ReshapeTensor(int id,const string & csv_values)464 void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
465 if (!IsValid()) return;
466 if (interpreter_->ResizeInputTensor(
467 id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
468 Invalidate("Failed to resize input tensor " + std::to_string(id));
469 return;
470 }
471 must_allocate_tensors_ = true;
472 }
473
SetInput(int id,const string & csv_values)474 void TfLiteDriver::SetInput(int id, const string& csv_values) {
475 if (!IsValid()) return;
476 auto* tensor = interpreter_->tensor(id);
477 switch (tensor->type) {
478 case kTfLiteFloat32: {
479 const auto& values = testing::Split<float>(csv_values, ",");
480 if (!CheckSizes<float>(tensor->bytes, values.size())) return;
481 SetTensorData(values, tensor->data.raw);
482 break;
483 }
484 case kTfLiteInt32: {
485 const auto& values = testing::Split<int32_t>(csv_values, ",");
486 if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
487 SetTensorData(values, tensor->data.raw);
488 break;
489 }
490 case kTfLiteUInt32: {
491 const auto& values = testing::Split<uint32_t>(csv_values, ",");
492 if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
493 SetTensorData(values, tensor->data.raw);
494 break;
495 }
496 case kTfLiteInt64: {
497 const auto& values = testing::Split<int64_t>(csv_values, ",");
498 if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
499 SetTensorData(values, tensor->data.raw);
500 break;
501 }
502 case kTfLiteUInt64: {
503 const auto& values = testing::Split<uint64_t>(csv_values, ",");
504 if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
505 SetTensorData(values, tensor->data.raw);
506 break;
507 }
508 case kTfLiteUInt8: {
509 const auto& values = testing::Split<uint8_t>(csv_values, ",");
510 if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
511 SetTensorData(values, tensor->data.raw);
512 break;
513 }
514 case kTfLiteInt8: {
515 const auto& values = testing::Split<int8_t>(csv_values, ",");
516 if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
517 SetTensorData(values, tensor->data.raw);
518 break;
519 }
520 case kTfLiteInt16: {
521 const auto& values = testing::Split<int16_t>(csv_values, ",");
522 if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
523 SetTensorData(values, tensor->data.raw);
524 break;
525 }
526 case kTfLiteBool: {
527 const auto& values = testing::Split<bool>(csv_values, ",");
528 if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
529 SetTensorData(values, tensor->data.raw);
530 break;
531 }
532 case kTfLiteString: {
533 string s = absl::HexStringToBytes(csv_values);
534
535 DeallocateStringTensor(tensors_to_deallocate_[id]);
536 AllocateStringTensor(id, s.size(), tensor);
537 memcpy(tensor->data.raw, s.data(), s.size());
538
539 break;
540 }
541 case kTfLiteComplex64: {
542 const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
543 if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
544 return;
545 SetTensorData(values, tensor->data.raw);
546 break;
547 }
548 case kTfLiteComplex128: {
549 const auto& values =
550 testing::Split<std::complex<double>>(csv_values, ",");
551 if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
552 return;
553 SetTensorData(values, tensor->data.raw);
554 break;
555 }
556 default:
557 Invalidate(absl::StrCat("Unsupported tensor type ",
558 TfLiteTypeGetName(tensor->type),
559 " in TfLiteDriver::SetInput"));
560 return;
561 }
562 }
563
SetThreshold(double relative_threshold,double absolute_threshold)564 void TfLiteDriver::SetThreshold(double relative_threshold,
565 double absolute_threshold) {
566 relative_threshold_ = relative_threshold;
567 absolute_threshold_ = absolute_threshold;
568 }
569
SetQuantizationErrorMultiplier(int quantization_error_multiplier)570 void TfLiteDriver::SetQuantizationErrorMultiplier(
571 int quantization_error_multiplier) {
572 quantization_error_multiplier_ = quantization_error_multiplier;
573 }
574
SetExpectation(int id,const string & csv_values)575 void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
576 if (!IsValid()) return;
577 auto* tensor = interpreter_->tensor(id);
578 if (expected_output_.count(id) != 0) {
579 Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
580 }
581 expected_output_[id].reset(
582 new DataExpectation(relative_threshold_, absolute_threshold_,
583 quantization_error_multiplier_));
584
585 if (InterpretAsQuantized(*tensor)) {
586 expected_output_[id]->SetData<float>(csv_values);
587 return;
588 }
589
590 switch (tensor->type) {
591 case kTfLiteFloat32:
592 expected_output_[id]->SetData<float>(csv_values);
593 break;
594 case kTfLiteInt32:
595 expected_output_[id]->SetData<int32_t>(csv_values);
596 break;
597 case kTfLiteUInt32:
598 expected_output_[id]->SetData<uint32_t>(csv_values);
599 break;
600 case kTfLiteInt64:
601 expected_output_[id]->SetData<int64_t>(csv_values);
602 break;
603 case kTfLiteUInt64:
604 expected_output_[id]->SetData<uint64_t>(csv_values);
605 break;
606 case kTfLiteUInt8:
607 expected_output_[id]->SetData<uint8_t>(csv_values);
608 break;
609 case kTfLiteInt8:
610 expected_output_[id]->SetData<int8_t>(csv_values);
611 break;
612 case kTfLiteInt16:
613 expected_output_[id]->SetData<int16_t>(csv_values);
614 break;
615 case kTfLiteBool:
616 expected_output_[id]->SetData<bool>(csv_values);
617 break;
618 case kTfLiteString:
619 expected_output_[id]->SetData<string>(csv_values);
620 break;
621 case kTfLiteFloat64:
622 expected_output_[id]->SetData<double>(csv_values);
623 break;
624 case kTfLiteComplex64:
625 expected_output_[id]->SetData<std::complex<float>>(csv_values);
626 break;
627 case kTfLiteComplex128:
628 expected_output_[id]->SetData<std::complex<double>>(csv_values);
629 break;
630 default:
631 Invalidate(absl::StrCat("Unsupported tensor type ",
632 TfLiteTypeGetName(tensor->type),
633 " in TfLiteDriver::SetExpectation"));
634 return;
635 }
636 }
637
SetShapeExpectation(int id,const string & csv_values)638 void TfLiteDriver::SetShapeExpectation(int id, const string& csv_values) {
639 if (!IsValid()) return;
640 if (expected_output_shape_.count(id) != 0) {
641 Invalidate(
642 absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
643 }
644 expected_output_shape_[id].reset(new ShapeExpectation(csv_values));
645 }
646
Invoke()647 void TfLiteDriver::Invoke() {
648 if (!IsValid()) return;
649 if (interpreter_->Invoke() != kTfLiteOk) {
650 Invalidate("Failed to invoke interpreter");
651 }
652 }
653
CheckResults()654 bool TfLiteDriver::CheckResults() {
655 if (!IsValid()) return false;
656 bool success = true;
657 for (const auto& p : expected_output_) {
658 int id = p.first;
659 auto* tensor = interpreter_->tensor(id);
660 if (!p.second->Check(/*verbose=*/false, *tensor)) {
661 // Do not invalidate anything here. Instead, simply output the
662 // differences and return false. Invalidating would prevent all
663 // subsequent invocations from running..
664 std::cerr << "There were errors in invocation '" << GetInvocationId()
665 << "', output tensor '" << id << "':" << std::endl;
666 p.second->Check(/*verbose=*/true, *tensor);
667 success = false;
668 SetOverallSuccess(false);
669 }
670 }
671 for (const auto& p : expected_output_shape_) {
672 int id = p.first;
673 auto* tensor = interpreter_->tensor(id);
674 if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
675 // Do not invalidate anything here. Instead, simply output the
676 // differences and return false. Invalidating would prevent all
677 // subsequent invocations from running..
678 std::cerr << "There were errors in invocation '" << GetInvocationId()
679 << "', output tensor '" << id << "':" << std::endl;
680 p.second->CheckShape(/*verbose=*/true, *tensor);
681 success = false;
682 SetOverallSuccess(false);
683 }
684 }
685 expected_output_.clear();
686 return success;
687 }
688
ResetLSTMStateTensors()689 void TfLiteDriver::ResetLSTMStateTensors() {
690 interpreter_->ResetVariableTensors();
691 }
692
ReadOutput(int id)693 string TfLiteDriver::ReadOutput(int id) {
694 auto* tensor = interpreter_->tensor(id);
695 int num_elements = 1;
696
697 for (int i = 0; i < tensor->dims->size; ++i) {
698 num_elements *= tensor->dims->data[i];
699 }
700
701 switch (tensor->type) {
702 case kTfLiteFloat32:
703 return JoinDefault(tensor->data.f, num_elements, ",");
704 case kTfLiteInt32:
705 return JoinDefault(tensor->data.i32, num_elements, ",");
706 case kTfLiteUInt32:
707 return JoinDefault(tensor->data.u32, num_elements, ",");
708 case kTfLiteInt64:
709 return JoinDefault(tensor->data.i64, num_elements, ",");
710 case kTfLiteUInt64:
711 return JoinDefault(tensor->data.u64, num_elements, ",");
712 case kTfLiteUInt8:
713 return Join(tensor->data.uint8, num_elements, ",");
714 case kTfLiteInt8:
715 return Join(tensor->data.int8, num_elements, ",");
716 case kTfLiteInt16:
717 return Join(tensor->data.i16, num_elements, ",");
718 case kTfLiteBool:
719 return JoinDefault(tensor->data.b, num_elements, ",");
720 default:
721 Invalidate(absl::StrCat("Unsupported tensor type ",
722 TfLiteTypeGetName(tensor->type),
723 " in TfLiteDriver::ReadOutput"));
724 return "";
725 }
726 }
727
728 } // namespace testing
729 } // namespace tflite
730