• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cstddef>
13 #include <cstdlib>
14 #include <functional>
15 #include <random>
16 #include <vector>
17 
18 #include <xnnpack.h>
19 
20 
21 class NegateOperatorTester {
22  public:
channels(size_t channels)23   inline NegateOperatorTester& channels(size_t channels) {
24     assert(channels != 0);
25     this->channels_ = channels;
26     return *this;
27   }
28 
channels()29   inline size_t channels() const {
30     return this->channels_;
31   }
32 
input_stride(size_t input_stride)33   inline NegateOperatorTester& input_stride(size_t input_stride) {
34     assert(input_stride != 0);
35     this->input_stride_ = input_stride;
36     return *this;
37   }
38 
input_stride()39   inline size_t input_stride() const {
40     if (this->input_stride_ == 0) {
41       return this->channels_;
42     } else {
43       assert(this->input_stride_ >= this->channels_);
44       return this->input_stride_;
45     }
46   }
47 
output_stride(size_t output_stride)48   inline NegateOperatorTester& output_stride(size_t output_stride) {
49     assert(output_stride != 0);
50     this->output_stride_ = output_stride;
51     return *this;
52   }
53 
output_stride()54   inline size_t output_stride() const {
55     if (this->output_stride_ == 0) {
56       return this->channels_;
57     } else {
58       assert(this->output_stride_ >= this->channels_);
59       return this->output_stride_;
60     }
61   }
62 
batch_size(size_t batch_size)63   inline NegateOperatorTester& batch_size(size_t batch_size) {
64     assert(batch_size != 0);
65     this->batch_size_ = batch_size;
66     return *this;
67   }
68 
batch_size()69   inline size_t batch_size() const {
70     return this->batch_size_;
71   }
72 
iterations(size_t iterations)73   inline NegateOperatorTester& iterations(size_t iterations) {
74     this->iterations_ = iterations;
75     return *this;
76   }
77 
iterations()78   inline size_t iterations() const {
79     return this->iterations_;
80   }
81 
TestF32()82   void TestF32() const {
83     std::random_device random_device;
84     auto rng = std::mt19937(random_device());
85     auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
86 
87     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
88       (batch_size() - 1) * input_stride() + channels());
89     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
90     std::vector<float> output_ref(batch_size() * channels());
91     for (size_t iteration = 0; iteration < iterations(); iteration++) {
92       std::generate(input.begin(), input.end(), std::ref(f32rng));
93       std::fill(output.begin(), output.end(), std::nanf(""));
94 
95       // Compute reference results.
96       for (size_t i = 0; i < batch_size(); i++) {
97         for (size_t c = 0; c < channels(); c++) {
98           output_ref[i * channels() + c] = -input[i * input_stride() + c];
99         }
100       }
101 
102       // Create, setup, run, and destroy Negate operator.
103       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
104       xnn_operator_t negate_op = nullptr;
105 
106       ASSERT_EQ(xnn_status_success,
107         xnn_create_negate_nc_f32(
108           channels(), input_stride(), output_stride(),
109           0, &negate_op));
110       ASSERT_NE(nullptr, negate_op);
111 
112       // Smart pointer to automatically delete negate_op.
113       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_negate_op(negate_op, xnn_delete_operator);
114 
115       ASSERT_EQ(xnn_status_success,
116         xnn_setup_negate_nc_f32(
117           negate_op,
118           batch_size(),
119           input.data(), output.data(),
120           nullptr /* thread pool */));
121 
122       ASSERT_EQ(xnn_status_success,
123         xnn_run_operator(negate_op, nullptr /* thread pool */));
124 
125       // Verify results.
126       for (size_t i = 0; i < batch_size(); i++) {
127         for (size_t c = 0; c < channels(); c++) {
128           ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
129             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
130         }
131       }
132     }
133   }
134 
135  private:
136   size_t batch_size_{1};
137   size_t channels_{1};
138   size_t input_stride_{0};
139   size_t output_stride_{0};
140   size_t iterations_{15};
141 };
142