1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <algorithm> 11 #include <cassert> 12 #include <cstddef> 13 #include <cstdlib> 14 #include <functional> 15 #include <random> 16 #include <vector> 17 18 #include <fp16.h> 19 20 #include <xnnpack.h> 21 22 23 class HardSwishOperatorTester { 24 public: channels(size_t channels)25 inline HardSwishOperatorTester& channels(size_t channels) { 26 assert(channels != 0); 27 this->channels_ = channels; 28 return *this; 29 } 30 channels()31 inline size_t channels() const { 32 return this->channels_; 33 } 34 input_stride(size_t input_stride)35 inline HardSwishOperatorTester& input_stride(size_t input_stride) { 36 assert(input_stride != 0); 37 this->input_stride_ = input_stride; 38 return *this; 39 } 40 input_stride()41 inline size_t input_stride() const { 42 if (this->input_stride_ == 0) { 43 return this->channels_; 44 } else { 45 assert(this->input_stride_ >= this->channels_); 46 return this->input_stride_; 47 } 48 } 49 output_stride(size_t output_stride)50 inline HardSwishOperatorTester& output_stride(size_t output_stride) { 51 assert(output_stride != 0); 52 this->output_stride_ = output_stride; 53 return *this; 54 } 55 output_stride()56 inline size_t output_stride() const { 57 if (this->output_stride_ == 0) { 58 return this->channels_; 59 } else { 60 assert(this->output_stride_ >= this->channels_); 61 return this->output_stride_; 62 } 63 } 64 batch_size(size_t batch_size)65 inline HardSwishOperatorTester& batch_size(size_t batch_size) { 66 assert(batch_size != 0); 67 this->batch_size_ = batch_size; 68 return *this; 69 } 70 batch_size()71 inline size_t batch_size() const { 72 return this->batch_size_; 73 } 74 iterations(size_t iterations)75 inline HardSwishOperatorTester& iterations(size_t iterations) { 76 this->iterations_ = iterations; 77 return *this; 78 } 79 iterations()80 inline size_t iterations() const { 81 return this->iterations_; 82 } 83 TestF16()84 void TestF16() const { 85 std::random_device random_device; 86 auto rng = std::mt19937(random_device()); 87 auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng); 88 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); 89 90 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 91 (batch_size() - 1) * input_stride() + channels()); 92 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 93 std::vector<float> output_ref(batch_size() * channels()); 94 for (size_t iteration = 0; iteration < iterations(); iteration++) { 95 std::generate(input.begin(), input.end(), std::ref(f16rng)); 96 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 97 98 // Compute reference results. 99 for (size_t i = 0; i < batch_size(); i++) { 100 for (size_t c = 0; c < channels(); c++) { 101 const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 102 const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f; 103 output_ref[i * channels() + c] = y; 104 } 105 } 106 107 // Create, setup, run, and destroy HardSwish operator. 108 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 109 xnn_operator_t hardswish_op = nullptr; 110 xnn_status status = xnn_create_hardswish_nc_f16( 111 channels(), input_stride(), output_stride(), 112 0, &hardswish_op); 113 if (status == xnn_status_unsupported_hardware) { 114 GTEST_SKIP(); 115 } 116 ASSERT_NE(nullptr, hardswish_op); 117 118 // Smart pointer to automatically delete hardswish_op. 119 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator); 120 121 ASSERT_EQ(xnn_status_success, 122 xnn_setup_hardswish_nc_f16( 123 hardswish_op, 124 batch_size(), 125 input.data(), output.data(), 126 nullptr /* thread pool */)); 127 128 ASSERT_EQ(xnn_status_success, 129 xnn_run_operator(hardswish_op, nullptr /* thread pool */)); 130 131 // Verify results. 132 for (size_t i = 0; i < batch_size(); i++) { 133 for (size_t c = 0; c < channels(); c++) { 134 ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-3f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f)) 135 << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels(); 136 } 137 } 138 } 139 } 140 TestF32()141 void TestF32() const { 142 std::random_device random_device; 143 auto rng = std::mt19937(random_device()); 144 auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng); 145 146 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 147 (batch_size() - 1) * input_stride() + channels()); 148 std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 149 std::vector<float> output_ref(batch_size() * channels()); 150 for (size_t iteration = 0; iteration < iterations(); iteration++) { 151 std::generate(input.begin(), input.end(), std::ref(f32rng)); 152 std::fill(output.begin(), output.end(), std::nanf("")); 153 154 // Compute reference results. 155 for (size_t i = 0; i < batch_size(); i++) { 156 for (size_t c = 0; c < channels(); c++) { 157 const float x = input[i * input_stride() + c]; 158 const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f; 159 output_ref[i * channels() + c] = y; 160 } 161 } 162 163 // Create, setup, run, and destroy HardSwish operator. 164 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 165 xnn_operator_t hardswish_op = nullptr; 166 167 ASSERT_EQ(xnn_status_success, 168 xnn_create_hardswish_nc_f32( 169 channels(), input_stride(), output_stride(), 170 0, &hardswish_op)); 171 ASSERT_NE(nullptr, hardswish_op); 172 173 // Smart pointer to automatically delete hardswish_op. 174 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator); 175 176 ASSERT_EQ(xnn_status_success, 177 xnn_setup_hardswish_nc_f32( 178 hardswish_op, 179 batch_size(), 180 input.data(), output.data(), 181 nullptr /* thread pool */)); 182 183 ASSERT_EQ(xnn_status_success, 184 xnn_run_operator(hardswish_op, nullptr /* thread pool */)); 185 186 // Verify results. 187 for (size_t i = 0; i < batch_size(); i++) { 188 for (size_t c = 0; c < channels(); c++) { 189 ASSERT_NEAR(output_ref[i * channels() + c], output[i * output_stride() + c], std::max(1.0e-7f, std::abs(output[i * output_stride() + c]) * 1.0e-6f)) 190 << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels(); 191 } 192 } 193 } 194 } 195 196 private: 197 size_t batch_size_{1}; 198 size_t channels_{1}; 199 size_t input_stride_{0}; 200 size_t output_stride_{0}; 201 size_t iterations_{15}; 202 }; 203