1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <algorithm> 14 #include <cmath> 15 #include <cstddef> 16 #include <cstdlib> 17 #include <functional> 18 #include <random> 19 #include <vector> 20 21 #include <xnnpack.h> 22 23 24 class AddOperatorTester { 25 public: channels(size_t channels)26 inline AddOperatorTester& channels(size_t channels) { 27 assert(channels != 0); 28 this->channels_ = channels; 29 return *this; 30 } 31 channels()32 inline size_t channels() const { 33 return this->channels_; 34 } 35 a_stride(size_t a_stride)36 inline AddOperatorTester& a_stride(size_t a_stride) { 37 assert(a_stride != 0); 38 this->a_stride_ = a_stride; 39 return *this; 40 } 41 a_stride()42 inline size_t a_stride() const { 43 if (this->a_stride_ == 0) { 44 return this->channels_; 45 } else { 46 assert(this->a_stride_ >= this->channels_); 47 return this->a_stride_; 48 } 49 } 50 b_stride(size_t b_stride)51 inline AddOperatorTester& b_stride(size_t b_stride) { 52 assert(b_stride != 0); 53 this->b_stride_ = b_stride; 54 return *this; 55 } 56 b_stride()57 inline size_t b_stride() const { 58 if (this->b_stride_ == 0) { 59 return this->channels_; 60 } else { 61 assert(this->b_stride_ >= this->channels_); 62 return this->b_stride_; 63 } 64 } 65 y_stride(size_t y_stride)66 inline AddOperatorTester& y_stride(size_t y_stride) { 67 assert(y_stride != 0); 68 this->y_stride_ = y_stride; 69 return *this; 70 } 71 y_stride()72 inline size_t y_stride() const { 73 if (this->y_stride_ == 0) { 74 return this->channels_; 75 } else { 76 assert(this->y_stride_ >= this->channels_); 77 return this->y_stride_; 78 } 79 } 80 batch_size(size_t batch_size)81 inline AddOperatorTester& batch_size(size_t batch_size) { 82 assert(batch_size != 0); 83 this->batch_size_ = batch_size; 84 return *this; 85 } 86 batch_size()87 inline size_t batch_size() const { 88 return this->batch_size_; 89 } 90 a_scale(float a_scale)91 inline AddOperatorTester& a_scale(float a_scale) { 92 assert(a_scale > 0.0f); 93 assert(std::isnormal(a_scale)); 94 this->a_scale_ = a_scale; 95 return *this; 96 } 97 a_scale()98 inline float a_scale() const { 99 return this->a_scale_; 100 } 101 a_zero_point(uint8_t a_zero_point)102 inline AddOperatorTester& a_zero_point(uint8_t a_zero_point) { 103 this->a_zero_point_ = a_zero_point; 104 return *this; 105 } 106 a_zero_point()107 inline uint8_t a_zero_point() const { 108 return this->a_zero_point_; 109 } 110 b_scale(float b_scale)111 inline AddOperatorTester& b_scale(float b_scale) { 112 assert(b_scale > 0.0f); 113 assert(std::isnormal(b_scale)); 114 this->b_scale_ = b_scale; 115 return *this; 116 } 117 b_scale()118 inline float b_scale() const { 119 return this->b_scale_; 120 } 121 b_zero_point(uint8_t b_zero_point)122 inline AddOperatorTester& b_zero_point(uint8_t b_zero_point) { 123 this->b_zero_point_ = b_zero_point; 124 return *this; 125 } 126 b_zero_point()127 inline uint8_t b_zero_point() const { 128 return this->b_zero_point_; 129 } 130 y_scale(float y_scale)131 inline AddOperatorTester& y_scale(float y_scale) { 132 assert(y_scale > 0.0f); 133 assert(std::isnormal(y_scale)); 134 this->y_scale_ = y_scale; 135 return *this; 136 } 137 y_scale()138 inline float y_scale() const { 139 return this->y_scale_; 140 } 141 y_zero_point(uint8_t y_zero_point)142 inline AddOperatorTester& y_zero_point(uint8_t y_zero_point) { 143 this->y_zero_point_ = y_zero_point; 144 return *this; 145 } 146 y_zero_point()147 inline uint8_t y_zero_point() const { 148 return this->y_zero_point_; 149 } 150 qmin(uint8_t qmin)151 inline AddOperatorTester& qmin(uint8_t qmin) { 152 this->qmin_ = qmin; 153 return *this; 154 } 155 qmin()156 inline uint8_t qmin() const { 157 return this->qmin_; 158 } 159 qmax(uint8_t qmax)160 inline AddOperatorTester& qmax(uint8_t qmax) { 161 this->qmax_ = qmax; 162 return *this; 163 } 164 qmax()165 inline uint8_t qmax() const { 166 return this->qmax_; 167 } 168 iterations(size_t iterations)169 inline AddOperatorTester& iterations(size_t iterations) { 170 this->iterations_ = iterations; 171 return *this; 172 } 173 iterations()174 inline size_t iterations() const { 175 return this->iterations_; 176 } 177 TestQ8()178 void TestQ8() const { 179 std::random_device random_device; 180 auto rng = std::mt19937(random_device()); 181 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng); 182 183 std::vector<uint8_t> a(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * a_stride() + channels()); 184 std::vector<uint8_t> b(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * b_stride() + channels()); 185 std::vector<uint8_t> y((batch_size() - 1) * y_stride() + channels()); 186 std::vector<float> y_ref(batch_size() * channels()); 187 for (size_t iteration = 0; iteration < iterations(); iteration++) { 188 std::generate(a.begin(), a.end(), std::ref(u8rng)); 189 std::generate(b.begin(), b.end(), std::ref(u8rng)); 190 std::fill(y.begin(), y.end(), 0xA5); 191 192 if (batch_size() * channels() > 3) { 193 ASSERT_NE(*std::max_element(a.cbegin(), a.cend()), *std::min_element(a.cbegin(), a.cend())); 194 ASSERT_NE(*std::max_element(b.cbegin(), b.cend()), *std::min_element(b.cbegin(), b.cend())); 195 } 196 197 // Compute reference results. 198 for (size_t i = 0; i < batch_size(); i++) { 199 for (size_t c = 0; c < channels(); c++) { 200 y_ref[i * channels() + c] = float(y_zero_point()) + 201 float(int32_t(a[i * a_stride() + c]) - int32_t(a_zero_point())) * (a_scale() / y_scale()) + 202 float(int32_t(b[i * b_stride() + c]) - int32_t(b_zero_point())) * (b_scale() / y_scale()); 203 y_ref[i * channels() + c] = std::min<float>(y_ref[i * channels() + c], float(qmax())); 204 y_ref[i * channels() + c] = std::max<float>(y_ref[i * channels() + c], float(qmin())); 205 } 206 } 207 208 // Create, setup, run, and destroy Add operator. 209 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 210 xnn_operator_t add_op = nullptr; 211 212 ASSERT_EQ(xnn_status_success, 213 xnn_create_add_nc_q8( 214 channels(), a_stride(), b_stride(), y_stride(), 215 a_zero_point(), a_scale(), 216 b_zero_point(), b_scale(), 217 y_zero_point(), y_scale(), 218 qmin(), qmax(), 219 0, &add_op)); 220 ASSERT_NE(nullptr, add_op); 221 222 // Smart pointer to automatically delete add_op. 223 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_add_op(add_op, xnn_delete_operator); 224 225 ASSERT_EQ(xnn_status_success, 226 xnn_setup_add_nc_q8( 227 add_op, 228 batch_size(), 229 a.data(), b.data(), y.data(), 230 nullptr /* thread pool */)); 231 232 ASSERT_EQ(xnn_status_success, 233 xnn_run_operator(add_op, nullptr /* thread pool */)); 234 235 /// Verify results. 236 for (size_t i = 0; i < batch_size(); i++) { 237 for (size_t c = 0; c < channels(); c++) { 238 ASSERT_LE(uint32_t(y[i * y_stride() + c]), uint32_t(qmax())); 239 ASSERT_GE(uint32_t(y[i * y_stride() + c]), uint32_t(qmin())); 240 ASSERT_NEAR(float(int32_t(y[i * y_stride() + c])), y_ref[i * channels() + c], 0.6f); 241 } 242 } 243 } 244 } 245 TestF32()246 void TestF32() const { 247 std::random_device random_device; 248 auto rng = std::mt19937(random_device()); 249 auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 1.0f), rng); 250 251 std::vector<float> a(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * a_stride() + channels()); 252 std::vector<float> b(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * b_stride() + channels()); 253 std::vector<float> y((batch_size() - 1) * y_stride() + channels()); 254 std::vector<float> y_ref(batch_size() * channels()); 255 for (size_t iteration = 0; iteration < iterations(); iteration++) { 256 std::generate(a.begin(), a.end(), std::ref(f32rng)); 257 std::generate(b.begin(), b.end(), std::ref(f32rng)); 258 std::fill(y.begin(), y.end(), nanf("")); 259 260 // Compute reference results. 261 for (size_t i = 0; i < batch_size(); i++) { 262 for (size_t c = 0; c < channels(); c++) { 263 y_ref[i * channels() + c] = a[i * a_stride() + c] + b[i * b_stride() + c]; 264 } 265 } 266 const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend()); 267 const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend()); 268 const float accumulated_range = accumulated_max - accumulated_min; 269 const float y_min = batch_size() * channels() == 1 ? 270 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin()); 271 const float y_max = batch_size() * channels() == 1 ? 272 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 273 for (size_t i = 0; i < batch_size(); i++) { 274 for (size_t c = 0; c < channels(); c++) { 275 y_ref[i * channels() + c] = std::min(std::max(y_ref[i * channels() + c], y_min), y_max); 276 } 277 } 278 279 // Create, setup, run, and destroy Add operator. 280 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 281 xnn_operator_t add_op = nullptr; 282 283 ASSERT_EQ(xnn_status_success, 284 xnn_create_add_nc_f32( 285 channels(), a_stride(), b_stride(), y_stride(), 286 y_min, y_max, 287 0, &add_op)); 288 ASSERT_NE(nullptr, add_op); 289 290 // Smart pointer to automatically delete add_op. 291 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_add_op(add_op, xnn_delete_operator); 292 293 ASSERT_EQ(xnn_status_success, 294 xnn_setup_add_nc_f32( 295 add_op, 296 batch_size(), 297 a.data(), b.data(), y.data(), 298 nullptr /* thread pool */)); 299 300 ASSERT_EQ(xnn_status_success, 301 xnn_run_operator(add_op, nullptr /* thread pool */)); 302 303 // Verify results. 304 for (size_t i = 0; i < batch_size(); i++) { 305 for (size_t c = 0; c < channels(); c++) { 306 ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * y_ref[i * channels() + c]) 307 << "i = " << i << ", c = " << c; 308 } 309 } 310 } 311 } 312 313 private: 314 size_t batch_size_{1}; 315 size_t channels_{1}; 316 size_t a_stride_{0}; 317 size_t b_stride_{0}; 318 size_t y_stride_{0}; 319 float a_scale_{0.75f}; 320 float b_scale_{1.25f}; 321 float y_scale_{0.96875f}; 322 uint8_t a_zero_point_{121}; 323 uint8_t b_zero_point_{127}; 324 uint8_t y_zero_point_{133}; 325 uint8_t qmin_{0}; 326 uint8_t qmax_{255}; 327 size_t iterations_{15}; 328 }; 329