• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16 
17 #include <algorithm>
18 #include <complex>
19 #include <cstdlib>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/numbers.h"
28 #include "tensorflow/lite/builtin_op_data.h"
29 #include "tensorflow/lite/c/c_api_types.h"
30 #include "tensorflow/lite/c/common.h"
31 #if !defined(__APPLE__)
32 #include "tensorflow/lite/delegates/flex/delegate.h"
33 #endif
34 #include "tensorflow/lite/kernels/custom_ops_register.h"
35 #include "tensorflow/lite/kernels/gradient/gradient_ops.h"
36 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
37 #include "tensorflow/lite/kernels/perception/perception_ops.h"
38 #include "tensorflow/lite/kernels/register.h"
39 #include "tensorflow/lite/kernels/register_ref.h"
40 #include "tensorflow/lite/kernels/test_delegate_providers.h"
41 #include "tensorflow/lite/signature_runner.h"
42 #include "tensorflow/lite/string_util.h"
43 #include "tensorflow/lite/testing/join.h"
44 #include "tensorflow/lite/testing/split.h"
45 #include "tensorflow/lite/tools/evaluation/utils.h"
46 
47 namespace tflite {
48 namespace testing {
49 
50 namespace {
51 const double kRelativeThreshold = 1e-2f;
52 const double kAbsoluteThreshold = 1e-4f;
53 const char kDefaultSignatureKey[] = "serving_default";
54 
55 // For quantized tests, we use a different error measurement from float ones.
56 // Assumes the baseline is a always a float TF model.
57 // Error of a quantized model compared to the baseline comes from two sources:
58 //   1. the math done with quantized inputs, and
59 //   2. quantization of the output.
60 // Assumes there is no error introduced by source 1, the theoretical maximum
61 // error allowed for the output is 0.5 * scale, because scale is equal to the
62 // size of the quantization bucket.
63 //
64 // As a result, we use `scale` as a unit for measuring the quantization error.
65 // To add the error introduced by source 1 as well, we need to relax the
66 // multiplier from 0.5 to a larger number, which is model/op dependent.
67 // The number below is good enough to account for both the two sources of error
68 // for most quantized op tests to pass.
69 const int kQuantizationErrorMultiplier = 4;
70 
71 // Returns the value in the given position in a tensor.
72 template <typename T>
Value(void * data,int index)73 T Value(void* data, int index) {
74   return static_cast<T*>(data)[index];
75 }
76 
77 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)78 void SetTensorData(const std::vector<T>& values, void* data) {
79   T* input_ptr = static_cast<T*>(data);
80   std::copy(values.begin(), values.end(), input_ptr);
81 }
82 
83 // Implement type erasure with unique_ptr with custom deleter
84 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
85 
86 template <typename T>
make_type_erased_array(size_t size)87 unique_void_ptr make_type_erased_array(size_t size) {
88   return unique_void_ptr(static_cast<void*>(new T[size]),
89                          [](void* data) { delete[] static_cast<T*>(data); });
90 }
91 
InterpretAsQuantized(const TfLiteTensor & tensor)92 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
93   if (tensor.quantization.type == kTfLiteNoQuantization) return false;
94 
95   // Quantized single-op models with uint8 input/output type are only used for
96   // EdgeTPU tests.
97   // EdgeTPU tests need to read the quantized values as-is to check for
98   // bit-exactness. As a result we don't interpret the tensor as quantized.
99   // TODO(b/176121243): Add an option to interpret uint8 buffers as
100   // non-quantized type and set if from the child class.
101   if (tensor.type == kTfLiteUInt8) return false;
102 
103   if (tensor.quantization.params != nullptr) {
104     auto* quantization =
105         reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
106     if (quantization->scale != nullptr && quantization->scale->size == 1 &&
107         quantization->zero_point != nullptr &&
108         quantization->zero_point->size == 1) {
109       return true;
110     }
111   }
112   return false;
113 }
114 }  // namespace
115 
116 class TfLiteDriver::DataExpectation {
117  public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)118   DataExpectation(double relative_threshold, double absolute_threshold,
119                   int quantization_error_multiplier)
120       : data_(nullptr, nullptr),
121         num_elements_(0),
122         relative_threshold_(relative_threshold),
123         absolute_threshold_(absolute_threshold),
124         quantization_error_multiplier_(quantization_error_multiplier) {}
125 
126   template <typename T>
SetData(const string & csv_values)127   void SetData(const string& csv_values) {
128     const auto& values = testing::Split<T>(csv_values, ",");
129     num_elements_ = values.size();
130     data_ = make_type_erased_array<T>(num_elements_);
131     SetTensorData(values, data_.get());
132   }
133 
134   bool Check(bool verbose, const TfLiteTensor& tensor);
135 
136  private:
CompareTwoValuesHelper(float v1,float v2)137   bool CompareTwoValuesHelper(float v1, float v2) {
138     if (std::isnan(v1) || std::isnan(v2)) {
139       return !(std::isnan(v1) && std::isnan(v2));
140     }
141 
142     float diff = std::abs(v1 - v2);
143     bool error_is_large = false;
144     // For very small numbers, try absolute error, otherwise go with
145     // relative.
146     if (std::abs(v2) < relative_threshold_) {
147       error_is_large = (diff > absolute_threshold_);
148     } else {
149       error_is_large = (diff > relative_threshold_ * std::abs(v2));
150     }
151     return error_is_large;
152   }
153 
CompareTwoValuesHelper(double v1,double v2)154   bool CompareTwoValuesHelper(double v1, double v2) {
155     if (std::isnan(v1) || std::isnan(v2)) {
156       return !(std::isnan(v1) && std::isnan(v2));
157     }
158 
159     double diff = std::abs(v1 - v2);
160     bool error_is_large = false;
161     // For very small numbers, try absolute error, otherwise go with
162     // relative.
163     if (std::abs(v2) < relative_threshold_) {
164       error_is_large = (diff > absolute_threshold_);
165     } else {
166       error_is_large = (diff > relative_threshold_ * std::abs(v2));
167     }
168     return error_is_large;
169   }
170 
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)171   bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
172     return CompareTwoValues(v1.real(), v2.real()) ||
173            CompareTwoValues(v1.imag(), v2.imag());
174   }
175 
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)176   bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
177     return CompareTwoValues(v1.real(), v2.real()) ||
178            CompareTwoValues(v1.imag(), v2.imag());
179   }
180 
CompareTwoValues(float v1,float v2)181   bool CompareTwoValues(float v1, float v2) {
182     return CompareTwoValuesHelper(v1, v2);
183   }
184 
CompareTwoValues(double v1,double v2)185   bool CompareTwoValues(double v1, double v2) {
186     return CompareTwoValuesHelper(v1, v2);
187   }
188 
189   template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)190   bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
191     size_t tensor_size = tensor.bytes / sizeof(T);
192 
193     if (tensor_size != num_elements_) {
194       std::cerr << "Expected a tensor with " << num_elements_
195                 << " elements, got " << tensor_size << std::endl;
196       std::cerr << "while checking tensor " << tensor.name << std::endl;
197       return false;
198     }
199 
200     bool good_output = true;
201     for (int i = 0; i < tensor_size; ++i) {
202       TS computed = Value<T>(tensor.data.raw, i);
203       TS reference = Value<T>(data_.get(), i);
204       if (CompareTwoValues(computed, reference)) {
205         good_output = false;
206         if (verbose) {
207           std::cerr << "  Tensor[" << tensor.name << "] index " << i << ": got "
208                     << computed << ", but expected " << reference << std::endl;
209         }
210       }
211     }
212     return good_output;
213   }
214 
215   bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
216   bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
217 
218   unique_void_ptr data_;
219   size_t num_elements_;
220   double relative_threshold_;
221   double absolute_threshold_;
222   int quantization_error_multiplier_;
223 };
224 
225 class TfLiteDriver::ShapeExpectation {
226  public:
ShapeExpectation(const string & csv_values)227   explicit ShapeExpectation(const string& csv_values)
228       : shape_(testing::Split<int32_t>(csv_values, ",")) {}
229 
CheckShape(bool verbose,const TfLiteTensor & tensor)230   bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
231     bool valid = true;
232     if (tensor.dims->size == shape_.size()) {
233       for (int i = 0; i < shape_.size(); ++i) {
234         if (shape_[i] != tensor.dims->data[i]) {
235           valid = false;
236         }
237       }
238     } else {
239       valid = false;
240     }
241     if (!valid && verbose) {
242       std::cerr << "Incorrect output shape while checking tensor "
243                 << tensor.name << std::endl;
244       std::cerr << "TFLite output shape: ";
245       for (int i = 0; i < tensor.dims->size; ++i) {
246         std::cerr << tensor.dims->data[i] << ", ";
247       }
248       std::cerr << std::endl;
249       std::cerr << "Expected output shape: ";
250       for (int i = 0; i < shape_.size(); ++i) {
251         std::cerr << shape_[i] << ", ";
252       }
253       std::cerr << std::endl;
254     }
255     return valid;
256   }
257 
258  private:
259   std::vector<int32_t> shape_;
260 };
261 
262 template <>
SetData(const string & csv_values)263 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
264   string s = absl::HexStringToBytes(csv_values);
265   data_ = make_type_erased_array<char>(s.size());
266   memcpy(data_.get(), s.data(), s.size());
267 }
268 
TypedCheckString(bool verbose,const TfLiteTensor & tensor)269 bool TfLiteDriver::DataExpectation::TypedCheckString(
270     bool verbose, const TfLiteTensor& tensor) {
271   if (tensor.data.raw == nullptr) {
272     if (verbose) {
273       std::cerr << "  got empty string" << std::endl;
274     }
275     return false;
276   }
277   int expected_num_strings = GetStringCount(data_.get());
278   int returned_num_strings = GetStringCount(&tensor);
279   if (expected_num_strings != returned_num_strings) {
280     if (verbose) {
281       std::cerr << "  string count differ: got " << returned_num_strings
282                 << ", but expected " << expected_num_strings << std::endl;
283     }
284     return false;
285   }
286   for (int i = 0; i < returned_num_strings; ++i) {
287     auto expected_ref = GetString(data_.get(), i);
288     auto returned_ref = GetString(&tensor, i);
289     if (expected_ref.len != returned_ref.len) {
290       if (verbose) {
291         std::cerr << "  index " << i << ": got string of size "
292                   << returned_ref.len << ", but expected size "
293                   << expected_ref.len << std::endl;
294       }
295       return false;
296     }
297     if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
298       if (verbose) {
299         std::cerr << "  index " << i << ": strings are different" << std::endl;
300       }
301       return false;
302     }
303   }
304 
305   return true;
306 }
307 
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)308 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
309                                                    const TfLiteTensor& tensor) {
310   auto* quantization =
311       reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
312   const float scale = quantization->scale->data[0];
313   const int32_t zero_point = quantization->zero_point->data[0];
314 
315   bool good_result = true;
316   int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
317   for (int i = 0; i < tensor.bytes / int_size; i++) {
318     int32_t computed =
319         tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
320     const float dequantized =
321         static_cast<float>(scale * (computed - zero_point));
322     int error_multiplier = quantization_error_multiplier_;
323     // If we are doing int16 symmetric quantization of activations, we need to
324     // bump up the potential error. Since the weights are quantized to 8 bits
325     // and the activations are 16bits, the output is could be getting
326     // effectively 8bit error instead of 16bit error. So we need to multiply the
327     // error mulitplier by 255 (the difference in number of values between a
328     // 16bit and 8bit number)
329     if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
330     const float reference = Value<float>(data_.get(), i);
331     if (std::abs(dequantized - reference) > error_multiplier * scale) {
332       if (verbose) {
333         std::cerr << "  index " << i << ": got " << dequantized
334                   << ", but expected " << reference << std::endl;
335       }
336       good_result = false;
337     }
338   }
339   return good_result;
340 }
341 
Check(bool verbose,const TfLiteTensor & tensor)342 bool TfLiteDriver::DataExpectation::Check(bool verbose,
343                                           const TfLiteTensor& tensor) {
344   if (InterpretAsQuantized(tensor)) {
345     return QuantizedCheck(verbose, tensor);
346   }
347 
348   switch (tensor.type) {
349     case kTfLiteFloat32:
350       return TypedCheck<float, float>(verbose, tensor);
351     case kTfLiteInt32:
352       return TypedCheck<int32_t, float>(verbose, tensor);
353     case kTfLiteUInt32:
354       return TypedCheck<uint32_t, float>(verbose, tensor);
355     case kTfLiteInt64:
356       return TypedCheck<int64_t, float>(verbose, tensor);
357     case kTfLiteUInt64:
358       return TypedCheck<uint64_t, float>(verbose, tensor);
359     case kTfLiteUInt8:
360       return TypedCheck<uint8_t, float>(verbose, tensor);
361     case kTfLiteInt8:
362       return TypedCheck<int8_t, float>(verbose, tensor);
363     case kTfLiteUInt16:
364       return TypedCheck<uint16_t, float>(verbose, tensor);
365     case kTfLiteInt16:
366       return TypedCheck<int16_t, float>(verbose, tensor);
367     case kTfLiteBool:
368       return TypedCheck<bool, float>(verbose, tensor);
369     case kTfLiteString:
370       return TypedCheckString(verbose, tensor);
371     case kTfLiteComplex64:
372       return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
373                                                                   tensor);
374     case kTfLiteComplex128:
375       return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
376                                                                     tensor);
377     case kTfLiteFloat64:
378       return TypedCheck<double, double>(verbose, tensor);
379     default:
380       fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
381       return false;
382   }
383 }
384 
385 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)386 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
387   return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
388                                                                          argv);
389 }
390 
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)391 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
392     : delegate_(nullptr, nullptr),
393       relative_threshold_(kRelativeThreshold),
394       absolute_threshold_(kAbsoluteThreshold),
395       quantization_error_multiplier_(kQuantizationErrorMultiplier) {
396   if (reference_kernel) {
397     resolver_ = std::make_unique<ops::builtin::BuiltinRefOpResolver>();
398   } else {
399     // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
400     // are fully validated against TfLite delegates.
401     resolver_ = std::make_unique<
402         ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
403     ops::builtin::BuiltinOpResolver* builtin_op_resolver_ =
404         reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
405     builtin_op_resolver_->AddCustom("IRFFT2D",
406                                     tflite::ops::custom::Register_IRFFT2D());
407     builtin_op_resolver_->AddCustom(
408         "AvgPool3D", tflite::ops::custom::Register_AVG_POOL_3D());
409     builtin_op_resolver_->AddCustom(
410         "MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
411     builtin_op_resolver_->AddCustom("Roll",
412                                     tflite::ops::custom::Register_ROLL());
413     tflite::ops::custom::AddGradientOps(builtin_op_resolver_);
414     tflite::ops::custom::AddParseExampleOp(builtin_op_resolver_);
415     tflite::ops::custom::AddPerceptionOps(builtin_op_resolver_);
416   }
417 
418   switch (delegate_type) {
419     case DelegateType::kNone:
420       break;
421     case DelegateType::kNnapi:
422       delegate_ = evaluation::CreateNNAPIDelegate();
423       break;
424     case DelegateType::kGpu:
425       delegate_ = evaluation::CreateGPUDelegate();
426       break;
427     case DelegateType::kFlex:
428 #if !defined(__APPLE__)
429       delegate_ = FlexDelegate::Create();
430 #endif
431       break;
432   }
433 }
434 
~TfLiteDriver()435 TfLiteDriver::~TfLiteDriver() {
436   for (auto t : tensors_to_deallocate_) {
437     DeallocateStringTensor(t.second);
438   }
439 }
440 
AllocateTensors()441 void TfLiteDriver::AllocateTensors() {
442   if (must_allocate_tensors_) {
443     if (interpreter_->AllocateTensors() != kTfLiteOk) {
444       Invalidate("Failed to allocate tensors");
445       return;
446     }
447     ResetLSTMStateTensors();
448     must_allocate_tensors_ = false;
449   }
450 }
451 
LoadModel(const string & bin_file_path,const string & signature)452 void TfLiteDriver::LoadModel(const string& bin_file_path,
453                              const string& signature) {
454   if (!IsValid()) return;
455 
456   model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
457   if (!model_) {
458     Invalidate("Failed to mmap model " + bin_file_path);
459     return;
460   }
461   InterpreterBuilder(*model_, *resolver_)(&interpreter_);
462   if (!interpreter_) {
463     Invalidate("Failed build interpreter");
464     return;
465   }
466   if (delegate_) {
467     if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
468       Invalidate("Unable to the build graph using the delegate");
469       return;
470     }
471   } else {
472     auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
473     for (auto& one : delegate_providers->CreateAllDelegates()) {
474       if (interpreter_->ModifyGraphWithDelegate(std::move(one.delegate)) !=
475           kTfLiteOk) {
476         Invalidate(
477             "Unable to the build graph using the delegate initialized from "
478             "tflite::KernelTestDelegateProviders");
479         return;
480       }
481     }
482   }
483 
484   must_allocate_tensors_ = true;
485 
486   signature_runner_ = interpreter_->GetSignatureRunner(signature.c_str());
487   if (signature_runner_) {
488     signature_inputs_ = interpreter_->signature_inputs(signature.c_str());
489     signature_outputs_ = interpreter_->signature_outputs(signature.c_str());
490   } else {
491     Invalidate("Unable to the fetch signature runner.");
492   }
493 }
494 
LoadModel(const string & bin_file_path)495 void TfLiteDriver::LoadModel(const string& bin_file_path) {
496   LoadModel(bin_file_path, kDefaultSignatureKey);
497 }
498 
ReshapeTensor(const string & name,const string & csv_values)499 void TfLiteDriver::ReshapeTensor(const string& name, const string& csv_values) {
500   if (!IsValid()) return;
501   if (signature_runner_->ResizeInputTensor(
502           name.c_str(), testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
503     Invalidate("Failed to resize input tensor " + name);
504     return;
505   }
506   must_allocate_tensors_ = true;
507 }
508 
ResetTensor(const std::string & name)509 void TfLiteDriver::ResetTensor(const std::string& name) {
510   if (!IsValid()) return;
511   auto* tensor = signature_runner_->input_tensor(name.c_str());
512   memset(tensor->data.raw, 0, tensor->bytes);
513 }
514 
Invoke(const std::vector<std::pair<string,string>> & inputs)515 void TfLiteDriver::Invoke(
516     const std::vector<std::pair<string, string>>& inputs) {
517   if (!IsValid()) return;
518   for (const auto& input : inputs) {
519     SetInput(input.first, input.second);
520   }
521   if (signature_runner_->Invoke() != kTfLiteOk) {
522     Invalidate("Failed to invoke interpreter");
523   }
524 }
525 
ReadOutput(const string & name)526 string TfLiteDriver::ReadOutput(const string& name) {
527   if (!IsValid()) return "";
528   return TensorValueToCsvString(signature_runner_->output_tensor(name.c_str()));
529 }
530 
CheckResults(const std::vector<std::pair<string,string>> & expected_outputs,const std::vector<std::pair<string,string>> & expected_output_shapes)531 bool TfLiteDriver::CheckResults(
532     const std::vector<std::pair<string, string>>& expected_outputs,
533     const std::vector<std::pair<string, string>>& expected_output_shapes) {
534   if (!IsValid()) return false;
535   bool success = true;
536   for (const auto& output : expected_outputs) {
537     SetExpectation(output.first, output.second);
538   }
539   for (const auto& shape : expected_output_shapes) {
540     SetShapeExpectation(shape.first, shape.second);
541   }
542 
543   for (const auto& p : expected_output_) {
544     int id = p.first;
545     auto* tensor = interpreter_->tensor(id);
546     if (!p.second->Check(/*verbose=*/false, *tensor)) {
547       // Do not invalidate anything here. Instead, simply output the
548       // differences and return false. Invalidating would prevent all
549       // subsequent invocations from running..
550       std::cerr << "TfLiteDriver: There were errors in invocation '"
551                 << GetInvocationId() << "', validating output tensor '" << id
552                 << "':" << std::endl;
553       p.second->Check(/*verbose=*/true, *tensor);
554       success = false;
555       SetOverallSuccess(false);
556     }
557   }
558   for (const auto& p : expected_output_shape_) {
559     int id = p.first;
560     auto* tensor = interpreter_->tensor(id);
561     if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
562       // Do not invalidate anything here. Instead, simply output the
563       // differences and return false. Invalidating would prevent all
564       // subsequent invocations from running..
565       std::cerr << "TfLiteDriver: There were errors in invocation '"
566                 << GetInvocationId()
567                 << "', validating the shape of output tensor '" << id
568                 << "':" << std::endl;
569       p.second->CheckShape(/*verbose=*/true, *tensor);
570       success = false;
571       SetOverallSuccess(false);
572     }
573   }
574   expected_output_.clear();
575   return success;
576 }
577 
GetOutputNames()578 std::vector<string> TfLiteDriver::GetOutputNames() {
579   if (!IsValid()) return {};
580   std::vector<string> names;
581   for (const auto* name : signature_runner_->output_names()) {
582     names.push_back(name);
583   }
584   return names;
585 }
586 
SetInput(const string & name,const string & csv_values)587 void TfLiteDriver::SetInput(const string& name, const string& csv_values) {
588   auto id = signature_inputs_[name];
589   auto* tensor = signature_runner_->input_tensor(name.c_str());
590   switch (tensor->type) {
591     case kTfLiteFloat64: {
592       const auto& values = testing::Split<double>(csv_values, ",");
593       if (!CheckSizes<double>(tensor->bytes, values.size())) return;
594       SetTensorData(values, tensor->data.raw);
595       break;
596     }
597     case kTfLiteFloat32: {
598       const auto& values = testing::Split<float>(csv_values, ",");
599       if (!CheckSizes<float>(tensor->bytes, values.size())) return;
600       SetTensorData(values, tensor->data.raw);
601       break;
602     }
603     case kTfLiteInt32: {
604       const auto& values = testing::Split<int32_t>(csv_values, ",");
605       if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
606       SetTensorData(values, tensor->data.raw);
607       break;
608     }
609     case kTfLiteUInt32: {
610       const auto& values = testing::Split<uint32_t>(csv_values, ",");
611       if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
612       SetTensorData(values, tensor->data.raw);
613       break;
614     }
615     case kTfLiteInt64: {
616       const auto& values = testing::Split<int64_t>(csv_values, ",");
617       if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
618       SetTensorData(values, tensor->data.raw);
619       break;
620     }
621     case kTfLiteUInt64: {
622       const auto& values = testing::Split<uint64_t>(csv_values, ",");
623       if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
624       SetTensorData(values, tensor->data.raw);
625       break;
626     }
627     case kTfLiteUInt8: {
628       const auto& values = testing::Split<uint8_t>(csv_values, ",");
629       if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
630       SetTensorData(values, tensor->data.raw);
631       break;
632     }
633     case kTfLiteInt8: {
634       const auto& values = testing::Split<int8_t>(csv_values, ",");
635       if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
636       SetTensorData(values, tensor->data.raw);
637       break;
638     }
639     case kTfLiteInt16: {
640       const auto& values = testing::Split<int16_t>(csv_values, ",");
641       if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
642       SetTensorData(values, tensor->data.raw);
643       break;
644     }
645     case kTfLiteUInt16: {
646       const auto& values = testing::Split<uint16_t>(csv_values, ",");
647       if (!CheckSizes<uint16_t>(tensor->bytes, values.size())) return;
648       SetTensorData(values, tensor->data.raw);
649       break;
650     }
651     case kTfLiteBool: {
652       const auto& values = testing::Split<bool>(csv_values, ",");
653       if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
654       SetTensorData(values, tensor->data.raw);
655       break;
656     }
657     case kTfLiteString: {
658       string s = absl::HexStringToBytes(csv_values);
659 
660       DeallocateStringTensor(tensors_to_deallocate_[id]);
661       AllocateStringTensor(id, s.size(), tensor);
662       memcpy(tensor->data.raw, s.data(), s.size());
663 
664       break;
665     }
666     case kTfLiteComplex64: {
667       const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
668       if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
669         return;
670       SetTensorData(values, tensor->data.raw);
671       break;
672     }
673     case kTfLiteComplex128: {
674       const auto& values =
675           testing::Split<std::complex<double>>(csv_values, ",");
676       if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
677         return;
678       SetTensorData(values, tensor->data.raw);
679       break;
680     }
681     default:
682       Invalidate(absl::StrCat("Unsupported tensor type ",
683                               TfLiteTypeGetName(tensor->type),
684                               " in TfLiteDriver::SetInput"));
685       return;
686   }
687 }
688 
SetThreshold(double relative_threshold,double absolute_threshold)689 void TfLiteDriver::SetThreshold(double relative_threshold,
690                                 double absolute_threshold) {
691   relative_threshold_ = relative_threshold;
692   absolute_threshold_ = absolute_threshold;
693 }
694 
SetQuantizationErrorMultiplier(int quantization_error_multiplier)695 void TfLiteDriver::SetQuantizationErrorMultiplier(
696     int quantization_error_multiplier) {
697   quantization_error_multiplier_ = quantization_error_multiplier;
698 }
699 
SetExpectation(const string & name,const string & csv_values)700 void TfLiteDriver::SetExpectation(const string& name,
701                                   const string& csv_values) {
702   auto id = signature_outputs_[name];
703   auto* tensor = signature_runner_->output_tensor(name.c_str());
704   if (expected_output_.count(id) != 0) {
705     Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
706   }
707   expected_output_[id] = std::make_unique<DataExpectation>(
708       relative_threshold_, absolute_threshold_, quantization_error_multiplier_);
709 
710   if (InterpretAsQuantized(*tensor)) {
711     expected_output_[id]->SetData<float>(csv_values);
712     return;
713   }
714 
715   switch (tensor->type) {
716     case kTfLiteFloat32:
717       expected_output_[id]->SetData<float>(csv_values);
718       break;
719     case kTfLiteInt32:
720       expected_output_[id]->SetData<int32_t>(csv_values);
721       break;
722     case kTfLiteUInt32:
723       expected_output_[id]->SetData<uint32_t>(csv_values);
724       break;
725     case kTfLiteInt64:
726       expected_output_[id]->SetData<int64_t>(csv_values);
727       break;
728     case kTfLiteUInt64:
729       expected_output_[id]->SetData<uint64_t>(csv_values);
730       break;
731     case kTfLiteUInt8:
732       expected_output_[id]->SetData<uint8_t>(csv_values);
733       break;
734     case kTfLiteInt8:
735       expected_output_[id]->SetData<int8_t>(csv_values);
736       break;
737     case kTfLiteUInt16:
738       expected_output_[id]->SetData<uint16_t>(csv_values);
739       break;
740     case kTfLiteInt16:
741       expected_output_[id]->SetData<int16_t>(csv_values);
742       break;
743     case kTfLiteBool:
744       expected_output_[id]->SetData<bool>(csv_values);
745       break;
746     case kTfLiteString:
747       expected_output_[id]->SetData<string>(csv_values);
748       break;
749     case kTfLiteFloat64:
750       expected_output_[id]->SetData<double>(csv_values);
751       break;
752     case kTfLiteComplex64:
753       expected_output_[id]->SetData<std::complex<float>>(csv_values);
754       break;
755     case kTfLiteComplex128:
756       expected_output_[id]->SetData<std::complex<double>>(csv_values);
757       break;
758     default:
759       Invalidate(absl::StrCat("Unsupported tensor type ",
760                               TfLiteTypeGetName(tensor->type),
761                               " in TfLiteDriver::SetExpectation"));
762       return;
763   }
764 }
765 
SetShapeExpectation(const string & name,const string & csv_values)766 void TfLiteDriver::SetShapeExpectation(const string& name,
767                                        const string& csv_values) {
768   auto id = signature_outputs_[name];
769   if (expected_output_shape_.count(id) != 0) {
770     Invalidate(
771         absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
772   }
773   expected_output_shape_[id] = std::make_unique<ShapeExpectation>(csv_values);
774 }
775 
ResetLSTMStateTensors()776 void TfLiteDriver::ResetLSTMStateTensors() {
777   interpreter_->ResetVariableTensors();
778 }
779 
TensorValueToCsvString(const TfLiteTensor * tensor)780 string TfLiteDriver::TensorValueToCsvString(const TfLiteTensor* tensor) {
781   int num_elements = 1;
782 
783   for (int i = 0; i < tensor->dims->size; ++i) {
784     num_elements *= tensor->dims->data[i];
785   }
786 
787   switch (tensor->type) {
788     case kTfLiteFloat32:
789       return JoinDefault(tensor->data.f, num_elements, ",");
790     case kTfLiteInt32:
791       return JoinDefault(tensor->data.i32, num_elements, ",");
792     case kTfLiteUInt32:
793       return JoinDefault(tensor->data.u32, num_elements, ",");
794     case kTfLiteInt64:
795       return JoinDefault(tensor->data.i64, num_elements, ",");
796     case kTfLiteUInt64:
797       return JoinDefault(tensor->data.u64, num_elements, ",");
798     case kTfLiteUInt8:
799       return Join(tensor->data.uint8, num_elements, ",");
800     case kTfLiteInt8:
801       return Join(tensor->data.int8, num_elements, ",");
802     case kTfLiteUInt16:
803       return Join(tensor->data.ui16, num_elements, ",");
804     case kTfLiteInt16:
805       return Join(tensor->data.i16, num_elements, ",");
806     case kTfLiteBool:
807       return JoinDefault(tensor->data.b, num_elements, ",");
808     default:
809       Invalidate(absl::StrCat("Unsupported tensor type ",
810                               TfLiteTypeGetName(tensor->type),
811                               " in TfLiteDriver::ReadOutput"));
812       return "";
813   }
814 }
815 
816 }  // namespace testing
817 }  // namespace tflite
818