• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16 
17 #include <algorithm>
18 #include <complex>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #if !defined(__APPLE__)
25 #include "tensorflow/lite/delegates/flex/delegate.h"
26 #endif
27 #include "tensorflow/lite/kernels/custom_ops_register.h"
28 #include "tensorflow/lite/kernels/gradient/gradient_ops.h"
29 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
30 #include "tensorflow/lite/kernels/perception/perception_ops.h"
31 #include "tensorflow/lite/kernels/register.h"
32 #include "tensorflow/lite/kernels/register_ref.h"
33 #include "tensorflow/lite/kernels/test_delegate_providers.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/testing/join.h"
36 #include "tensorflow/lite/testing/split.h"
37 #include "tensorflow/lite/tools/evaluation/utils.h"
38 
39 namespace tflite {
40 namespace testing {
41 
42 namespace {
43 const double kRelativeThreshold = 1e-2f;
44 const double kAbsoluteThreshold = 1e-4f;
45 
46 // For quantized tests, we use a different error measurement from float ones.
47 // Assumes the baseline is a always a float TF model.
48 // Error of a quantized model compared to the baseline comes from two sources:
49 //   1. the math done with quantized inputs, and
50 //   2. quantization of the output.
51 // Assumes there is no error introduced by source 1, the theoretical maximum
52 // error allowed for the output is 0.5 * scale, because scale is equal to the
53 // size of the quantization bucket.
54 //
55 // As a result, we use `scale` as a unit for measuring the quantization error.
56 // To add the error introduced by source 1 as well, we need to relax the
57 // multiplier from 0.5 to a larger number, which is model/op dependent.
58 // The number below is good enough to account for both the two sources of error
59 // for most quantized op tests to pass.
60 const int kQuantizationErrorMultiplier = 4;
61 
62 // Returns the value in the given position in a tensor.
63 template <typename T>
Value(void * data,int index)64 T Value(void* data, int index) {
65   return static_cast<T*>(data)[index];
66 }
67 
68 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)69 void SetTensorData(const std::vector<T>& values, void* data) {
70   T* input_ptr = static_cast<T*>(data);
71   std::copy(values.begin(), values.end(), input_ptr);
72 }
73 
74 // Implement type erasure with unique_ptr with custom deleter
75 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
76 
77 template <typename T>
make_type_erased_array(size_t size)78 unique_void_ptr make_type_erased_array(size_t size) {
79   return unique_void_ptr(static_cast<void*>(new T[size]),
80                          [](void* data) { delete[] static_cast<T*>(data); });
81 }
82 
InterpretAsQuantized(const TfLiteTensor & tensor)83 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
84   if (tensor.quantization.type == kTfLiteNoQuantization) return false;
85 
86   // Quantized single-op models with uint8 input/output type are only used for
87   // EdgeTPU tests.
88   // EdgeTPU tests need to read the quantized values as-is to check for
89   // bit-exactness. As a result we don't interpret the tensor as quantized.
90   // TODO(b/176121243): Add an option to interpret uint8 buffers as
91   // non-quantized type and set if from the child class.
92   if (tensor.type == kTfLiteUInt8) return false;
93 
94   if (tensor.quantization.params != nullptr) {
95     auto* quantization =
96         reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
97     if (quantization->scale != nullptr && quantization->scale->size == 1 &&
98         quantization->zero_point != nullptr &&
99         quantization->zero_point->size == 1) {
100       return true;
101     }
102   }
103   return false;
104 }
105 }  // namespace
106 
107 class TfLiteDriver::DataExpectation {
108  public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)109   DataExpectation(double relative_threshold, double absolute_threshold,
110                   int quantization_error_multiplier)
111       : data_(nullptr, nullptr),
112         num_elements_(0),
113         relative_threshold_(relative_threshold),
114         absolute_threshold_(absolute_threshold),
115         quantization_error_multiplier_(quantization_error_multiplier) {}
116 
117   template <typename T>
SetData(const string & csv_values)118   void SetData(const string& csv_values) {
119     const auto& values = testing::Split<T>(csv_values, ",");
120     num_elements_ = values.size();
121     data_ = make_type_erased_array<T>(num_elements_);
122     SetTensorData(values, data_.get());
123   }
124 
125   bool Check(bool verbose, const TfLiteTensor& tensor);
126 
127  private:
CompareTwoValuesHelper(float v1,float v2)128   bool CompareTwoValuesHelper(float v1, float v2) {
129     if (std::isnan(v1) || std::isnan(v2)) {
130       return !(std::isnan(v1) && std::isnan(v2));
131     }
132 
133     float diff = std::abs(v1 - v2);
134     bool error_is_large = false;
135     // For very small numbers, try absolute error, otherwise go with
136     // relative.
137     if (std::abs(v2) < relative_threshold_) {
138       error_is_large = (diff > absolute_threshold_);
139     } else {
140       error_is_large = (diff > relative_threshold_ * std::abs(v2));
141     }
142     return error_is_large;
143   }
144 
CompareTwoValuesHelper(double v1,double v2)145   bool CompareTwoValuesHelper(double v1, double v2) {
146     if (std::isnan(v1) || std::isnan(v2)) {
147       return !(std::isnan(v1) && std::isnan(v2));
148     }
149 
150     double diff = std::abs(v1 - v2);
151     bool error_is_large = false;
152     // For very small numbers, try absolute error, otherwise go with
153     // relative.
154     if (std::abs(v2) < relative_threshold_) {
155       error_is_large = (diff > absolute_threshold_);
156     } else {
157       error_is_large = (diff > relative_threshold_ * std::abs(v2));
158     }
159     return error_is_large;
160   }
161 
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)162   bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
163     return CompareTwoValues(v1.real(), v2.real()) ||
164            CompareTwoValues(v1.imag(), v2.imag());
165   }
166 
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)167   bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
168     return CompareTwoValues(v1.real(), v2.real()) ||
169            CompareTwoValues(v1.imag(), v2.imag());
170   }
171 
CompareTwoValues(float v1,float v2)172   bool CompareTwoValues(float v1, float v2) {
173     return CompareTwoValuesHelper(v1, v2);
174   }
175 
CompareTwoValues(double v1,double v2)176   bool CompareTwoValues(double v1, double v2) {
177     return CompareTwoValuesHelper(v1, v2);
178   }
179 
180   template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)181   bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
182     size_t tensor_size = tensor.bytes / sizeof(T);
183 
184     if (tensor_size != num_elements_) {
185       std::cerr << "Expected a tensor with " << num_elements_
186                 << " elements, got " << tensor_size << std::endl;
187       std::cerr << "while checking tensor " << tensor.name << std::endl;
188       return false;
189     }
190 
191     bool good_output = true;
192     for (int i = 0; i < tensor_size; ++i) {
193       TS computed = Value<T>(tensor.data.raw, i);
194       TS reference = Value<T>(data_.get(), i);
195       if (CompareTwoValues(computed, reference)) {
196         good_output = false;
197         if (verbose) {
198           std::cerr << "  index " << i << ": got " << computed
199                     << ", but expected " << reference << std::endl;
200         }
201       }
202     }
203     return good_output;
204   }
205 
206   bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
207   bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
208 
209   unique_void_ptr data_;
210   size_t num_elements_;
211   double relative_threshold_;
212   double absolute_threshold_;
213   int quantization_error_multiplier_;
214 };
215 
216 class TfLiteDriver::ShapeExpectation {
217  public:
ShapeExpectation(const string & csv_values)218   explicit ShapeExpectation(const string& csv_values)
219       : shape_(testing::Split<int32_t>(csv_values, ",")) {}
220 
CheckShape(bool verbose,const TfLiteTensor & tensor)221   bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
222     bool valid = true;
223     if (tensor.dims->size == shape_.size()) {
224       for (int i = 0; i < shape_.size(); ++i) {
225         if (shape_[i] != tensor.dims->data[i]) {
226           valid = false;
227         }
228       }
229     } else {
230       valid = false;
231     }
232     if (!valid && verbose) {
233       std::cerr << "Incorrect output shape while checking tensor "
234                 << tensor.name << std::endl;
235       std::cerr << "TFLite output shape: ";
236       for (int i = 0; i < tensor.dims->size; ++i) {
237         std::cerr << tensor.dims->data[i] << ", ";
238       }
239       std::cerr << std::endl;
240       std::cerr << "Expected output shape: ";
241       for (int i = 0; i < shape_.size(); ++i) {
242         std::cerr << shape_[i] << ", ";
243       }
244       std::cerr << std::endl;
245     }
246     return valid;
247   }
248 
249  private:
250   std::vector<int32_t> shape_;
251 };
252 
253 template <>
SetData(const string & csv_values)254 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
255   string s = absl::HexStringToBytes(csv_values);
256   data_ = make_type_erased_array<char>(s.size());
257   memcpy(data_.get(), s.data(), s.size());
258 }
259 
TypedCheckString(bool verbose,const TfLiteTensor & tensor)260 bool TfLiteDriver::DataExpectation::TypedCheckString(
261     bool verbose, const TfLiteTensor& tensor) {
262   if (tensor.data.raw == nullptr) {
263     if (verbose) {
264       std::cerr << "  got empty string" << std::endl;
265     }
266     return false;
267   }
268   int expected_num_strings = GetStringCount(data_.get());
269   int returned_num_strings = GetStringCount(&tensor);
270   if (expected_num_strings != returned_num_strings) {
271     if (verbose) {
272       std::cerr << "  string count differ: got " << returned_num_strings
273                 << ", but expected " << expected_num_strings << std::endl;
274     }
275     return false;
276   }
277   for (int i = 0; i < returned_num_strings; ++i) {
278     auto expected_ref = GetString(data_.get(), i);
279     auto returned_ref = GetString(&tensor, i);
280     if (expected_ref.len != returned_ref.len) {
281       if (verbose) {
282         std::cerr << "  index " << i << ": got string of size "
283                   << returned_ref.len << ", but expected size "
284                   << expected_ref.len << std::endl;
285       }
286       return false;
287     }
288     if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
289       if (verbose) {
290         std::cerr << "  index " << i << ": strings are different" << std::endl;
291       }
292       return false;
293     }
294   }
295 
296   return true;
297 }
298 
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)299 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
300                                                    const TfLiteTensor& tensor) {
301   auto* quantization =
302       reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
303   const float scale = quantization->scale->data[0];
304   const int32_t zero_point = quantization->zero_point->data[0];
305 
306   bool good_result = true;
307   int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
308   for (int i = 0; i < tensor.bytes / int_size; i++) {
309     int32_t computed =
310         tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
311     const float dequantized =
312         static_cast<float>(scale * (computed - zero_point));
313     int error_multiplier = quantization_error_multiplier_;
314     // If we are doing int16 symmetric quantization of activations, we need to
315     // bump up the potential error. Since the weights are quantized to 8 bits
316     // and the activations are 16bits, the output is could be getting
317     // effectively 8bit error instead of 16bit error. So we need to multiply the
318     // error mulitplier by 255 (the difference in number of values between a
319     // 16bit and 8bit number)
320     if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
321     const float reference = Value<float>(data_.get(), i);
322     if (std::abs(dequantized - reference) > error_multiplier * scale) {
323       if (verbose) {
324         std::cerr << "  index " << i << ": got " << dequantized
325                   << ", but expected " << reference << std::endl;
326       }
327       good_result = false;
328     }
329   }
330   return good_result;
331 }
332 
Check(bool verbose,const TfLiteTensor & tensor)333 bool TfLiteDriver::DataExpectation::Check(bool verbose,
334                                           const TfLiteTensor& tensor) {
335   if (InterpretAsQuantized(tensor)) {
336     return QuantizedCheck(verbose, tensor);
337   }
338 
339   switch (tensor.type) {
340     case kTfLiteFloat32:
341       return TypedCheck<float, float>(verbose, tensor);
342     case kTfLiteInt32:
343       return TypedCheck<int32_t, float>(verbose, tensor);
344     case kTfLiteUInt32:
345       return TypedCheck<uint32_t, float>(verbose, tensor);
346     case kTfLiteInt64:
347       return TypedCheck<int64_t, float>(verbose, tensor);
348     case kTfLiteUInt64:
349       return TypedCheck<uint64_t, float>(verbose, tensor);
350     case kTfLiteUInt8:
351       return TypedCheck<uint8_t, float>(verbose, tensor);
352     case kTfLiteInt8:
353       return TypedCheck<int8_t, float>(verbose, tensor);
354     case kTfLiteInt16:
355       return TypedCheck<int16_t, float>(verbose, tensor);
356     case kTfLiteBool:
357       return TypedCheck<bool, float>(verbose, tensor);
358     case kTfLiteString:
359       return TypedCheckString(verbose, tensor);
360     case kTfLiteComplex64:
361       return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
362                                                                   tensor);
363     case kTfLiteComplex128:
364       return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
365                                                                     tensor);
366     case kTfLiteFloat64:
367       return TypedCheck<double, double>(verbose, tensor);
368     default:
369       fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
370       return false;
371   }
372 }
373 
374 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)375 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
376   return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
377                                                                          argv);
378 }
379 
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)380 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
381     : delegate_(nullptr, nullptr),
382       relative_threshold_(kRelativeThreshold),
383       absolute_threshold_(kAbsoluteThreshold),
384       quantization_error_multiplier_(kQuantizationErrorMultiplier) {
385   if (reference_kernel) {
386     resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
387   } else {
388     // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
389     // are fully validated against TfLite delegates.
390     resolver_.reset(
391         new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
392     ops::builtin::BuiltinOpResolver* builtin_op_resolver_ =
393         reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
394     builtin_op_resolver_->AddCustom("IRFFT2D",
395                                     tflite::ops::custom::Register_IRFFT2D());
396     builtin_op_resolver_->AddCustom(
397         "AvgPool3D", tflite::ops::custom::Register_AVG_POOL_3D());
398     builtin_op_resolver_->AddCustom(
399         "MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
400     builtin_op_resolver_->AddCustom("Roll",
401                                     tflite::ops::custom::Register_ROLL());
402     tflite::ops::custom::AddGradientOps(builtin_op_resolver_);
403     tflite::ops::custom::AddParseExampleOp(builtin_op_resolver_);
404     tflite::ops::custom::AddPerceptionOps(builtin_op_resolver_);
405   }
406 
407   switch (delegate_type) {
408     case DelegateType::kNone:
409       break;
410     case DelegateType::kNnapi:
411       delegate_ = evaluation::CreateNNAPIDelegate();
412       break;
413     case DelegateType::kGpu:
414       delegate_ = evaluation::CreateGPUDelegate();
415       break;
416     case DelegateType::kFlex:
417 #if !defined(__APPLE__)
418       delegate_ = FlexDelegate::Create();
419 #endif
420       break;
421   }
422 }
423 
~TfLiteDriver()424 TfLiteDriver::~TfLiteDriver() {
425   for (auto t : tensors_to_deallocate_) {
426     DeallocateStringTensor(t.second);
427   }
428 }
429 
AllocateTensors()430 void TfLiteDriver::AllocateTensors() {
431   if (must_allocate_tensors_) {
432     if (interpreter_->AllocateTensors() != kTfLiteOk) {
433       Invalidate("Failed to allocate tensors");
434       return;
435     }
436     ResetLSTMStateTensors();
437     must_allocate_tensors_ = false;
438   }
439 }
440 
LoadModel(const string & bin_file_path)441 void TfLiteDriver::LoadModel(const string& bin_file_path) {
442   if (!IsValid()) return;
443 
444   model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
445   if (!model_) {
446     Invalidate("Failed to mmap model " + bin_file_path);
447     return;
448   }
449   InterpreterBuilder(*model_, *resolver_)(&interpreter_);
450   if (!interpreter_) {
451     Invalidate("Failed build interpreter");
452     return;
453   }
454   if (delegate_) {
455     if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
456       Invalidate("Unable to the build graph using the delegate");
457       return;
458     }
459   } else {
460     auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
461     for (auto& one : delegate_providers->CreateAllDelegates()) {
462       if (interpreter_->ModifyGraphWithDelegate(std::move(one.delegate)) !=
463           kTfLiteOk) {
464         Invalidate(
465             "Unable to the build graph using the delegate initialized from "
466             "tflite::KernelTestDelegateProviders");
467         return;
468       }
469     }
470   }
471 
472   must_allocate_tensors_ = true;
473 }
474 
ResetTensor(int id)475 void TfLiteDriver::ResetTensor(int id) {
476   if (!IsValid()) return;
477   auto* tensor = interpreter_->tensor(id);
478   memset(tensor->data.raw, 0, tensor->bytes);
479 }
480 
ReshapeTensor(int id,const string & csv_values)481 void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
482   if (!IsValid()) return;
483   if (interpreter_->ResizeInputTensor(
484           id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
485     Invalidate("Failed to resize input tensor " + std::to_string(id));
486     return;
487   }
488   must_allocate_tensors_ = true;
489 }
490 
SetInput(int id,const string & csv_values)491 void TfLiteDriver::SetInput(int id, const string& csv_values) {
492   if (!IsValid()) return;
493   auto* tensor = interpreter_->tensor(id);
494   switch (tensor->type) {
495     case kTfLiteFloat32: {
496       const auto& values = testing::Split<float>(csv_values, ",");
497       if (!CheckSizes<float>(tensor->bytes, values.size())) return;
498       SetTensorData(values, tensor->data.raw);
499       break;
500     }
501     case kTfLiteInt32: {
502       const auto& values = testing::Split<int32_t>(csv_values, ",");
503       if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
504       SetTensorData(values, tensor->data.raw);
505       break;
506     }
507     case kTfLiteUInt32: {
508       const auto& values = testing::Split<uint32_t>(csv_values, ",");
509       if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
510       SetTensorData(values, tensor->data.raw);
511       break;
512     }
513     case kTfLiteInt64: {
514       const auto& values = testing::Split<int64_t>(csv_values, ",");
515       if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
516       SetTensorData(values, tensor->data.raw);
517       break;
518     }
519     case kTfLiteUInt64: {
520       const auto& values = testing::Split<uint64_t>(csv_values, ",");
521       if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
522       SetTensorData(values, tensor->data.raw);
523       break;
524     }
525     case kTfLiteUInt8: {
526       const auto& values = testing::Split<uint8_t>(csv_values, ",");
527       if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
528       SetTensorData(values, tensor->data.raw);
529       break;
530     }
531     case kTfLiteInt8: {
532       const auto& values = testing::Split<int8_t>(csv_values, ",");
533       if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
534       SetTensorData(values, tensor->data.raw);
535       break;
536     }
537     case kTfLiteInt16: {
538       const auto& values = testing::Split<int16_t>(csv_values, ",");
539       if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
540       SetTensorData(values, tensor->data.raw);
541       break;
542     }
543     case kTfLiteBool: {
544       const auto& values = testing::Split<bool>(csv_values, ",");
545       if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
546       SetTensorData(values, tensor->data.raw);
547       break;
548     }
549     case kTfLiteString: {
550       string s = absl::HexStringToBytes(csv_values);
551 
552       DeallocateStringTensor(tensors_to_deallocate_[id]);
553       AllocateStringTensor(id, s.size(), tensor);
554       memcpy(tensor->data.raw, s.data(), s.size());
555 
556       break;
557     }
558     case kTfLiteComplex64: {
559       const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
560       if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
561         return;
562       SetTensorData(values, tensor->data.raw);
563       break;
564     }
565     case kTfLiteComplex128: {
566       const auto& values =
567           testing::Split<std::complex<double>>(csv_values, ",");
568       if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
569         return;
570       SetTensorData(values, tensor->data.raw);
571       break;
572     }
573     default:
574       Invalidate(absl::StrCat("Unsupported tensor type ",
575                               TfLiteTypeGetName(tensor->type),
576                               " in TfLiteDriver::SetInput"));
577       return;
578   }
579 }
580 
SetThreshold(double relative_threshold,double absolute_threshold)581 void TfLiteDriver::SetThreshold(double relative_threshold,
582                                 double absolute_threshold) {
583   relative_threshold_ = relative_threshold;
584   absolute_threshold_ = absolute_threshold;
585 }
586 
SetQuantizationErrorMultiplier(int quantization_error_multiplier)587 void TfLiteDriver::SetQuantizationErrorMultiplier(
588     int quantization_error_multiplier) {
589   quantization_error_multiplier_ = quantization_error_multiplier;
590 }
591 
SetExpectation(int id,const string & csv_values)592 void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
593   if (!IsValid()) return;
594   auto* tensor = interpreter_->tensor(id);
595   if (expected_output_.count(id) != 0) {
596     Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
597   }
598   expected_output_[id].reset(
599       new DataExpectation(relative_threshold_, absolute_threshold_,
600                           quantization_error_multiplier_));
601 
602   if (InterpretAsQuantized(*tensor)) {
603     expected_output_[id]->SetData<float>(csv_values);
604     return;
605   }
606 
607   switch (tensor->type) {
608     case kTfLiteFloat32:
609       expected_output_[id]->SetData<float>(csv_values);
610       break;
611     case kTfLiteInt32:
612       expected_output_[id]->SetData<int32_t>(csv_values);
613       break;
614     case kTfLiteUInt32:
615       expected_output_[id]->SetData<uint32_t>(csv_values);
616       break;
617     case kTfLiteInt64:
618       expected_output_[id]->SetData<int64_t>(csv_values);
619       break;
620     case kTfLiteUInt64:
621       expected_output_[id]->SetData<uint64_t>(csv_values);
622       break;
623     case kTfLiteUInt8:
624       expected_output_[id]->SetData<uint8_t>(csv_values);
625       break;
626     case kTfLiteInt8:
627       expected_output_[id]->SetData<int8_t>(csv_values);
628       break;
629     case kTfLiteInt16:
630       expected_output_[id]->SetData<int16_t>(csv_values);
631       break;
632     case kTfLiteBool:
633       expected_output_[id]->SetData<bool>(csv_values);
634       break;
635     case kTfLiteString:
636       expected_output_[id]->SetData<string>(csv_values);
637       break;
638     case kTfLiteFloat64:
639       expected_output_[id]->SetData<double>(csv_values);
640       break;
641     case kTfLiteComplex64:
642       expected_output_[id]->SetData<std::complex<float>>(csv_values);
643       break;
644     case kTfLiteComplex128:
645       expected_output_[id]->SetData<std::complex<double>>(csv_values);
646       break;
647     default:
648       Invalidate(absl::StrCat("Unsupported tensor type ",
649                               TfLiteTypeGetName(tensor->type),
650                               " in TfLiteDriver::SetExpectation"));
651       return;
652   }
653 }
654 
SetShapeExpectation(int id,const string & csv_values)655 void TfLiteDriver::SetShapeExpectation(int id, const string& csv_values) {
656   if (!IsValid()) return;
657   if (expected_output_shape_.count(id) != 0) {
658     Invalidate(
659         absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
660   }
661   expected_output_shape_[id].reset(new ShapeExpectation(csv_values));
662 }
663 
Invoke()664 void TfLiteDriver::Invoke() {
665   if (!IsValid()) return;
666   if (interpreter_->Invoke() != kTfLiteOk) {
667     Invalidate("Failed to invoke interpreter");
668   }
669 }
670 
CheckResults()671 bool TfLiteDriver::CheckResults() {
672   if (!IsValid()) return false;
673   bool success = true;
674   for (const auto& p : expected_output_) {
675     int id = p.first;
676     auto* tensor = interpreter_->tensor(id);
677     if (!p.second->Check(/*verbose=*/false, *tensor)) {
678       // Do not invalidate anything here. Instead, simply output the
679       // differences and return false. Invalidating would prevent all
680       // subsequent invocations from running..
681       std::cerr << "There were errors in invocation '" << GetInvocationId()
682                 << "', output tensor '" << id << "':" << std::endl;
683       p.second->Check(/*verbose=*/true, *tensor);
684       success = false;
685       SetOverallSuccess(false);
686     }
687   }
688   for (const auto& p : expected_output_shape_) {
689     int id = p.first;
690     auto* tensor = interpreter_->tensor(id);
691     if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
692       // Do not invalidate anything here. Instead, simply output the
693       // differences and return false. Invalidating would prevent all
694       // subsequent invocations from running..
695       std::cerr << "There were errors in invocation '" << GetInvocationId()
696                 << "', output tensor '" << id << "':" << std::endl;
697       p.second->CheckShape(/*verbose=*/true, *tensor);
698       success = false;
699       SetOverallSuccess(false);
700     }
701   }
702   expected_output_.clear();
703   return success;
704 }
705 
ResetLSTMStateTensors()706 void TfLiteDriver::ResetLSTMStateTensors() {
707   interpreter_->ResetVariableTensors();
708 }
709 
ReadOutput(int id)710 string TfLiteDriver::ReadOutput(int id) {
711   auto* tensor = interpreter_->tensor(id);
712   int num_elements = 1;
713 
714   for (int i = 0; i < tensor->dims->size; ++i) {
715     num_elements *= tensor->dims->data[i];
716   }
717 
718   switch (tensor->type) {
719     case kTfLiteFloat32:
720       return JoinDefault(tensor->data.f, num_elements, ",");
721     case kTfLiteInt32:
722       return JoinDefault(tensor->data.i32, num_elements, ",");
723     case kTfLiteUInt32:
724       return JoinDefault(tensor->data.u32, num_elements, ",");
725     case kTfLiteInt64:
726       return JoinDefault(tensor->data.i64, num_elements, ",");
727     case kTfLiteUInt64:
728       return JoinDefault(tensor->data.u64, num_elements, ",");
729     case kTfLiteUInt8:
730       return Join(tensor->data.uint8, num_elements, ",");
731     case kTfLiteInt8:
732       return Join(tensor->data.int8, num_elements, ",");
733     case kTfLiteInt16:
734       return Join(tensor->data.i16, num_elements, ",");
735     case kTfLiteBool:
736       return JoinDefault(tensor->data.b, num_elements, ",");
737     default:
738       Invalidate(absl::StrCat("Unsupported tensor type ",
739                               TfLiteTypeGetName(tensor->type),
740                               " in TfLiteDriver::ReadOutput"));
741       return "";
742   }
743 }
744 
745 }  // namespace testing
746 }  // namespace tflite
747