1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <algorithm> 14 #include <cassert> 15 #include <cmath> 16 #include <cstddef> 17 #include <cstdlib> 18 #include <functional> 19 #include <limits> 20 #include <random> 21 #include <vector> 22 23 #include <xnnpack.h> 24 25 26 class SoftMaxOperatorTester { 27 public: channels(size_t channels)28 inline SoftMaxOperatorTester& channels(size_t channels) { 29 assert(channels != 0); 30 this->channels_ = channels; 31 return *this; 32 } 33 channels()34 inline size_t channels() const { 35 return this->channels_; 36 } 37 input_stride(size_t input_stride)38 inline SoftMaxOperatorTester& input_stride(size_t input_stride) { 39 assert(input_stride != 0); 40 this->input_stride_ = input_stride; 41 return *this; 42 } 43 input_stride()44 inline size_t input_stride() const { 45 if (this->input_stride_ == 0) { 46 return this->channels_; 47 } else { 48 assert(this->input_stride_ >= this->channels_); 49 return this->input_stride_; 50 } 51 } 52 output_stride(size_t output_stride)53 inline SoftMaxOperatorTester& output_stride(size_t output_stride) { 54 assert(output_stride != 0); 55 this->output_stride_ = output_stride; 56 return *this; 57 } 58 output_stride()59 inline size_t output_stride() const { 60 if (this->output_stride_ == 0) { 61 return this->channels_; 62 } else { 63 assert(this->output_stride_ >= this->channels_); 64 return this->output_stride_; 65 } 66 } 67 batch_size(size_t batch_size)68 inline SoftMaxOperatorTester& batch_size(size_t batch_size) { 69 assert(batch_size != 0); 70 this->batch_size_ = batch_size; 71 return *this; 72 } 73 batch_size()74 inline size_t batch_size() const { 75 return this->batch_size_; 76 } 77 input_scale(float input_scale)78 inline SoftMaxOperatorTester& input_scale(float input_scale) { 79 assert(input_scale > 0.0f); 80 assert(std::isnormal(input_scale)); 81 this->input_scale_ = input_scale; 82 return *this; 83 } 84 input_scale()85 inline float input_scale() const { 86 return this->input_scale_; 87 } 88 input_zero_point(uint8_t input_zero_point)89 inline SoftMaxOperatorTester& input_zero_point(uint8_t input_zero_point) { 90 this->input_zero_point_ = input_zero_point; 91 return *this; 92 } 93 input_zero_point()94 inline uint8_t input_zero_point() const { 95 return this->input_zero_point_; 96 } 97 output_scale()98 inline float output_scale() const { 99 return 1.0f / 256.0f; 100 } 101 output_zero_point()102 inline uint8_t output_zero_point() const { 103 return 0; 104 } 105 iterations(size_t iterations)106 inline SoftMaxOperatorTester& iterations(size_t iterations) { 107 this->iterations_ = iterations; 108 return *this; 109 } 110 iterations()111 inline size_t iterations() const { 112 return this->iterations_; 113 } 114 TestQU8()115 void TestQU8() const { 116 std::random_device random_device; 117 auto rng = std::mt19937(random_device()); 118 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng); 119 120 std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels()); 121 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 122 std::vector<float> output_ref(batch_size() * channels()); 123 for (size_t iteration = 0; iteration < iterations(); iteration++) { 124 std::generate(input.begin(), input.end(), std::ref(u8rng)); 125 std::fill(output.begin(), output.end(), 0xA5); 126 127 // Compute reference results. 128 for (size_t i = 0; i < batch_size(); i++) { 129 const int32_t max_input = *std::max_element( 130 input.data() + i * input_stride(), 131 input.data() + i * input_stride() + channels()); 132 float sum_exp = 0.0f; 133 for (size_t c = 0; c < channels(); c++) { 134 sum_exp += 135 std::exp((int32_t(input[i * input_stride() + c]) - max_input) * 136 input_scale()); 137 } 138 for (size_t c = 0; c < channels(); c++) { 139 output_ref[i * channels() + c] = 140 std::exp((int32_t(input[i * input_stride() + c]) - max_input) * 141 input_scale()) / 142 (sum_exp * output_scale()); 143 output_ref[i * channels() + c] = std::min(output_ref[i * channels() + c], 255.0f); 144 } 145 } 146 147 // Create, setup, run, and destroy SoftMax operator. 148 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 149 xnn_operator_t softmax_op = nullptr; 150 151 ASSERT_EQ(xnn_status_success, 152 xnn_create_softmax_nc_qu8( 153 channels(), input_stride(), output_stride(), 154 input_scale(), 155 output_zero_point(), output_scale(), 156 0, &softmax_op)); 157 ASSERT_NE(nullptr, softmax_op); 158 159 // Smart pointer to automatically delete softmax_op. 160 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator); 161 162 ASSERT_EQ(xnn_status_success, 163 xnn_setup_softmax_nc_qu8( 164 softmax_op, 165 batch_size(), 166 input.data(), output.data(), 167 nullptr /* thread pool */)); 168 169 ASSERT_EQ(xnn_status_success, 170 xnn_run_operator(softmax_op, nullptr /* thread pool */)); 171 172 // Verify results. 173 for (size_t i = 0; i < batch_size(); i++) { 174 for (size_t c = 0; c < channels(); c++) { 175 ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f); 176 } 177 } 178 } 179 } 180 TestF32()181 void TestF32() const { 182 std::random_device random_device; 183 auto rng = std::mt19937(random_device()); 184 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng); 185 186 std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 187 std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 188 std::vector<double> output_ref(batch_size() * channels()); 189 for (size_t iteration = 0; iteration < iterations(); iteration++) { 190 std::generate(input.begin(), input.end(), std::ref(f32rng)); 191 std::fill(output.begin(), output.end(), std::nanf("")); 192 193 // Compute reference results. 194 for (size_t i = 0; i < batch_size(); i++) { 195 const double max_input = *std::max_element( 196 input.data() + i * input_stride(), 197 input.data() + i * input_stride() + channels()); 198 double sum_exp = 0.0; 199 for (size_t c = 0; c < channels(); c++) { 200 sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input); 201 } 202 for (size_t c = 0; c < channels(); c++) { 203 output_ref[i * channels() + c] = 204 std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp; 205 } 206 } 207 208 // Create, setup, run, and destroy SoftMax operator. 209 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 210 xnn_operator_t softmax_op = nullptr; 211 212 ASSERT_EQ(xnn_status_success, 213 xnn_create_softmax_nc_f32( 214 channels(), input_stride(), output_stride(), 215 0, &softmax_op)); 216 ASSERT_NE(nullptr, softmax_op); 217 218 // Smart pointer to automatically delete softmax_op. 219 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator); 220 221 ASSERT_EQ(xnn_status_success, 222 xnn_setup_softmax_nc_f32( 223 softmax_op, 224 batch_size(), 225 input.data(), output.data(), 226 nullptr /* thread pool */)); 227 228 ASSERT_EQ(xnn_status_success, 229 xnn_run_operator(softmax_op, nullptr /* thread pool */)); 230 231 // Verify results. 232 for (size_t i = 0; i < batch_size(); i++) { 233 for (size_t c = 0; c < channels(); c++) { 234 ASSERT_NEAR( 235 double(output[i * output_stride() + c]), 236 output_ref[i * channels() + c], 237 output_ref[i * channels() + c] * 1.0e-4); 238 } 239 } 240 } 241 } 242 243 private: 244 size_t batch_size_{1}; 245 size_t channels_{1}; 246 size_t input_stride_{0}; 247 size_t output_stride_{0}; 248 float input_scale_{0.176080093}; 249 uint8_t input_zero_point_{121}; 250 size_t iterations_{15}; 251 }; 252