• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cstddef>
13 #include <cstdlib>
14 #include <functional>
15 #include <random>
16 #include <vector>
17 
18 #include <xnnpack.h>
19 #include <xnnpack/params-init.h>
20 #include <xnnpack/params.h>
21 
22 
23 class VUnOpMicrokernelTester {
24  public:
25   enum class OpType {
26     Abs,
27     ELU,
28     LeakyReLU,
29     Negate,
30     ReLU,
31     RoundToNearestEven,
32     RoundTowardsZero,
33     RoundUp,
34     RoundDown,
35     Square,
36     SquareRoot,
37     Sigmoid,
38   };
39 
40   enum class Variant {
41     Native,
42     Scalar,
43   };
44 
batch_size(size_t batch_size)45   inline VUnOpMicrokernelTester& batch_size(size_t batch_size) {
46     assert(batch_size != 0);
47     this->batch_size_ = batch_size;
48     return *this;
49   }
50 
batch_size()51   inline size_t batch_size() const {
52     return this->batch_size_;
53   }
54 
inplace(bool inplace)55   inline VUnOpMicrokernelTester& inplace(bool inplace) {
56     this->inplace_ = inplace;
57     return *this;
58   }
59 
inplace()60   inline bool inplace() const {
61     return this->inplace_;
62   }
63 
slope(float slope)64   inline VUnOpMicrokernelTester& slope(float slope) {
65     this->slope_ = slope;
66     return *this;
67   }
68 
slope()69   inline float slope() const {
70     return this->slope_;
71   }
72 
prescale(float prescale)73   inline VUnOpMicrokernelTester& prescale(float prescale) {
74     this->prescale_ = prescale;
75     return *this;
76   }
77 
prescale()78   inline float prescale() const {
79     return this->prescale_;
80   }
81 
alpha(float alpha)82   inline VUnOpMicrokernelTester& alpha(float alpha) {
83     this->alpha_ = alpha;
84     return *this;
85   }
86 
alpha()87   inline float alpha() const {
88     return this->alpha_;
89   }
90 
beta(float beta)91   inline VUnOpMicrokernelTester& beta(float beta) {
92     this->beta_ = beta;
93     return *this;
94   }
95 
beta()96   inline float beta() const {
97     return this->beta_;
98   }
99 
qmin(uint8_t qmin)100   inline VUnOpMicrokernelTester& qmin(uint8_t qmin) {
101     this->qmin_ = qmin;
102     return *this;
103   }
104 
qmin()105   inline uint8_t qmin() const {
106     return this->qmin_;
107   }
108 
qmax(uint8_t qmax)109   inline VUnOpMicrokernelTester& qmax(uint8_t qmax) {
110     this->qmax_ = qmax;
111     return *this;
112   }
113 
qmax()114   inline uint8_t qmax() const {
115     return this->qmax_;
116   }
117 
iterations(size_t iterations)118   inline VUnOpMicrokernelTester& iterations(size_t iterations) {
119     this->iterations_ = iterations;
120     return *this;
121   }
122 
iterations()123   inline size_t iterations() const {
124     return this->iterations_;
125   }
126 
127   void Test(xnn_f32_vunary_ukernel_function vunary, OpType op_type, Variant variant = Variant::Native) const {
128     std::random_device random_device;
129     auto rng = std::mt19937(random_device());
130     auto distribution = std::uniform_real_distribution<float>(-125.0f, 125.0f);
131     switch (op_type) {
132       case OpType::ELU:
133         distribution = std::uniform_real_distribution<float>(-20.0f, 20.0f);
134         break;
135       case OpType::SquareRoot:
136         distribution = std::uniform_real_distribution<float>(0.0f, 10.0f);
137         break;
138       default:
139         break;
140     }
141     auto f32rng = std::bind(distribution, std::ref(rng));
142 
143     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
144     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
145     std::vector<double> y_ref(batch_size());
146     for (size_t iteration = 0; iteration < iterations(); iteration++) {
147       if (inplace()) {
148         std::generate(y.begin(), y.end(), std::ref(f32rng));
149       } else {
150         std::generate(x.begin(), x.end(), std::ref(f32rng));
151         std::fill(y.begin(), y.end(), nanf(""));
152       }
153       const float* x_data = inplace() ? y.data() : x.data();
154 
155       // Compute reference results.
156       for (size_t i = 0; i < batch_size(); i++) {
157         switch (op_type) {
158           case OpType::Abs:
159             y_ref[i] = std::abs(x_data[i]);
160             break;
161           case OpType::ELU:
162           {
163             y_ref[i] = std::signbit(x_data[i]) ? alpha() * std::expm1(double(x_data[i]) * prescale()) : double(x_data[i]) * beta();
164             break;
165           }
166           case OpType::LeakyReLU:
167             y_ref[i] = std::signbit(x_data[i]) ? x_data[i] * slope() : x_data[i];
168             break;
169           case OpType::Negate:
170             y_ref[i] = -x_data[i];
171             break;
172           case OpType::ReLU:
173             y_ref[i] = std::max(x_data[i], 0.0f);
174             break;
175           case OpType::RoundToNearestEven:
176             y_ref[i] = std::nearbyint(double(x_data[i]));
177             break;
178           case OpType::RoundTowardsZero:
179             y_ref[i] = std::trunc(double(x_data[i]));
180             break;
181           case OpType::RoundUp:
182             y_ref[i] = std::ceil(double(x_data[i]));
183             break;
184           case OpType::RoundDown:
185             y_ref[i] = std::floor(double(x_data[i]));
186             break;
187           case OpType::Square:
188             y_ref[i] = double(x_data[i]) * double(x_data[i]);
189             break;
190           case OpType::SquareRoot:
191             y_ref[i] = std::sqrt(double(x_data[i]));
192             break;
193           case OpType::Sigmoid:
194           {
195             const double e = std::exp(double(x_data[i]));
196             y_ref[i] = e / (1.0 + e);
197             break;
198           }
199         }
200       }
201 
202       // Prepare parameters.
203       union {
204         union xnn_f32_abs_params abs;
205         union xnn_f32_elu_params elu;
206         union xnn_f32_relu_params relu;
207         union xnn_f32_lrelu_params lrelu;
208         union xnn_f32_neg_params neg;
209         union xnn_f32_rnd_params rnd;
210         union xnn_f32_sqrt_params sqrt;
211       } params;
212       switch (op_type) {
213         case OpType::Abs:
214           switch (variant) {
215             case Variant::Native:
216               params.abs = xnn_init_f32_abs_params();
217               break;
218             case Variant::Scalar:
219               params.abs = xnn_init_scalar_f32_abs_params();
220               break;
221           }
222           break;
223         case OpType::ELU:
224           switch (variant) {
225             case Variant::Native:
226               params.elu = xnn_init_f32_elu_params(prescale(), alpha(), beta());
227               break;
228             case Variant::Scalar:
229               params.elu = xnn_init_scalar_f32_elu_params(prescale(), alpha(), beta());
230               break;
231           }
232           break;
233         case OpType::LeakyReLU:
234           switch (variant) {
235             case Variant::Native:
236               params.lrelu = xnn_init_f32_lrelu_params(slope());
237               break;
238             case Variant::Scalar:
239               params.lrelu = xnn_init_scalar_f32_lrelu_params(slope());
240               break;
241           }
242           break;
243         case OpType::Negate:
244           switch (variant) {
245             case Variant::Native:
246               params.neg = xnn_init_f32_neg_params();
247               break;
248             case Variant::Scalar:
249               params.neg = xnn_init_scalar_f32_neg_params();
250               break;
251           }
252           break;
253         case OpType::RoundToNearestEven:
254         case OpType::RoundTowardsZero:
255         case OpType::RoundUp:
256         case OpType::RoundDown:
257           switch (variant) {
258             case Variant::Native:
259               params.rnd = xnn_init_f32_rnd_params();
260               break;
261             case Variant::Scalar:
262               params.rnd = xnn_init_scalar_f32_rnd_params();
263               break;
264           }
265           break;
266         case OpType::ReLU:
267         case OpType::Sigmoid:
268         case OpType::Square:
269           break;
270         case OpType::SquareRoot:
271           switch (variant) {
272             case Variant::Native:
273               params.sqrt = xnn_init_f32_sqrt_params();
274               break;
275             case Variant::Scalar:
276               params.sqrt = xnn_init_scalar_f32_sqrt_params();
277               break;
278           }
279           break;
280       }
281 
282       // Call optimized micro-kernel.
283       vunary(batch_size() * sizeof(float), x_data, y.data(), &params);
284 
285       // Verify results.
286       for (size_t i = 0; i < batch_size(); i++) {
287         ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5))
288           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
289       }
290     }
291   }
292 
293  private:
294   size_t batch_size_ = 1;
295   bool inplace_ = false;
296   float slope_ = 0.5f;
297   float prescale_ = 1.0f;
298   float alpha_ = 1.0f;
299   float beta_ = 1.0f;
300   uint8_t qmin_ = 0;
301   uint8_t qmax_ = 255;
302   size_t iterations_ = 15;
303 };
304