• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <functional>
16 #include <random>
17 #include <vector>
18 
19 #include <xnnpack.h>
20 
21 
22 class ELUOperatorTester {
23  public:
channels(size_t channels)24   inline ELUOperatorTester& channels(size_t channels) {
25     assert(channels != 0);
26     this->channels_ = channels;
27     return *this;
28   }
29 
channels()30   inline size_t channels() const {
31     return this->channels_;
32   }
33 
input_stride(size_t input_stride)34   inline ELUOperatorTester& input_stride(size_t input_stride) {
35     assert(input_stride != 0);
36     this->input_stride_ = input_stride;
37     return *this;
38   }
39 
input_stride()40   inline size_t input_stride() const {
41     if (this->input_stride_ == 0) {
42       return this->channels_;
43     } else {
44       assert(this->input_stride_ >= this->channels_);
45       return this->input_stride_;
46     }
47   }
48 
output_stride(size_t output_stride)49   inline ELUOperatorTester& output_stride(size_t output_stride) {
50     assert(output_stride != 0);
51     this->output_stride_ = output_stride;
52     return *this;
53   }
54 
output_stride()55   inline size_t output_stride() const {
56     if (this->output_stride_ == 0) {
57       return this->channels_;
58     } else {
59       assert(this->output_stride_ >= this->channels_);
60       return this->output_stride_;
61     }
62   }
63 
batch_size(size_t batch_size)64   inline ELUOperatorTester& batch_size(size_t batch_size) {
65     assert(batch_size != 0);
66     this->batch_size_ = batch_size;
67     return *this;
68   }
69 
batch_size()70   inline size_t batch_size() const {
71     return this->batch_size_;
72   }
73 
alpha(float alpha)74   inline ELUOperatorTester& alpha(float alpha) {
75     assert(alpha > 0.0f);
76     assert(alpha < 1.0f);
77     this->alpha_ = alpha;
78     return *this;
79   }
80 
alpha()81   inline float alpha() const {
82     return this->alpha_;
83   }
84 
iterations(size_t iterations)85   inline ELUOperatorTester& iterations(size_t iterations) {
86     this->iterations_ = iterations;
87     return *this;
88   }
89 
iterations()90   inline size_t iterations() const {
91     return this->iterations_;
92   }
93 
TestF32()94   void TestF32() const {
95     std::random_device random_device;
96     auto rng = std::mt19937(random_device());
97     auto f32rng = std::bind(std::uniform_real_distribution<float>(-20.0f, 20.0f), std::ref(rng));
98 
99     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
100     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
101     std::vector<double> output_ref(batch_size() * channels());
102     for (size_t iteration = 0; iteration < iterations(); iteration++) {
103       std::generate(input.begin(), input.end(), std::ref(f32rng));
104       std::fill(output.begin(), output.end(), std::nanf(""));
105 
106       // Compute reference results.
107       for (size_t i = 0; i < batch_size(); i++) {
108         for (size_t c = 0; c < channels(); c++) {
109           const double x = double(input[i * input_stride() + c]);
110           output_ref[i * channels() + c] = std::signbit(x) ? std::expm1(x) * alpha() : x;
111         }
112       }
113 
114       // Create, setup, run, and destroy ELU operator.
115       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
116       xnn_operator_t elu_op = nullptr;
117 
118       ASSERT_EQ(xnn_status_success,
119         xnn_create_elu_nc_f32(
120           channels(), input_stride(), output_stride(),
121           alpha(),
122           0, &elu_op));
123       ASSERT_NE(nullptr, elu_op);
124 
125       // Smart pointer to automatically delete elu_op.
126       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_elu_op(elu_op, xnn_delete_operator);
127 
128       ASSERT_EQ(xnn_status_success,
129         xnn_setup_elu_nc_f32(
130           elu_op,
131           batch_size(),
132           input.data(), output.data(),
133           nullptr /* thread pool */));
134 
135       ASSERT_EQ(xnn_status_success,
136         xnn_run_operator(elu_op, nullptr /* thread pool */));
137 
138       // Verify results.
139       for (size_t i = 0; i < batch_size(); i++) {
140         for (size_t c = 0; c < channels(); c++) {
141           ASSERT_NEAR(output[i * output_stride() + c],
142                       output_ref[i * channels() + c],
143                       std::abs(output_ref[i * channels() + c]) * 1.0e-5)
144             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
145             << ", input " << input[i * input_stride() + c] << ", alpha " << alpha();
146         }
147       }
148     }
149   }
150 
151  private:
152   size_t batch_size_{1};
153   size_t channels_{1};
154   size_t input_stride_{0};
155   size_t output_stride_{0};
156   float alpha_{0.5f};
157   size_t iterations_{15};
158 };
159