• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cstddef>
13 #include <cstdlib>
14 #include <functional>
15 #include <random>
16 #include <vector>
17 
18 #include <fp16.h>
19 
20 #include <xnnpack.h>
21 
22 
23 class HardSwishOperatorTester {
24  public:
channels(size_t channels)25   inline HardSwishOperatorTester& channels(size_t channels) {
26     assert(channels != 0);
27     this->channels_ = channels;
28     return *this;
29   }
30 
channels()31   inline size_t channels() const {
32     return this->channels_;
33   }
34 
input_stride(size_t input_stride)35   inline HardSwishOperatorTester& input_stride(size_t input_stride) {
36     assert(input_stride != 0);
37     this->input_stride_ = input_stride;
38     return *this;
39   }
40 
input_stride()41   inline size_t input_stride() const {
42     if (this->input_stride_ == 0) {
43       return this->channels_;
44     } else {
45       assert(this->input_stride_ >= this->channels_);
46       return this->input_stride_;
47     }
48   }
49 
output_stride(size_t output_stride)50   inline HardSwishOperatorTester& output_stride(size_t output_stride) {
51     assert(output_stride != 0);
52     this->output_stride_ = output_stride;
53     return *this;
54   }
55 
output_stride()56   inline size_t output_stride() const {
57     if (this->output_stride_ == 0) {
58       return this->channels_;
59     } else {
60       assert(this->output_stride_ >= this->channels_);
61       return this->output_stride_;
62     }
63   }
64 
batch_size(size_t batch_size)65   inline HardSwishOperatorTester& batch_size(size_t batch_size) {
66     assert(batch_size != 0);
67     this->batch_size_ = batch_size;
68     return *this;
69   }
70 
batch_size()71   inline size_t batch_size() const {
72     return this->batch_size_;
73   }
74 
iterations(size_t iterations)75   inline HardSwishOperatorTester& iterations(size_t iterations) {
76     this->iterations_ = iterations;
77     return *this;
78   }
79 
iterations()80   inline size_t iterations() const {
81     return this->iterations_;
82   }
83 
TestF16()84   void TestF16() const {
85     std::random_device random_device;
86     auto rng = std::mt19937(random_device());
87     auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng);
88     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
89 
90     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
91       (batch_size() - 1) * input_stride() + channels());
92     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
93     std::vector<float> output_ref(batch_size() * channels());
94     for (size_t iteration = 0; iteration < iterations(); iteration++) {
95       std::generate(input.begin(), input.end(), std::ref(f16rng));
96       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
97 
98       // Compute reference results.
99       for (size_t i = 0; i < batch_size(); i++) {
100         for (size_t c = 0; c < channels(); c++) {
101           const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]);
102           const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
103           output_ref[i * channels() + c] = y;
104         }
105       }
106 
107       // Create, setup, run, and destroy HardSwish operator.
108       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
109       xnn_operator_t hardswish_op = nullptr;
110       xnn_status status = xnn_create_hardswish_nc_f16(
111           channels(), input_stride(), output_stride(),
112           0, &hardswish_op);
113       if (status == xnn_status_unsupported_hardware) {
114         GTEST_SKIP();
115       }
116       ASSERT_NE(nullptr, hardswish_op);
117 
118       // Smart pointer to automatically delete hardswish_op.
119       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator);
120 
121       ASSERT_EQ(xnn_status_success,
122         xnn_setup_hardswish_nc_f16(
123           hardswish_op,
124           batch_size(),
125           input.data(), output.data(),
126           nullptr /* thread pool */));
127 
128       ASSERT_EQ(xnn_status_success,
129         xnn_run_operator(hardswish_op, nullptr /* thread pool */));
130 
131       // Verify results.
132       for (size_t i = 0; i < batch_size(); i++) {
133         for (size_t c = 0; c < channels(); c++) {
134           ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-3f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f))
135             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
136         }
137       }
138     }
139   }
140 
TestF32()141   void TestF32() const {
142     std::random_device random_device;
143     auto rng = std::mt19937(random_device());
144     auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng);
145 
146     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
147       (batch_size() - 1) * input_stride() + channels());
148     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
149     std::vector<float> output_ref(batch_size() * channels());
150     for (size_t iteration = 0; iteration < iterations(); iteration++) {
151       std::generate(input.begin(), input.end(), std::ref(f32rng));
152       std::fill(output.begin(), output.end(), std::nanf(""));
153 
154       // Compute reference results.
155       for (size_t i = 0; i < batch_size(); i++) {
156         for (size_t c = 0; c < channels(); c++) {
157           const float x = input[i * input_stride() + c];
158           const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
159           output_ref[i * channels() + c] = y;
160         }
161       }
162 
163       // Create, setup, run, and destroy HardSwish operator.
164       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
165       xnn_operator_t hardswish_op = nullptr;
166 
167       ASSERT_EQ(xnn_status_success,
168         xnn_create_hardswish_nc_f32(
169           channels(), input_stride(), output_stride(),
170           0, &hardswish_op));
171       ASSERT_NE(nullptr, hardswish_op);
172 
173       // Smart pointer to automatically delete hardswish_op.
174       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator);
175 
176       ASSERT_EQ(xnn_status_success,
177         xnn_setup_hardswish_nc_f32(
178           hardswish_op,
179           batch_size(),
180           input.data(), output.data(),
181           nullptr /* thread pool */));
182 
183       ASSERT_EQ(xnn_status_success,
184         xnn_run_operator(hardswish_op, nullptr /* thread pool */));
185 
186       // Verify results.
187       for (size_t i = 0; i < batch_size(); i++) {
188         for (size_t c = 0; c < channels(); c++) {
189           ASSERT_NEAR(output_ref[i * channels() + c], output[i * output_stride() + c], std::max(1.0e-7f, std::abs(output[i * output_stride() + c]) * 1.0e-6f))
190             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
191         }
192       }
193     }
194   }
195 
196  private:
197   size_t batch_size_{1};
198   size_t channels_{1};
199   size_t input_stride_{0};
200   size_t output_stride_{0};
201   size_t iterations_{15};
202 };
203