• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <algorithm>
9 #include <cstddef>
10 #include <cstdlib>
11 #include <cstring>
12 #include <vector>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/subgraph.h>
16 
17 #include "subgraph-tester.h"
18 
19 namespace xnnpack {
20 
21 class RuntimeTester : public SubgraphTester {
22  public:
23   using SubgraphTester::SubgraphTester;
24 
25   template<typename T>
RunWithFusion()26   inline std::vector<T> RunWithFusion() {
27     Run();
28     std::vector<char>& tensor = this->external_tensors_.at(this->output_id_);
29     std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float));
30     std::memcpy(output.data(), tensor.data(), tensor.size());
31     return output;
32   }
33 
34   template<typename T>
RunWithoutFusion()35   inline std::vector<T> RunWithoutFusion() {
36     Run(XNN_FLAG_NO_OPERATOR_FUSION);
37     std::vector<char>& tensor = this->external_tensors_.at(this->output_id_);
38     std::vector<float> output = std::vector<float>(tensor.size() / sizeof(float));
39     memcpy(output.data(), tensor.data(), tensor.size());
40     return output;
41   }
42 
NumOperators()43   size_t NumOperators() {
44     size_t count = 0;
45     for (size_t i = 0; i < runtime_->num_ops; i++) {
46       if (runtime_->opdata[i].operator_objects[0] != NULL) {
47         count++;
48       }
49     }
50     return count;
51   }
52 
53  private:
54   void Run(uint32_t flags = 0) {
55     xnn_runtime_t runtime = nullptr;
56     ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(this->subgraph_.get(), nullptr, nullptr, flags, &runtime));
57     ASSERT_NE(nullptr, runtime);
58     runtime_.reset(runtime);
59 
60     std::vector<xnn_external_value> externals;
61     for (auto it = this->external_tensors_.begin(); it != this->external_tensors_.end(); ++it) {
62       if (it->first == this->output_id_) {
63         // Scramble output tensor.
64         std::fill(it->second.begin(), it->second.end(), 0xA8);
65       }
66       externals.push_back(xnn_external_value{it->first, it->second.data()});
67     }
68 
69     ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, externals.size(), externals.data()));
70     ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
71   };
72 
73   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{nullptr, xnn_delete_runtime};
74 };
75 
76 }  // namespace xnnpack
77