• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16 
17 #include <algorithm>
18 #include <complex>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/lite/builtin_op_data.h"
24 #if !defined(__APPLE__)
25 #include "tensorflow/lite/delegates/flex/delegate.h"
26 #endif
27 #include "tensorflow/lite/kernels/custom_ops_register.h"
28 #include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
29 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
30 #include "tensorflow/lite/kernels/perception/perception_ops.h"
31 #include "tensorflow/lite/kernels/register.h"
32 #include "tensorflow/lite/kernels/register_ref.h"
33 #include "tensorflow/lite/kernels/test_delegate_providers.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/testing/join.h"
36 #include "tensorflow/lite/testing/split.h"
37 #include "tensorflow/lite/tools/evaluation/utils.h"
38 
39 namespace tflite {
40 namespace testing {
41 
42 namespace {
43 const double kRelativeThreshold = 1e-2f;
44 const double kAbsoluteThreshold = 1e-4f;
45 
46 // For quantized tests, we use a different error measurement from float ones.
47 // Assumes the baseline is a always a float TF model.
48 // Error of a quantized model compared to the baseline comes from two sources:
49 //   1. the math done with quantized inputs, and
50 //   2. quantization of the output.
51 // Assumes there is no error introduced by source 1, the theoretical maximum
52 // error allowed for the output is 0.5 * scale, because scale is equal to the
53 // size of the quantization bucket.
54 //
55 // As a result, we use `scale` as a unit for measuring the quantization error.
56 // To add the error introduced by source 1 as well, we need to relax the
57 // multiplier from 0.5 to a larger number, which is model/op dependent.
58 // The number below is good enough to account for both the two sources of error
59 // for most quantized op tests to pass.
60 const int kQuantizationErrorMultiplier = 4;
61 
62 // Returns the value in the given position in a tensor.
63 template <typename T>
Value(void * data,int index)64 T Value(void* data, int index) {
65   return static_cast<T*>(data)[index];
66 }
67 
68 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)69 void SetTensorData(const std::vector<T>& values, void* data) {
70   T* input_ptr = static_cast<T*>(data);
71   std::copy(values.begin(), values.end(), input_ptr);
72 }
73 
74 // Implement type erasure with unique_ptr with custom deleter
75 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
76 
77 template <typename T>
make_type_erased_array(size_t size)78 unique_void_ptr make_type_erased_array(size_t size) {
79   return unique_void_ptr(static_cast<void*>(new T[size]),
80                          [](void* data) { delete[] static_cast<T*>(data); });
81 }
82 
InterpretAsQuantized(const TfLiteTensor & tensor)83 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
84   if (tensor.quantization.type == kTfLiteNoQuantization) return false;
85 
86   // Quantized single-op models with uint8 input/output type are only used for
87   // EdgeTPU tests.
88   // EdgeTPU tests need to read the quantized values as-is to check for
89   // bit-exactness. As a result we don't interpret the tensor as quantized.
90   // TODO(b/176121243): Add an option to interpret uint8 buffers as
91   // non-quantized type and set if from the child class.
92   if (tensor.type == kTfLiteUInt8) return false;
93 
94   if (tensor.quantization.params != nullptr) {
95     auto* quantization =
96         reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
97     if (quantization->scale != nullptr && quantization->scale->size == 1 &&
98         quantization->zero_point != nullptr &&
99         quantization->zero_point->size == 1) {
100       return true;
101     }
102   }
103   return false;
104 }
105 }  // namespace
106 
107 class TfLiteDriver::DataExpectation {
108  public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)109   DataExpectation(double relative_threshold, double absolute_threshold,
110                   int quantization_error_multiplier)
111       : data_(nullptr, nullptr),
112         num_elements_(0),
113         relative_threshold_(relative_threshold),
114         absolute_threshold_(absolute_threshold),
115         quantization_error_multiplier_(quantization_error_multiplier) {}
116 
117   template <typename T>
SetData(const string & csv_values)118   void SetData(const string& csv_values) {
119     const auto& values = testing::Split<T>(csv_values, ",");
120     num_elements_ = values.size();
121     data_ = make_type_erased_array<T>(num_elements_);
122     SetTensorData(values, data_.get());
123   }
124 
125   bool Check(bool verbose, const TfLiteTensor& tensor);
126 
127  private:
CompareTwoValuesHelper(float v1,float v2)128   bool CompareTwoValuesHelper(float v1, float v2) {
129     float diff = std::abs(v1 - v2);
130     bool error_is_large = false;
131     // For very small numbers, try absolute error, otherwise go with
132     // relative.
133     if (std::abs(v2) < relative_threshold_) {
134       error_is_large = (diff > absolute_threshold_);
135     } else {
136       error_is_large = (diff > relative_threshold_ * std::abs(v2));
137     }
138     return error_is_large;
139   }
140 
CompareTwoValuesHelper(double v1,double v2)141   bool CompareTwoValuesHelper(double v1, double v2) {
142     double diff = std::abs(v1 - v2);
143     bool error_is_large = false;
144     // For very small numbers, try absolute error, otherwise go with
145     // relative.
146     if (std::abs(v2) < relative_threshold_) {
147       error_is_large = (diff > absolute_threshold_);
148     } else {
149       error_is_large = (diff > relative_threshold_ * std::abs(v2));
150     }
151     return error_is_large;
152   }
153 
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)154   bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
155     return CompareTwoValues(v1.real(), v2.real()) ||
156            CompareTwoValues(v1.imag(), v2.imag());
157   }
158 
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)159   bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
160     return CompareTwoValues(v1.real(), v2.real()) ||
161            CompareTwoValues(v1.imag(), v2.imag());
162   }
163 
CompareTwoValues(float v1,float v2)164   bool CompareTwoValues(float v1, float v2) {
165     return CompareTwoValuesHelper(v1, v2);
166   }
167 
CompareTwoValues(double v1,double v2)168   bool CompareTwoValues(double v1, double v2) {
169     return CompareTwoValuesHelper(v1, v2);
170   }
171 
172   template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)173   bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
174     size_t tensor_size = tensor.bytes / sizeof(T);
175 
176     if (tensor_size != num_elements_) {
177       std::cerr << "Expected a tensor with " << num_elements_
178                 << " elements, got " << tensor_size << std::endl;
179       std::cerr << "while checking tensor " << tensor.name << std::endl;
180       return false;
181     }
182 
183     bool good_output = true;
184     for (int i = 0; i < tensor_size; ++i) {
185       TS computed = Value<T>(tensor.data.raw, i);
186       TS reference = Value<T>(data_.get(), i);
187       if (CompareTwoValues(computed, reference)) {
188         good_output = false;
189         if (verbose) {
190           std::cerr << "  index " << i << ": got " << computed
191                     << ", but expected " << reference << std::endl;
192         }
193       }
194     }
195     return good_output;
196   }
197 
198   bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
199   bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
200 
201   unique_void_ptr data_;
202   size_t num_elements_;
203   double relative_threshold_;
204   double absolute_threshold_;
205   int quantization_error_multiplier_;
206 };
207 
208 class TfLiteDriver::ShapeExpectation {
209  public:
ShapeExpectation(const string & csv_values)210   explicit ShapeExpectation(const string& csv_values)
211       : shape_(testing::Split<int32_t>(csv_values, ",")) {}
212 
CheckShape(bool verbose,const TfLiteTensor & tensor)213   bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
214     bool valid = true;
215     if (tensor.dims->size == shape_.size()) {
216       for (int i = 0; i < shape_.size(); ++i) {
217         if (shape_[i] != tensor.dims->data[i]) {
218           valid = false;
219         }
220       }
221     } else {
222       valid = false;
223     }
224     if (!valid && verbose) {
225       std::cerr << "Incorrect output shape while checking tensor "
226                 << tensor.name << std::endl;
227       std::cerr << "TFLite output shape: ";
228       for (int i = 0; i < tensor.dims->size; ++i) {
229         std::cerr << tensor.dims->data[i] << ", ";
230       }
231       std::cerr << std::endl;
232       std::cerr << "Expected output shape: ";
233       for (int i = 0; i < shape_.size(); ++i) {
234         std::cerr << shape_[i] << ", ";
235       }
236       std::cerr << std::endl;
237     }
238     return valid;
239   }
240 
241  private:
242   std::vector<int32_t> shape_;
243 };
244 
245 template <>
SetData(const string & csv_values)246 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
247   string s = absl::HexStringToBytes(csv_values);
248   data_ = make_type_erased_array<char>(s.size());
249   memcpy(data_.get(), s.data(), s.size());
250 }
251 
TypedCheckString(bool verbose,const TfLiteTensor & tensor)252 bool TfLiteDriver::DataExpectation::TypedCheckString(
253     bool verbose, const TfLiteTensor& tensor) {
254   if (tensor.data.raw == nullptr) {
255     if (verbose) {
256       std::cerr << "  got empty string" << std::endl;
257     }
258     return false;
259   }
260   int expected_num_strings = GetStringCount(data_.get());
261   int returned_num_strings = GetStringCount(&tensor);
262   if (expected_num_strings != returned_num_strings) {
263     if (verbose) {
264       std::cerr << "  string count differ: got " << returned_num_strings
265                 << ", but expected " << expected_num_strings << std::endl;
266     }
267     return false;
268   }
269   for (int i = 0; i < returned_num_strings; ++i) {
270     auto expected_ref = GetString(data_.get(), i);
271     auto returned_ref = GetString(&tensor, i);
272     if (expected_ref.len != returned_ref.len) {
273       if (verbose) {
274         std::cerr << "  index " << i << ": got string of size "
275                   << returned_ref.len << ", but expected size "
276                   << expected_ref.len << std::endl;
277       }
278       return false;
279     }
280     if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
281       if (verbose) {
282         std::cerr << "  index " << i << ": strings are different" << std::endl;
283       }
284       return false;
285     }
286   }
287 
288   return true;
289 }
290 
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)291 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
292                                                    const TfLiteTensor& tensor) {
293   auto* quantization =
294       reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
295   const float scale = quantization->scale->data[0];
296   const int32_t zero_point = quantization->zero_point->data[0];
297 
298   bool good_result = true;
299   int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
300   for (int i = 0; i < tensor.bytes / int_size; i++) {
301     int32_t computed =
302         tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
303     const float dequantized =
304         static_cast<float>(scale * (computed - zero_point));
305     int error_multiplier = quantization_error_multiplier_;
306     // If we are doing int16 symmetric quantization of activations, we need to
307     // bump up the potential error. Since the weights are quantized to 8 bits
308     // and the activations are 16bits, the output is could be getting
309     // effectively 8bit error instead of 16bit error. So we need to multiply the
310     // error mulitplier by 255 (the difference in number of values between a
311     // 16bit and 8bit number)
312     if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
313     const float reference = Value<float>(data_.get(), i);
314     if (std::abs(dequantized - reference) > error_multiplier * scale) {
315       if (verbose) {
316         std::cerr << "  index " << i << ": got " << dequantized
317                   << ", but expected " << reference << std::endl;
318       }
319       good_result = false;
320     }
321   }
322   return good_result;
323 }
324 
Check(bool verbose,const TfLiteTensor & tensor)325 bool TfLiteDriver::DataExpectation::Check(bool verbose,
326                                           const TfLiteTensor& tensor) {
327   if (InterpretAsQuantized(tensor)) {
328     return QuantizedCheck(verbose, tensor);
329   }
330 
331   switch (tensor.type) {
332     case kTfLiteFloat32:
333       return TypedCheck<float, float>(verbose, tensor);
334     case kTfLiteInt32:
335       return TypedCheck<int32_t, float>(verbose, tensor);
336     case kTfLiteUInt32:
337       return TypedCheck<uint32_t, float>(verbose, tensor);
338     case kTfLiteInt64:
339       return TypedCheck<int64_t, float>(verbose, tensor);
340     case kTfLiteUInt64:
341       return TypedCheck<uint64_t, float>(verbose, tensor);
342     case kTfLiteUInt8:
343       return TypedCheck<uint8_t, float>(verbose, tensor);
344     case kTfLiteInt8:
345       return TypedCheck<int8_t, float>(verbose, tensor);
346     case kTfLiteInt16:
347       return TypedCheck<int16_t, float>(verbose, tensor);
348     case kTfLiteBool:
349       return TypedCheck<bool, float>(verbose, tensor);
350     case kTfLiteString:
351       return TypedCheckString(verbose, tensor);
352     case kTfLiteComplex64:
353       return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
354                                                                   tensor);
355     case kTfLiteComplex128:
356       return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
357                                                                     tensor);
358     case kTfLiteFloat64:
359       return TypedCheck<double, double>(verbose, tensor);
360     default:
361       fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
362       return false;
363   }
364 }
365 
366 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)367 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
368   return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
369                                                                          argv);
370 }
371 
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)372 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
373     : delegate_(nullptr, nullptr),
374       relative_threshold_(kRelativeThreshold),
375       absolute_threshold_(kAbsoluteThreshold),
376       quantization_error_multiplier_(kQuantizationErrorMultiplier) {
377   if (reference_kernel) {
378     resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
379   } else {
380     // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
381     // are fully validated against TfLite delegates.
382     resolver_.reset(
383         new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
384     ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
385         reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
386     tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
387     tflite::ops::custom::AddParseExampleOp(buildinop_resolver_);
388     tflite::ops::custom::AddPerceptionOps(buildinop_resolver_);
389   }
390 
391   switch (delegate_type) {
392     case DelegateType::kNone:
393       break;
394     case DelegateType::kNnapi:
395       delegate_ = evaluation::CreateNNAPIDelegate();
396       break;
397     case DelegateType::kGpu:
398       delegate_ = evaluation::CreateGPUDelegate();
399       break;
400     case DelegateType::kFlex:
401 #if !defined(__APPLE__)
402       delegate_ = FlexDelegate::Create();
403 #endif
404       break;
405   }
406 }
407 
~TfLiteDriver()408 TfLiteDriver::~TfLiteDriver() {
409   for (auto t : tensors_to_deallocate_) {
410     DeallocateStringTensor(t.second);
411   }
412 }
413 
AllocateTensors()414 void TfLiteDriver::AllocateTensors() {
415   if (must_allocate_tensors_) {
416     if (interpreter_->AllocateTensors() != kTfLiteOk) {
417       Invalidate("Failed to allocate tensors");
418       return;
419     }
420     ResetLSTMStateTensors();
421     must_allocate_tensors_ = false;
422   }
423 }
424 
LoadModel(const string & bin_file_path)425 void TfLiteDriver::LoadModel(const string& bin_file_path) {
426   if (!IsValid()) return;
427 
428   model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
429   if (!model_) {
430     Invalidate("Failed to mmap model " + bin_file_path);
431     return;
432   }
433   InterpreterBuilder(*model_, *resolver_)(&interpreter_);
434   if (!interpreter_) {
435     Invalidate("Failed build interpreter");
436     return;
437   }
438   if (delegate_) {
439     if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
440       Invalidate("Unable to the build graph using the delegate");
441       return;
442     }
443   } else {
444     auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
445     for (auto& one : delegate_providers->CreateAllDelegates()) {
446       if (interpreter_->ModifyGraphWithDelegate(std::move(one)) != kTfLiteOk) {
447         Invalidate(
448             "Unable to the build graph using the delegate initialized from "
449             "tflite::KernelTestDelegateProviders");
450         return;
451       }
452     }
453   }
454 
455   must_allocate_tensors_ = true;
456 }
457 
ResetTensor(int id)458 void TfLiteDriver::ResetTensor(int id) {
459   if (!IsValid()) return;
460   auto* tensor = interpreter_->tensor(id);
461   memset(tensor->data.raw, 0, tensor->bytes);
462 }
463 
ReshapeTensor(int id,const string & csv_values)464 void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) {
465   if (!IsValid()) return;
466   if (interpreter_->ResizeInputTensor(
467           id, testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
468     Invalidate("Failed to resize input tensor " + std::to_string(id));
469     return;
470   }
471   must_allocate_tensors_ = true;
472 }
473 
SetInput(int id,const string & csv_values)474 void TfLiteDriver::SetInput(int id, const string& csv_values) {
475   if (!IsValid()) return;
476   auto* tensor = interpreter_->tensor(id);
477   switch (tensor->type) {
478     case kTfLiteFloat32: {
479       const auto& values = testing::Split<float>(csv_values, ",");
480       if (!CheckSizes<float>(tensor->bytes, values.size())) return;
481       SetTensorData(values, tensor->data.raw);
482       break;
483     }
484     case kTfLiteInt32: {
485       const auto& values = testing::Split<int32_t>(csv_values, ",");
486       if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
487       SetTensorData(values, tensor->data.raw);
488       break;
489     }
490     case kTfLiteUInt32: {
491       const auto& values = testing::Split<uint32_t>(csv_values, ",");
492       if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
493       SetTensorData(values, tensor->data.raw);
494       break;
495     }
496     case kTfLiteInt64: {
497       const auto& values = testing::Split<int64_t>(csv_values, ",");
498       if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
499       SetTensorData(values, tensor->data.raw);
500       break;
501     }
502     case kTfLiteUInt64: {
503       const auto& values = testing::Split<uint64_t>(csv_values, ",");
504       if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
505       SetTensorData(values, tensor->data.raw);
506       break;
507     }
508     case kTfLiteUInt8: {
509       const auto& values = testing::Split<uint8_t>(csv_values, ",");
510       if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
511       SetTensorData(values, tensor->data.raw);
512       break;
513     }
514     case kTfLiteInt8: {
515       const auto& values = testing::Split<int8_t>(csv_values, ",");
516       if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
517       SetTensorData(values, tensor->data.raw);
518       break;
519     }
520     case kTfLiteInt16: {
521       const auto& values = testing::Split<int16_t>(csv_values, ",");
522       if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
523       SetTensorData(values, tensor->data.raw);
524       break;
525     }
526     case kTfLiteBool: {
527       const auto& values = testing::Split<bool>(csv_values, ",");
528       if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
529       SetTensorData(values, tensor->data.raw);
530       break;
531     }
532     case kTfLiteString: {
533       string s = absl::HexStringToBytes(csv_values);
534 
535       DeallocateStringTensor(tensors_to_deallocate_[id]);
536       AllocateStringTensor(id, s.size(), tensor);
537       memcpy(tensor->data.raw, s.data(), s.size());
538 
539       break;
540     }
541     case kTfLiteComplex64: {
542       const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
543       if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
544         return;
545       SetTensorData(values, tensor->data.raw);
546       break;
547     }
548     case kTfLiteComplex128: {
549       const auto& values =
550           testing::Split<std::complex<double>>(csv_values, ",");
551       if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
552         return;
553       SetTensorData(values, tensor->data.raw);
554       break;
555     }
556     default:
557       Invalidate(absl::StrCat("Unsupported tensor type ",
558                               TfLiteTypeGetName(tensor->type),
559                               " in TfLiteDriver::SetInput"));
560       return;
561   }
562 }
563 
SetThreshold(double relative_threshold,double absolute_threshold)564 void TfLiteDriver::SetThreshold(double relative_threshold,
565                                 double absolute_threshold) {
566   relative_threshold_ = relative_threshold;
567   absolute_threshold_ = absolute_threshold;
568 }
569 
SetQuantizationErrorMultiplier(int quantization_error_multiplier)570 void TfLiteDriver::SetQuantizationErrorMultiplier(
571     int quantization_error_multiplier) {
572   quantization_error_multiplier_ = quantization_error_multiplier;
573 }
574 
SetExpectation(int id,const string & csv_values)575 void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
576   if (!IsValid()) return;
577   auto* tensor = interpreter_->tensor(id);
578   if (expected_output_.count(id) != 0) {
579     Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
580   }
581   expected_output_[id].reset(
582       new DataExpectation(relative_threshold_, absolute_threshold_,
583                           quantization_error_multiplier_));
584 
585   if (InterpretAsQuantized(*tensor)) {
586     expected_output_[id]->SetData<float>(csv_values);
587     return;
588   }
589 
590   switch (tensor->type) {
591     case kTfLiteFloat32:
592       expected_output_[id]->SetData<float>(csv_values);
593       break;
594     case kTfLiteInt32:
595       expected_output_[id]->SetData<int32_t>(csv_values);
596       break;
597     case kTfLiteUInt32:
598       expected_output_[id]->SetData<uint32_t>(csv_values);
599       break;
600     case kTfLiteInt64:
601       expected_output_[id]->SetData<int64_t>(csv_values);
602       break;
603     case kTfLiteUInt64:
604       expected_output_[id]->SetData<uint64_t>(csv_values);
605       break;
606     case kTfLiteUInt8:
607       expected_output_[id]->SetData<uint8_t>(csv_values);
608       break;
609     case kTfLiteInt8:
610       expected_output_[id]->SetData<int8_t>(csv_values);
611       break;
612     case kTfLiteInt16:
613       expected_output_[id]->SetData<int16_t>(csv_values);
614       break;
615     case kTfLiteBool:
616       expected_output_[id]->SetData<bool>(csv_values);
617       break;
618     case kTfLiteString:
619       expected_output_[id]->SetData<string>(csv_values);
620       break;
621     case kTfLiteFloat64:
622       expected_output_[id]->SetData<double>(csv_values);
623       break;
624     case kTfLiteComplex64:
625       expected_output_[id]->SetData<std::complex<float>>(csv_values);
626       break;
627     case kTfLiteComplex128:
628       expected_output_[id]->SetData<std::complex<double>>(csv_values);
629       break;
630     default:
631       Invalidate(absl::StrCat("Unsupported tensor type ",
632                               TfLiteTypeGetName(tensor->type),
633                               " in TfLiteDriver::SetExpectation"));
634       return;
635   }
636 }
637 
SetShapeExpectation(int id,const string & csv_values)638 void TfLiteDriver::SetShapeExpectation(int id, const string& csv_values) {
639   if (!IsValid()) return;
640   if (expected_output_shape_.count(id) != 0) {
641     Invalidate(
642         absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
643   }
644   expected_output_shape_[id].reset(new ShapeExpectation(csv_values));
645 }
646 
Invoke()647 void TfLiteDriver::Invoke() {
648   if (!IsValid()) return;
649   if (interpreter_->Invoke() != kTfLiteOk) {
650     Invalidate("Failed to invoke interpreter");
651   }
652 }
653 
CheckResults()654 bool TfLiteDriver::CheckResults() {
655   if (!IsValid()) return false;
656   bool success = true;
657   for (const auto& p : expected_output_) {
658     int id = p.first;
659     auto* tensor = interpreter_->tensor(id);
660     if (!p.second->Check(/*verbose=*/false, *tensor)) {
661       // Do not invalidate anything here. Instead, simply output the
662       // differences and return false. Invalidating would prevent all
663       // subsequent invocations from running..
664       std::cerr << "There were errors in invocation '" << GetInvocationId()
665                 << "', output tensor '" << id << "':" << std::endl;
666       p.second->Check(/*verbose=*/true, *tensor);
667       success = false;
668       SetOverallSuccess(false);
669     }
670   }
671   for (const auto& p : expected_output_shape_) {
672     int id = p.first;
673     auto* tensor = interpreter_->tensor(id);
674     if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
675       // Do not invalidate anything here. Instead, simply output the
676       // differences and return false. Invalidating would prevent all
677       // subsequent invocations from running..
678       std::cerr << "There were errors in invocation '" << GetInvocationId()
679                 << "', output tensor '" << id << "':" << std::endl;
680       p.second->CheckShape(/*verbose=*/true, *tensor);
681       success = false;
682       SetOverallSuccess(false);
683     }
684   }
685   expected_output_.clear();
686   return success;
687 }
688 
ResetLSTMStateTensors()689 void TfLiteDriver::ResetLSTMStateTensors() {
690   interpreter_->ResetVariableTensors();
691 }
692 
ReadOutput(int id)693 string TfLiteDriver::ReadOutput(int id) {
694   auto* tensor = interpreter_->tensor(id);
695   int num_elements = 1;
696 
697   for (int i = 0; i < tensor->dims->size; ++i) {
698     num_elements *= tensor->dims->data[i];
699   }
700 
701   switch (tensor->type) {
702     case kTfLiteFloat32:
703       return JoinDefault(tensor->data.f, num_elements, ",");
704     case kTfLiteInt32:
705       return JoinDefault(tensor->data.i32, num_elements, ",");
706     case kTfLiteUInt32:
707       return JoinDefault(tensor->data.u32, num_elements, ",");
708     case kTfLiteInt64:
709       return JoinDefault(tensor->data.i64, num_elements, ",");
710     case kTfLiteUInt64:
711       return JoinDefault(tensor->data.u64, num_elements, ",");
712     case kTfLiteUInt8:
713       return Join(tensor->data.uint8, num_elements, ",");
714     case kTfLiteInt8:
715       return Join(tensor->data.int8, num_elements, ",");
716     case kTfLiteInt16:
717       return Join(tensor->data.i16, num_elements, ",");
718     case kTfLiteBool:
719       return JoinDefault(tensor->data.b, num_elements, ",");
720     default:
721       Invalidate(absl::StrCat("Unsupported tensor type ",
722                               TfLiteTypeGetName(tensor->type),
723                               " in TfLiteDriver::ReadOutput"));
724       return "";
725   }
726 }
727 
728 }  // namespace testing
729 }  // namespace tflite
730