• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cmath>
16 #include <cstddef>
17 #include <cstdlib>
18 #include <functional>
19 #include <limits>
20 #include <random>
21 #include <vector>
22 
23 #include <xnnpack.h>
24 
25 
26 class SoftMaxOperatorTester {
27  public:
channels(size_t channels)28   inline SoftMaxOperatorTester& channels(size_t channels) {
29     assert(channels != 0);
30     this->channels_ = channels;
31     return *this;
32   }
33 
channels()34   inline size_t channels() const {
35     return this->channels_;
36   }
37 
input_stride(size_t input_stride)38   inline SoftMaxOperatorTester& input_stride(size_t input_stride) {
39     assert(input_stride != 0);
40     this->input_stride_ = input_stride;
41     return *this;
42   }
43 
input_stride()44   inline size_t input_stride() const {
45     if (this->input_stride_ == 0) {
46       return this->channels_;
47     } else {
48       assert(this->input_stride_ >= this->channels_);
49       return this->input_stride_;
50     }
51   }
52 
output_stride(size_t output_stride)53   inline SoftMaxOperatorTester& output_stride(size_t output_stride) {
54     assert(output_stride != 0);
55     this->output_stride_ = output_stride;
56     return *this;
57   }
58 
output_stride()59   inline size_t output_stride() const {
60     if (this->output_stride_ == 0) {
61       return this->channels_;
62     } else {
63       assert(this->output_stride_ >= this->channels_);
64       return this->output_stride_;
65     }
66   }
67 
batch_size(size_t batch_size)68   inline SoftMaxOperatorTester& batch_size(size_t batch_size) {
69     assert(batch_size != 0);
70     this->batch_size_ = batch_size;
71     return *this;
72   }
73 
batch_size()74   inline size_t batch_size() const {
75     return this->batch_size_;
76   }
77 
input_scale(float input_scale)78   inline SoftMaxOperatorTester& input_scale(float input_scale) {
79     assert(input_scale > 0.0f);
80     assert(std::isnormal(input_scale));
81     this->input_scale_ = input_scale;
82     return *this;
83   }
84 
input_scale()85   inline float input_scale() const {
86     return this->input_scale_;
87   }
88 
input_zero_point(uint8_t input_zero_point)89   inline SoftMaxOperatorTester& input_zero_point(uint8_t input_zero_point) {
90     this->input_zero_point_ = input_zero_point;
91     return *this;
92   }
93 
input_zero_point()94   inline uint8_t input_zero_point() const {
95     return this->input_zero_point_;
96   }
97 
output_scale()98   inline float output_scale() const {
99     return 1.0f / 256.0f;
100   }
101 
output_zero_point()102   inline uint8_t output_zero_point() const {
103     return 0;
104   }
105 
iterations(size_t iterations)106   inline SoftMaxOperatorTester& iterations(size_t iterations) {
107     this->iterations_ = iterations;
108     return *this;
109   }
110 
iterations()111   inline size_t iterations() const {
112     return this->iterations_;
113   }
114 
TestQU8()115   void TestQU8() const {
116     std::random_device random_device;
117     auto rng = std::mt19937(random_device());
118     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
119 
120     std::vector<uint8_t> input((batch_size() - 1) * input_stride() + channels());
121     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
122     std::vector<float> output_ref(batch_size() * channels());
123     for (size_t iteration = 0; iteration < iterations(); iteration++) {
124       std::generate(input.begin(), input.end(), std::ref(u8rng));
125       std::fill(output.begin(), output.end(), 0xA5);
126 
127       // Compute reference results.
128       for (size_t i = 0; i < batch_size(); i++) {
129         const int32_t max_input = *std::max_element(
130           input.data() + i * input_stride(),
131           input.data() + i * input_stride() + channels());
132         float sum_exp = 0.0f;
133         for (size_t c = 0; c < channels(); c++) {
134           sum_exp +=
135               std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
136                        input_scale());
137         }
138         for (size_t c = 0; c < channels(); c++) {
139           output_ref[i * channels() + c] =
140               std::exp((int32_t(input[i * input_stride() + c]) - max_input) *
141                        input_scale()) /
142               (sum_exp * output_scale());
143           output_ref[i * channels() + c] = std::min(output_ref[i * channels() + c], 255.0f);
144         }
145       }
146 
147       // Create, setup, run, and destroy SoftMax operator.
148       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
149       xnn_operator_t softmax_op = nullptr;
150 
151       ASSERT_EQ(xnn_status_success,
152         xnn_create_softmax_nc_qu8(
153           channels(), input_stride(), output_stride(),
154           input_scale(),
155           output_zero_point(), output_scale(),
156           0, &softmax_op));
157       ASSERT_NE(nullptr, softmax_op);
158 
159       // Smart pointer to automatically delete softmax_op.
160       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
161 
162       ASSERT_EQ(xnn_status_success,
163         xnn_setup_softmax_nc_qu8(
164           softmax_op,
165           batch_size(),
166           input.data(), output.data(),
167           nullptr /* thread pool */));
168 
169       ASSERT_EQ(xnn_status_success,
170         xnn_run_operator(softmax_op, nullptr /* thread pool */));
171 
172       // Verify results.
173       for (size_t i = 0; i < batch_size(); i++) {
174         for (size_t c = 0; c < channels(); c++) {
175           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
176         }
177       }
178     }
179   }
180 
TestF32()181   void TestF32() const {
182     std::random_device random_device;
183     auto rng = std::mt19937(random_device());
184     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
185 
186     std::vector<float> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
187     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
188     std::vector<double> output_ref(batch_size() * channels());
189     for (size_t iteration = 0; iteration < iterations(); iteration++) {
190       std::generate(input.begin(), input.end(), std::ref(f32rng));
191       std::fill(output.begin(), output.end(), std::nanf(""));
192 
193       // Compute reference results.
194       for (size_t i = 0; i < batch_size(); i++) {
195         const double max_input = *std::max_element(
196           input.data() + i * input_stride(),
197           input.data() + i * input_stride() + channels());
198         double sum_exp = 0.0;
199         for (size_t c = 0; c < channels(); c++) {
200           sum_exp += std::exp(double(input[i * input_stride() + c]) - max_input);
201         }
202         for (size_t c = 0; c < channels(); c++) {
203           output_ref[i * channels() + c] =
204               std::exp(double(input[i * input_stride() + c]) - max_input) / sum_exp;
205         }
206       }
207 
208       // Create, setup, run, and destroy SoftMax operator.
209       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
210       xnn_operator_t softmax_op = nullptr;
211 
212       ASSERT_EQ(xnn_status_success,
213         xnn_create_softmax_nc_f32(
214           channels(), input_stride(), output_stride(),
215           0, &softmax_op));
216       ASSERT_NE(nullptr, softmax_op);
217 
218       // Smart pointer to automatically delete softmax_op.
219       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_softmax_op(softmax_op, xnn_delete_operator);
220 
221       ASSERT_EQ(xnn_status_success,
222         xnn_setup_softmax_nc_f32(
223           softmax_op,
224           batch_size(),
225           input.data(), output.data(),
226           nullptr /* thread pool */));
227 
228       ASSERT_EQ(xnn_status_success,
229         xnn_run_operator(softmax_op, nullptr /* thread pool */));
230 
231       // Verify results.
232       for (size_t i = 0; i < batch_size(); i++) {
233         for (size_t c = 0; c < channels(); c++) {
234           ASSERT_NEAR(
235             double(output[i * output_stride() + c]),
236             output_ref[i * channels() + c],
237             output_ref[i * channels() + c] * 1.0e-4);
238         }
239       }
240     }
241   }
242 
243  private:
244   size_t batch_size_{1};
245   size_t channels_{1};
246   size_t input_stride_{0};
247   size_t output_stride_{0};
248   float input_scale_{0.176080093};
249   uint8_t input_zero_point_{121};
250   size_t iterations_{15};
251 };
252