1 // Copyright 2020 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <algorithm> 9 #include <cstddef> 10 #include <cstdlib> 11 #include <cstring> 12 #include <vector> 13 14 #include <xnnpack.h> 15 #include <xnnpack/subgraph.h> 16 17 #include "subgraph-tester.h" 18 19 namespace xnnpack { 20 21 class RuntimeTester : public SubgraphTester { 22 public: 23 using SubgraphTester::SubgraphTester; 24 25 template<typename T> RunWithFusion()26 inline std::vector<T> RunWithFusion() { 27 Run(); 28 std::vector<char>& tensor = this->external_tensors_.at(this->output_id_); 29 std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float)); 30 std::memcpy(output.data(), tensor.data(), tensor.size()); 31 return output; 32 } 33 34 template<typename T> RunWithoutFusion()35 inline std::vector<T> RunWithoutFusion() { 36 Run(XNN_FLAG_NO_OPERATOR_FUSION); 37 std::vector<char>& tensor = this->external_tensors_.at(this->output_id_); 38 std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float)); 39 memcpy(output.data(), tensor.data(), tensor.size()); 40 return output; 41 } 42 NumOperators()43 size_t NumOperators() { 44 size_t count = 0; 45 for (size_t i = 0; i < runtime_->num_ops; i++) { 46 if (runtime_->opdata[i].operator_objects[0] != NULL) { 47 count++; 48 } 49 } 50 return count; 51 } 52 53 private: 54 void Run(uint32_t flags = 0) { 55 xnn_runtime_t runtime = nullptr; 56 ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(this->subgraph_.get(), nullptr, nullptr, flags, &runtime)); 57 ASSERT_NE(nullptr, runtime); 58 runtime_.reset(runtime); 59 60 std::vector<xnn_external_value> externals; 61 for (auto it = this->external_tensors_.begin(); it != this->external_tensors_.end(); ++it) { 62 if (it->first == this->output_id_) { 63 // Scramble output tensor. 64 std::fill(it->second.begin(), it->second.end(), 0xA8); 65 } 66 externals.push_back(xnn_external_value{it->first, it->second.data()}); 67 } 68 69 ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, externals.size(), externals.data())); 70 ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); 71 }; 72 73 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{nullptr, xnn_delete_runtime}; 74 }; 75 76 } // namespace xnnpack 77