1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <fp16.h> 11 12 #include <algorithm> 13 #include <cmath> 14 #include <cstddef> 15 #include <cstdlib> 16 #include <functional> 17 #include <random> 18 #include <vector> 19 20 #include <xnnpack.h> 21 22 23 class PReLUOperatorTester { 24 public: 25 enum class WeightsType { 26 Default, 27 FP32, 28 }; 29 batch_size(size_t batch_size)30 inline PReLUOperatorTester& batch_size(size_t batch_size) { 31 assert(batch_size != 0); 32 this->batch_size_ = batch_size; 33 return *this; 34 } 35 batch_size()36 inline size_t batch_size() const { 37 return this->batch_size_; 38 } 39 channels(size_t channels)40 inline PReLUOperatorTester& channels(size_t channels) { 41 assert(channels != 0); 42 this->channels_ = channels; 43 return *this; 44 } 45 channels()46 inline size_t channels() const { 47 return this->channels_; 48 } 49 x_stride(size_t x_stride)50 inline PReLUOperatorTester& x_stride(size_t x_stride) { 51 assert(x_stride != 0); 52 this->x_stride_ = x_stride; 53 return *this; 54 } 55 x_stride()56 inline size_t x_stride() const { 57 if (this->x_stride_ == 0) { 58 return this->channels_; 59 } else { 60 assert(this->x_stride_ >= this->channels_); 61 return this->x_stride_; 62 } 63 } 64 y_stride(size_t y_stride)65 inline PReLUOperatorTester& y_stride(size_t y_stride) { 66 assert(y_stride != 0); 67 this->y_stride_ = y_stride; 68 return *this; 69 } 70 y_stride()71 inline size_t y_stride() const { 72 if (this->y_stride_ == 0) { 73 return this->channels_; 74 } else { 75 assert(this->y_stride_ >= this->channels_); 76 return this->y_stride_; 77 } 78 } 79 weights_type(WeightsType weights_type)80 inline PReLUOperatorTester& weights_type(WeightsType weights_type) { 81 this->weights_type_ = weights_type; 82 return *this; 83 } 84 weights_type()85 inline WeightsType weights_type() const { 86 return this->weights_type_; 87 } 88 iterations(size_t iterations)89 inline PReLUOperatorTester& iterations(size_t iterations) { 90 this->iterations_ = iterations; 91 return *this; 92 } 93 iterations()94 inline size_t iterations() const { 95 return this->iterations_; 96 } 97 TestF16()98 void TestF16() const { 99 switch (weights_type()) { 100 case WeightsType::Default: 101 break; 102 case WeightsType::FP32: 103 break; 104 default: 105 GTEST_FAIL() << "unexpected weights type"; 106 } 107 108 std::random_device random_device; 109 auto rng = std::mt19937(random_device()); 110 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng); 111 auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng); 112 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng); 113 auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng); 114 115 std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 116 std::vector<uint16_t> w(channels()); 117 std::vector<float> w_as_float(channels()); 118 std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 119 std::vector<float> y_ref(batch_size() * channels()); 120 for (size_t iteration = 0; iteration < iterations(); iteration++) { 121 std::generate(x.begin(), x.end(), std::ref(f16irng)); 122 std::generate(w.begin(), w.end(), std::ref(f16wrng)); 123 std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value); 124 std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 125 126 // Compute reference results, without clamping. 127 for (size_t i = 0; i < batch_size(); i++) { 128 for (size_t c = 0; c < channels(); c++) { 129 const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]); 130 const float w_value = w_as_float[c]; 131 y_ref[i * channels() + c] = signbit(x_value) ? x_value * w_value : x_value; 132 } 133 } 134 135 // Create, setup, run, and destroy PReLU operator. 136 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 137 xnn_operator_t prelu_op = nullptr; 138 139 const void* negative_slope_data = w.data(); 140 if (weights_type() == WeightsType::FP32) { 141 negative_slope_data = w_as_float.data(); 142 } 143 uint32_t flags = 0; 144 if (weights_type() == WeightsType::FP32) { 145 flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 146 } 147 ASSERT_EQ(xnn_status_success, 148 xnn_create_prelu_nc_f16( 149 channels(), x_stride(), y_stride(), 150 negative_slope_data, 151 flags, &prelu_op)); 152 ASSERT_NE(nullptr, prelu_op); 153 154 // Smart pointer to automatically delete prelu_op. 155 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 156 157 ASSERT_EQ(xnn_status_success, 158 xnn_setup_prelu_nc_f16( 159 prelu_op, 160 batch_size(), 161 x.data(), y.data(), 162 nullptr /* thread pool */)); 163 164 ASSERT_EQ(xnn_status_success, 165 xnn_run_operator(prelu_op, nullptr /* thread pool */)); 166 167 // Verify results. 168 for (size_t i = 0; i < batch_size(); i++) { 169 for (size_t c = 0; c < channels(); c++) { 170 ASSERT_NEAR( 171 fp16_ieee_to_fp32_value(y[i * y_stride() + c]), 172 y_ref[i * channels() + c], 173 std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-4f)) 174 << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 175 } 176 } 177 } 178 } 179 TestF32()180 void TestF32() const { 181 ASSERT_EQ(weights_type(), WeightsType::Default); 182 183 std::random_device random_device; 184 auto rng = std::mt19937(random_device()); 185 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng); 186 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng); 187 188 std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 189 std::vector<float> w(channels()); 190 std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 191 std::vector<float> y_ref(batch_size() * channels()); 192 for (size_t iteration = 0; iteration < iterations(); iteration++) { 193 std::generate(x.begin(), x.end(), std::ref(f32irng)); 194 std::generate(w.begin(), w.end(), std::ref(f32wrng)); 195 std::fill(y.begin(), y.end(), nanf("")); 196 197 // Compute reference results, without clamping. 198 for (size_t i = 0; i < batch_size(); i++) { 199 for (size_t c = 0; c < channels(); c++) { 200 y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c]; 201 } 202 } 203 204 // Create, setup, run, and destroy PReLU operator. 205 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 206 xnn_operator_t prelu_op = nullptr; 207 208 ASSERT_EQ(xnn_status_success, 209 xnn_create_prelu_nc_f32( 210 channels(), x_stride(), y_stride(), 211 w.data(), 212 0, &prelu_op)); 213 ASSERT_NE(nullptr, prelu_op); 214 215 // Smart pointer to automatically delete prelu_op. 216 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 217 218 ASSERT_EQ(xnn_status_success, 219 xnn_setup_prelu_nc_f32( 220 prelu_op, 221 batch_size(), 222 x.data(), y.data(), 223 nullptr /* thread pool */)); 224 225 ASSERT_EQ(xnn_status_success, 226 xnn_run_operator(prelu_op, nullptr /* thread pool */)); 227 228 // Verify results. 229 for (size_t i = 0; i < batch_size(); i++) { 230 for (size_t c = 0; c < channels(); c++) { 231 ASSERT_NEAR( 232 y[i * y_stride() + c], 233 y_ref[i * channels() + c], 234 std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f)) 235 << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 236 } 237 } 238 } 239 } 240 241 private: 242 size_t batch_size_{1}; 243 size_t channels_{1}; 244 size_t x_stride_{0}; 245 size_t y_stride_{0}; 246 WeightsType weights_type_{WeightsType::Default}; 247 size_t iterations_{15}; 248 }; 249