1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tflite_driver.h"
16
17 #include <algorithm>
18 #include <complex>
19 #include <cstdlib>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/escaping.h"
27 #include "absl/strings/numbers.h"
28 #include "tensorflow/lite/builtin_op_data.h"
29 #include "tensorflow/lite/c/c_api_types.h"
30 #include "tensorflow/lite/c/common.h"
31 #if !defined(__APPLE__)
32 #include "tensorflow/lite/delegates/flex/delegate.h"
33 #endif
34 #include "tensorflow/lite/kernels/custom_ops_register.h"
35 #include "tensorflow/lite/kernels/gradient/gradient_ops.h"
36 #include "tensorflow/lite/kernels/parse_example/parse_example.h"
37 #include "tensorflow/lite/kernels/perception/perception_ops.h"
38 #include "tensorflow/lite/kernels/register.h"
39 #include "tensorflow/lite/kernels/register_ref.h"
40 #include "tensorflow/lite/kernels/test_delegate_providers.h"
41 #include "tensorflow/lite/signature_runner.h"
42 #include "tensorflow/lite/string_util.h"
43 #include "tensorflow/lite/testing/join.h"
44 #include "tensorflow/lite/testing/split.h"
45 #include "tensorflow/lite/tools/evaluation/utils.h"
46
47 namespace tflite {
48 namespace testing {
49
50 namespace {
51 const double kRelativeThreshold = 1e-2f;
52 const double kAbsoluteThreshold = 1e-4f;
53 const char kDefaultSignatureKey[] = "serving_default";
54
55 // For quantized tests, we use a different error measurement from float ones.
56 // Assumes the baseline is a always a float TF model.
57 // Error of a quantized model compared to the baseline comes from two sources:
58 // 1. the math done with quantized inputs, and
59 // 2. quantization of the output.
60 // Assumes there is no error introduced by source 1, the theoretical maximum
61 // error allowed for the output is 0.5 * scale, because scale is equal to the
62 // size of the quantization bucket.
63 //
64 // As a result, we use `scale` as a unit for measuring the quantization error.
65 // To add the error introduced by source 1 as well, we need to relax the
66 // multiplier from 0.5 to a larger number, which is model/op dependent.
67 // The number below is good enough to account for both the two sources of error
68 // for most quantized op tests to pass.
69 const int kQuantizationErrorMultiplier = 4;
70
71 // Returns the value in the given position in a tensor.
72 template <typename T>
Value(void * data,int index)73 T Value(void* data, int index) {
74 return static_cast<T*>(data)[index];
75 }
76
77 template <typename T>
SetTensorData(const std::vector<T> & values,void * data)78 void SetTensorData(const std::vector<T>& values, void* data) {
79 T* input_ptr = static_cast<T*>(data);
80 std::copy(values.begin(), values.end(), input_ptr);
81 }
82
83 // Implement type erasure with unique_ptr with custom deleter
84 using unique_void_ptr = std::unique_ptr<void, void (*)(void*)>;
85
86 template <typename T>
make_type_erased_array(size_t size)87 unique_void_ptr make_type_erased_array(size_t size) {
88 return unique_void_ptr(static_cast<void*>(new T[size]),
89 [](void* data) { delete[] static_cast<T*>(data); });
90 }
91
InterpretAsQuantized(const TfLiteTensor & tensor)92 bool InterpretAsQuantized(const TfLiteTensor& tensor) {
93 if (tensor.quantization.type == kTfLiteNoQuantization) return false;
94
95 // Quantized single-op models with uint8 input/output type are only used for
96 // EdgeTPU tests.
97 // EdgeTPU tests need to read the quantized values as-is to check for
98 // bit-exactness. As a result we don't interpret the tensor as quantized.
99 // TODO(b/176121243): Add an option to interpret uint8 buffers as
100 // non-quantized type and set if from the child class.
101 if (tensor.type == kTfLiteUInt8) return false;
102
103 if (tensor.quantization.params != nullptr) {
104 auto* quantization =
105 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
106 if (quantization->scale != nullptr && quantization->scale->size == 1 &&
107 quantization->zero_point != nullptr &&
108 quantization->zero_point->size == 1) {
109 return true;
110 }
111 }
112 return false;
113 }
114 } // namespace
115
116 class TfLiteDriver::DataExpectation {
117 public:
DataExpectation(double relative_threshold,double absolute_threshold,int quantization_error_multiplier)118 DataExpectation(double relative_threshold, double absolute_threshold,
119 int quantization_error_multiplier)
120 : data_(nullptr, nullptr),
121 num_elements_(0),
122 relative_threshold_(relative_threshold),
123 absolute_threshold_(absolute_threshold),
124 quantization_error_multiplier_(quantization_error_multiplier) {}
125
126 template <typename T>
SetData(const string & csv_values)127 void SetData(const string& csv_values) {
128 const auto& values = testing::Split<T>(csv_values, ",");
129 num_elements_ = values.size();
130 data_ = make_type_erased_array<T>(num_elements_);
131 SetTensorData(values, data_.get());
132 }
133
134 bool Check(bool verbose, const TfLiteTensor& tensor);
135
136 private:
CompareTwoValuesHelper(float v1,float v2)137 bool CompareTwoValuesHelper(float v1, float v2) {
138 if (std::isnan(v1) || std::isnan(v2)) {
139 return !(std::isnan(v1) && std::isnan(v2));
140 }
141
142 float diff = std::abs(v1 - v2);
143 bool error_is_large = false;
144 // For very small numbers, try absolute error, otherwise go with
145 // relative.
146 if (std::abs(v2) < relative_threshold_) {
147 error_is_large = (diff > absolute_threshold_);
148 } else {
149 error_is_large = (diff > relative_threshold_ * std::abs(v2));
150 }
151 return error_is_large;
152 }
153
CompareTwoValuesHelper(double v1,double v2)154 bool CompareTwoValuesHelper(double v1, double v2) {
155 if (std::isnan(v1) || std::isnan(v2)) {
156 return !(std::isnan(v1) && std::isnan(v2));
157 }
158
159 double diff = std::abs(v1 - v2);
160 bool error_is_large = false;
161 // For very small numbers, try absolute error, otherwise go with
162 // relative.
163 if (std::abs(v2) < relative_threshold_) {
164 error_is_large = (diff > absolute_threshold_);
165 } else {
166 error_is_large = (diff > relative_threshold_ * std::abs(v2));
167 }
168 return error_is_large;
169 }
170
CompareTwoValues(std::complex<float> v1,std::complex<float> v2)171 bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
172 return CompareTwoValues(v1.real(), v2.real()) ||
173 CompareTwoValues(v1.imag(), v2.imag());
174 }
175
CompareTwoValues(std::complex<double> v1,std::complex<double> v2)176 bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
177 return CompareTwoValues(v1.real(), v2.real()) ||
178 CompareTwoValues(v1.imag(), v2.imag());
179 }
180
CompareTwoValues(float v1,float v2)181 bool CompareTwoValues(float v1, float v2) {
182 return CompareTwoValuesHelper(v1, v2);
183 }
184
CompareTwoValues(double v1,double v2)185 bool CompareTwoValues(double v1, double v2) {
186 return CompareTwoValuesHelper(v1, v2);
187 }
188
189 template <typename T, typename TS>
TypedCheck(bool verbose,const TfLiteTensor & tensor)190 bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
191 size_t tensor_size = tensor.bytes / sizeof(T);
192
193 if (tensor_size != num_elements_) {
194 std::cerr << "Expected a tensor with " << num_elements_
195 << " elements, got " << tensor_size << std::endl;
196 std::cerr << "while checking tensor " << tensor.name << std::endl;
197 return false;
198 }
199
200 bool good_output = true;
201 for (int i = 0; i < tensor_size; ++i) {
202 TS computed = Value<T>(tensor.data.raw, i);
203 TS reference = Value<T>(data_.get(), i);
204 if (CompareTwoValues(computed, reference)) {
205 good_output = false;
206 if (verbose) {
207 std::cerr << " Tensor[" << tensor.name << "] index " << i << ": got "
208 << computed << ", but expected " << reference << std::endl;
209 }
210 }
211 }
212 return good_output;
213 }
214
215 bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
216 bool QuantizedCheck(bool verbose, const TfLiteTensor& tensor);
217
218 unique_void_ptr data_;
219 size_t num_elements_;
220 double relative_threshold_;
221 double absolute_threshold_;
222 int quantization_error_multiplier_;
223 };
224
225 class TfLiteDriver::ShapeExpectation {
226 public:
ShapeExpectation(const string & csv_values)227 explicit ShapeExpectation(const string& csv_values)
228 : shape_(testing::Split<int32_t>(csv_values, ",")) {}
229
CheckShape(bool verbose,const TfLiteTensor & tensor)230 bool CheckShape(bool verbose, const TfLiteTensor& tensor) {
231 bool valid = true;
232 if (tensor.dims->size == shape_.size()) {
233 for (int i = 0; i < shape_.size(); ++i) {
234 if (shape_[i] != tensor.dims->data[i]) {
235 valid = false;
236 }
237 }
238 } else {
239 valid = false;
240 }
241 if (!valid && verbose) {
242 std::cerr << "Incorrect output shape while checking tensor "
243 << tensor.name << std::endl;
244 std::cerr << "TFLite output shape: ";
245 for (int i = 0; i < tensor.dims->size; ++i) {
246 std::cerr << tensor.dims->data[i] << ", ";
247 }
248 std::cerr << std::endl;
249 std::cerr << "Expected output shape: ";
250 for (int i = 0; i < shape_.size(); ++i) {
251 std::cerr << shape_[i] << ", ";
252 }
253 std::cerr << std::endl;
254 }
255 return valid;
256 }
257
258 private:
259 std::vector<int32_t> shape_;
260 };
261
262 template <>
SetData(const string & csv_values)263 void TfLiteDriver::DataExpectation::SetData<string>(const string& csv_values) {
264 string s = absl::HexStringToBytes(csv_values);
265 data_ = make_type_erased_array<char>(s.size());
266 memcpy(data_.get(), s.data(), s.size());
267 }
268
TypedCheckString(bool verbose,const TfLiteTensor & tensor)269 bool TfLiteDriver::DataExpectation::TypedCheckString(
270 bool verbose, const TfLiteTensor& tensor) {
271 if (tensor.data.raw == nullptr) {
272 if (verbose) {
273 std::cerr << " got empty string" << std::endl;
274 }
275 return false;
276 }
277 int expected_num_strings = GetStringCount(data_.get());
278 int returned_num_strings = GetStringCount(&tensor);
279 if (expected_num_strings != returned_num_strings) {
280 if (verbose) {
281 std::cerr << " string count differ: got " << returned_num_strings
282 << ", but expected " << expected_num_strings << std::endl;
283 }
284 return false;
285 }
286 for (int i = 0; i < returned_num_strings; ++i) {
287 auto expected_ref = GetString(data_.get(), i);
288 auto returned_ref = GetString(&tensor, i);
289 if (expected_ref.len != returned_ref.len) {
290 if (verbose) {
291 std::cerr << " index " << i << ": got string of size "
292 << returned_ref.len << ", but expected size "
293 << expected_ref.len << std::endl;
294 }
295 return false;
296 }
297 if (strncmp(expected_ref.str, returned_ref.str, returned_ref.len) != 0) {
298 if (verbose) {
299 std::cerr << " index " << i << ": strings are different" << std::endl;
300 }
301 return false;
302 }
303 }
304
305 return true;
306 }
307
QuantizedCheck(bool verbose,const TfLiteTensor & tensor)308 bool TfLiteDriver::DataExpectation::QuantizedCheck(bool verbose,
309 const TfLiteTensor& tensor) {
310 auto* quantization =
311 reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
312 const float scale = quantization->scale->data[0];
313 const int32_t zero_point = quantization->zero_point->data[0];
314
315 bool good_result = true;
316 int int_size = tensor.type == kTfLiteInt8 ? 1 : 2;
317 for (int i = 0; i < tensor.bytes / int_size; i++) {
318 int32_t computed =
319 tensor.type == kTfLiteInt8 ? tensor.data.int8[i] : tensor.data.i16[i];
320 const float dequantized =
321 static_cast<float>(scale * (computed - zero_point));
322 int error_multiplier = quantization_error_multiplier_;
323 // If we are doing int16 symmetric quantization of activations, we need to
324 // bump up the potential error. Since the weights are quantized to 8 bits
325 // and the activations are 16bits, the output is could be getting
326 // effectively 8bit error instead of 16bit error. So we need to multiply the
327 // error mulitplier by 255 (the difference in number of values between a
328 // 16bit and 8bit number)
329 if (tensor.type == kTfLiteInt16) error_multiplier *= 255;
330 const float reference = Value<float>(data_.get(), i);
331 if (std::abs(dequantized - reference) > error_multiplier * scale) {
332 if (verbose) {
333 std::cerr << " index " << i << ": got " << dequantized
334 << ", but expected " << reference << std::endl;
335 }
336 good_result = false;
337 }
338 }
339 return good_result;
340 }
341
Check(bool verbose,const TfLiteTensor & tensor)342 bool TfLiteDriver::DataExpectation::Check(bool verbose,
343 const TfLiteTensor& tensor) {
344 if (InterpretAsQuantized(tensor)) {
345 return QuantizedCheck(verbose, tensor);
346 }
347
348 switch (tensor.type) {
349 case kTfLiteFloat32:
350 return TypedCheck<float, float>(verbose, tensor);
351 case kTfLiteInt32:
352 return TypedCheck<int32_t, float>(verbose, tensor);
353 case kTfLiteUInt32:
354 return TypedCheck<uint32_t, float>(verbose, tensor);
355 case kTfLiteInt64:
356 return TypedCheck<int64_t, float>(verbose, tensor);
357 case kTfLiteUInt64:
358 return TypedCheck<uint64_t, float>(verbose, tensor);
359 case kTfLiteUInt8:
360 return TypedCheck<uint8_t, float>(verbose, tensor);
361 case kTfLiteInt8:
362 return TypedCheck<int8_t, float>(verbose, tensor);
363 case kTfLiteUInt16:
364 return TypedCheck<uint16_t, float>(verbose, tensor);
365 case kTfLiteInt16:
366 return TypedCheck<int16_t, float>(verbose, tensor);
367 case kTfLiteBool:
368 return TypedCheck<bool, float>(verbose, tensor);
369 case kTfLiteString:
370 return TypedCheckString(verbose, tensor);
371 case kTfLiteComplex64:
372 return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
373 tensor);
374 case kTfLiteComplex128:
375 return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
376 tensor);
377 case kTfLiteFloat64:
378 return TypedCheck<double, double>(verbose, tensor);
379 default:
380 fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
381 return false;
382 }
383 }
384
385 /* static */
InitTestDelegateProviders(int * argc,const char ** argv)386 bool TfLiteDriver::InitTestDelegateProviders(int* argc, const char** argv) {
387 return tflite::KernelTestDelegateProviders::Get()->InitFromCmdlineArgs(argc,
388 argv);
389 }
390
TfLiteDriver(DelegateType delegate_type,bool reference_kernel)391 TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
392 : delegate_(nullptr, nullptr),
393 relative_threshold_(kRelativeThreshold),
394 absolute_threshold_(kAbsoluteThreshold),
395 quantization_error_multiplier_(kQuantizationErrorMultiplier) {
396 if (reference_kernel) {
397 resolver_ = std::make_unique<ops::builtin::BuiltinRefOpResolver>();
398 } else {
399 // TODO(b/168278077): change back to use BuiltinOpResolver after zip tests
400 // are fully validated against TfLite delegates.
401 resolver_ = std::make_unique<
402 ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
403 ops::builtin::BuiltinOpResolver* builtin_op_resolver_ =
404 reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
405 builtin_op_resolver_->AddCustom("IRFFT2D",
406 tflite::ops::custom::Register_IRFFT2D());
407 builtin_op_resolver_->AddCustom(
408 "AvgPool3D", tflite::ops::custom::Register_AVG_POOL_3D());
409 builtin_op_resolver_->AddCustom(
410 "MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
411 builtin_op_resolver_->AddCustom("Roll",
412 tflite::ops::custom::Register_ROLL());
413 tflite::ops::custom::AddGradientOps(builtin_op_resolver_);
414 tflite::ops::custom::AddParseExampleOp(builtin_op_resolver_);
415 tflite::ops::custom::AddPerceptionOps(builtin_op_resolver_);
416 }
417
418 switch (delegate_type) {
419 case DelegateType::kNone:
420 break;
421 case DelegateType::kNnapi:
422 delegate_ = evaluation::CreateNNAPIDelegate();
423 break;
424 case DelegateType::kGpu:
425 delegate_ = evaluation::CreateGPUDelegate();
426 break;
427 case DelegateType::kFlex:
428 #if !defined(__APPLE__)
429 delegate_ = FlexDelegate::Create();
430 #endif
431 break;
432 }
433 }
434
~TfLiteDriver()435 TfLiteDriver::~TfLiteDriver() {
436 for (auto t : tensors_to_deallocate_) {
437 DeallocateStringTensor(t.second);
438 }
439 }
440
AllocateTensors()441 void TfLiteDriver::AllocateTensors() {
442 if (must_allocate_tensors_) {
443 if (interpreter_->AllocateTensors() != kTfLiteOk) {
444 Invalidate("Failed to allocate tensors");
445 return;
446 }
447 ResetLSTMStateTensors();
448 must_allocate_tensors_ = false;
449 }
450 }
451
LoadModel(const string & bin_file_path,const string & signature)452 void TfLiteDriver::LoadModel(const string& bin_file_path,
453 const string& signature) {
454 if (!IsValid()) return;
455
456 model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
457 if (!model_) {
458 Invalidate("Failed to mmap model " + bin_file_path);
459 return;
460 }
461 InterpreterBuilder(*model_, *resolver_)(&interpreter_);
462 if (!interpreter_) {
463 Invalidate("Failed build interpreter");
464 return;
465 }
466 if (delegate_) {
467 if (interpreter_->ModifyGraphWithDelegate(delegate_.get()) != kTfLiteOk) {
468 Invalidate("Unable to the build graph using the delegate");
469 return;
470 }
471 } else {
472 auto* delegate_providers = tflite::KernelTestDelegateProviders::Get();
473 for (auto& one : delegate_providers->CreateAllDelegates()) {
474 if (interpreter_->ModifyGraphWithDelegate(std::move(one.delegate)) !=
475 kTfLiteOk) {
476 Invalidate(
477 "Unable to the build graph using the delegate initialized from "
478 "tflite::KernelTestDelegateProviders");
479 return;
480 }
481 }
482 }
483
484 must_allocate_tensors_ = true;
485
486 signature_runner_ = interpreter_->GetSignatureRunner(signature.c_str());
487 if (signature_runner_) {
488 signature_inputs_ = interpreter_->signature_inputs(signature.c_str());
489 signature_outputs_ = interpreter_->signature_outputs(signature.c_str());
490 } else {
491 Invalidate("Unable to the fetch signature runner.");
492 }
493 }
494
LoadModel(const string & bin_file_path)495 void TfLiteDriver::LoadModel(const string& bin_file_path) {
496 LoadModel(bin_file_path, kDefaultSignatureKey);
497 }
498
ReshapeTensor(const string & name,const string & csv_values)499 void TfLiteDriver::ReshapeTensor(const string& name, const string& csv_values) {
500 if (!IsValid()) return;
501 if (signature_runner_->ResizeInputTensor(
502 name.c_str(), testing::Split<int>(csv_values, ",")) != kTfLiteOk) {
503 Invalidate("Failed to resize input tensor " + name);
504 return;
505 }
506 must_allocate_tensors_ = true;
507 }
508
ResetTensor(const std::string & name)509 void TfLiteDriver::ResetTensor(const std::string& name) {
510 if (!IsValid()) return;
511 auto* tensor = signature_runner_->input_tensor(name.c_str());
512 memset(tensor->data.raw, 0, tensor->bytes);
513 }
514
Invoke(const std::vector<std::pair<string,string>> & inputs)515 void TfLiteDriver::Invoke(
516 const std::vector<std::pair<string, string>>& inputs) {
517 if (!IsValid()) return;
518 for (const auto& input : inputs) {
519 SetInput(input.first, input.second);
520 }
521 if (signature_runner_->Invoke() != kTfLiteOk) {
522 Invalidate("Failed to invoke interpreter");
523 }
524 }
525
ReadOutput(const string & name)526 string TfLiteDriver::ReadOutput(const string& name) {
527 if (!IsValid()) return "";
528 return TensorValueToCsvString(signature_runner_->output_tensor(name.c_str()));
529 }
530
CheckResults(const std::vector<std::pair<string,string>> & expected_outputs,const std::vector<std::pair<string,string>> & expected_output_shapes)531 bool TfLiteDriver::CheckResults(
532 const std::vector<std::pair<string, string>>& expected_outputs,
533 const std::vector<std::pair<string, string>>& expected_output_shapes) {
534 if (!IsValid()) return false;
535 bool success = true;
536 for (const auto& output : expected_outputs) {
537 SetExpectation(output.first, output.second);
538 }
539 for (const auto& shape : expected_output_shapes) {
540 SetShapeExpectation(shape.first, shape.second);
541 }
542
543 for (const auto& p : expected_output_) {
544 int id = p.first;
545 auto* tensor = interpreter_->tensor(id);
546 if (!p.second->Check(/*verbose=*/false, *tensor)) {
547 // Do not invalidate anything here. Instead, simply output the
548 // differences and return false. Invalidating would prevent all
549 // subsequent invocations from running..
550 std::cerr << "TfLiteDriver: There were errors in invocation '"
551 << GetInvocationId() << "', validating output tensor '" << id
552 << "':" << std::endl;
553 p.second->Check(/*verbose=*/true, *tensor);
554 success = false;
555 SetOverallSuccess(false);
556 }
557 }
558 for (const auto& p : expected_output_shape_) {
559 int id = p.first;
560 auto* tensor = interpreter_->tensor(id);
561 if (!p.second->CheckShape(/*verbose=*/false, *tensor)) {
562 // Do not invalidate anything here. Instead, simply output the
563 // differences and return false. Invalidating would prevent all
564 // subsequent invocations from running..
565 std::cerr << "TfLiteDriver: There were errors in invocation '"
566 << GetInvocationId()
567 << "', validating the shape of output tensor '" << id
568 << "':" << std::endl;
569 p.second->CheckShape(/*verbose=*/true, *tensor);
570 success = false;
571 SetOverallSuccess(false);
572 }
573 }
574 expected_output_.clear();
575 return success;
576 }
577
GetOutputNames()578 std::vector<string> TfLiteDriver::GetOutputNames() {
579 if (!IsValid()) return {};
580 std::vector<string> names;
581 for (const auto* name : signature_runner_->output_names()) {
582 names.push_back(name);
583 }
584 return names;
585 }
586
SetInput(const string & name,const string & csv_values)587 void TfLiteDriver::SetInput(const string& name, const string& csv_values) {
588 auto id = signature_inputs_[name];
589 auto* tensor = signature_runner_->input_tensor(name.c_str());
590 switch (tensor->type) {
591 case kTfLiteFloat64: {
592 const auto& values = testing::Split<double>(csv_values, ",");
593 if (!CheckSizes<double>(tensor->bytes, values.size())) return;
594 SetTensorData(values, tensor->data.raw);
595 break;
596 }
597 case kTfLiteFloat32: {
598 const auto& values = testing::Split<float>(csv_values, ",");
599 if (!CheckSizes<float>(tensor->bytes, values.size())) return;
600 SetTensorData(values, tensor->data.raw);
601 break;
602 }
603 case kTfLiteInt32: {
604 const auto& values = testing::Split<int32_t>(csv_values, ",");
605 if (!CheckSizes<int32_t>(tensor->bytes, values.size())) return;
606 SetTensorData(values, tensor->data.raw);
607 break;
608 }
609 case kTfLiteUInt32: {
610 const auto& values = testing::Split<uint32_t>(csv_values, ",");
611 if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
612 SetTensorData(values, tensor->data.raw);
613 break;
614 }
615 case kTfLiteInt64: {
616 const auto& values = testing::Split<int64_t>(csv_values, ",");
617 if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
618 SetTensorData(values, tensor->data.raw);
619 break;
620 }
621 case kTfLiteUInt64: {
622 const auto& values = testing::Split<uint64_t>(csv_values, ",");
623 if (!CheckSizes<uint64_t>(tensor->bytes, values.size())) return;
624 SetTensorData(values, tensor->data.raw);
625 break;
626 }
627 case kTfLiteUInt8: {
628 const auto& values = testing::Split<uint8_t>(csv_values, ",");
629 if (!CheckSizes<uint8_t>(tensor->bytes, values.size())) return;
630 SetTensorData(values, tensor->data.raw);
631 break;
632 }
633 case kTfLiteInt8: {
634 const auto& values = testing::Split<int8_t>(csv_values, ",");
635 if (!CheckSizes<int8_t>(tensor->bytes, values.size())) return;
636 SetTensorData(values, tensor->data.raw);
637 break;
638 }
639 case kTfLiteInt16: {
640 const auto& values = testing::Split<int16_t>(csv_values, ",");
641 if (!CheckSizes<int16_t>(tensor->bytes, values.size())) return;
642 SetTensorData(values, tensor->data.raw);
643 break;
644 }
645 case kTfLiteUInt16: {
646 const auto& values = testing::Split<uint16_t>(csv_values, ",");
647 if (!CheckSizes<uint16_t>(tensor->bytes, values.size())) return;
648 SetTensorData(values, tensor->data.raw);
649 break;
650 }
651 case kTfLiteBool: {
652 const auto& values = testing::Split<bool>(csv_values, ",");
653 if (!CheckSizes<bool>(tensor->bytes, values.size())) return;
654 SetTensorData(values, tensor->data.raw);
655 break;
656 }
657 case kTfLiteString: {
658 string s = absl::HexStringToBytes(csv_values);
659
660 DeallocateStringTensor(tensors_to_deallocate_[id]);
661 AllocateStringTensor(id, s.size(), tensor);
662 memcpy(tensor->data.raw, s.data(), s.size());
663
664 break;
665 }
666 case kTfLiteComplex64: {
667 const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
668 if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
669 return;
670 SetTensorData(values, tensor->data.raw);
671 break;
672 }
673 case kTfLiteComplex128: {
674 const auto& values =
675 testing::Split<std::complex<double>>(csv_values, ",");
676 if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
677 return;
678 SetTensorData(values, tensor->data.raw);
679 break;
680 }
681 default:
682 Invalidate(absl::StrCat("Unsupported tensor type ",
683 TfLiteTypeGetName(tensor->type),
684 " in TfLiteDriver::SetInput"));
685 return;
686 }
687 }
688
SetThreshold(double relative_threshold,double absolute_threshold)689 void TfLiteDriver::SetThreshold(double relative_threshold,
690 double absolute_threshold) {
691 relative_threshold_ = relative_threshold;
692 absolute_threshold_ = absolute_threshold;
693 }
694
SetQuantizationErrorMultiplier(int quantization_error_multiplier)695 void TfLiteDriver::SetQuantizationErrorMultiplier(
696 int quantization_error_multiplier) {
697 quantization_error_multiplier_ = quantization_error_multiplier;
698 }
699
SetExpectation(const string & name,const string & csv_values)700 void TfLiteDriver::SetExpectation(const string& name,
701 const string& csv_values) {
702 auto id = signature_outputs_[name];
703 auto* tensor = signature_runner_->output_tensor(name.c_str());
704 if (expected_output_.count(id) != 0) {
705 Invalidate(absl::StrCat("Overridden expectation for tensor '", id, "'"));
706 }
707 expected_output_[id] = std::make_unique<DataExpectation>(
708 relative_threshold_, absolute_threshold_, quantization_error_multiplier_);
709
710 if (InterpretAsQuantized(*tensor)) {
711 expected_output_[id]->SetData<float>(csv_values);
712 return;
713 }
714
715 switch (tensor->type) {
716 case kTfLiteFloat32:
717 expected_output_[id]->SetData<float>(csv_values);
718 break;
719 case kTfLiteInt32:
720 expected_output_[id]->SetData<int32_t>(csv_values);
721 break;
722 case kTfLiteUInt32:
723 expected_output_[id]->SetData<uint32_t>(csv_values);
724 break;
725 case kTfLiteInt64:
726 expected_output_[id]->SetData<int64_t>(csv_values);
727 break;
728 case kTfLiteUInt64:
729 expected_output_[id]->SetData<uint64_t>(csv_values);
730 break;
731 case kTfLiteUInt8:
732 expected_output_[id]->SetData<uint8_t>(csv_values);
733 break;
734 case kTfLiteInt8:
735 expected_output_[id]->SetData<int8_t>(csv_values);
736 break;
737 case kTfLiteUInt16:
738 expected_output_[id]->SetData<uint16_t>(csv_values);
739 break;
740 case kTfLiteInt16:
741 expected_output_[id]->SetData<int16_t>(csv_values);
742 break;
743 case kTfLiteBool:
744 expected_output_[id]->SetData<bool>(csv_values);
745 break;
746 case kTfLiteString:
747 expected_output_[id]->SetData<string>(csv_values);
748 break;
749 case kTfLiteFloat64:
750 expected_output_[id]->SetData<double>(csv_values);
751 break;
752 case kTfLiteComplex64:
753 expected_output_[id]->SetData<std::complex<float>>(csv_values);
754 break;
755 case kTfLiteComplex128:
756 expected_output_[id]->SetData<std::complex<double>>(csv_values);
757 break;
758 default:
759 Invalidate(absl::StrCat("Unsupported tensor type ",
760 TfLiteTypeGetName(tensor->type),
761 " in TfLiteDriver::SetExpectation"));
762 return;
763 }
764 }
765
SetShapeExpectation(const string & name,const string & csv_values)766 void TfLiteDriver::SetShapeExpectation(const string& name,
767 const string& csv_values) {
768 auto id = signature_outputs_[name];
769 if (expected_output_shape_.count(id) != 0) {
770 Invalidate(
771 absl::StrCat("Overridden shape expectation for tensor '", id, "'"));
772 }
773 expected_output_shape_[id] = std::make_unique<ShapeExpectation>(csv_values);
774 }
775
ResetLSTMStateTensors()776 void TfLiteDriver::ResetLSTMStateTensors() {
777 interpreter_->ResetVariableTensors();
778 }
779
TensorValueToCsvString(const TfLiteTensor * tensor)780 string TfLiteDriver::TensorValueToCsvString(const TfLiteTensor* tensor) {
781 int num_elements = 1;
782
783 for (int i = 0; i < tensor->dims->size; ++i) {
784 num_elements *= tensor->dims->data[i];
785 }
786
787 switch (tensor->type) {
788 case kTfLiteFloat32:
789 return JoinDefault(tensor->data.f, num_elements, ",");
790 case kTfLiteInt32:
791 return JoinDefault(tensor->data.i32, num_elements, ",");
792 case kTfLiteUInt32:
793 return JoinDefault(tensor->data.u32, num_elements, ",");
794 case kTfLiteInt64:
795 return JoinDefault(tensor->data.i64, num_elements, ",");
796 case kTfLiteUInt64:
797 return JoinDefault(tensor->data.u64, num_elements, ",");
798 case kTfLiteUInt8:
799 return Join(tensor->data.uint8, num_elements, ",");
800 case kTfLiteInt8:
801 return Join(tensor->data.int8, num_elements, ",");
802 case kTfLiteUInt16:
803 return Join(tensor->data.ui16, num_elements, ",");
804 case kTfLiteInt16:
805 return Join(tensor->data.i16, num_elements, ",");
806 case kTfLiteBool:
807 return JoinDefault(tensor->data.b, num_elements, ",");
808 default:
809 Invalidate(absl::StrCat("Unsupported tensor type ",
810 TfLiteTypeGetName(tensor->type),
811 " in TfLiteDriver::ReadOutput"));
812 return "";
813 }
814 }
815
816 } // namespace testing
817 } // namespace tflite
818