• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <functional>
18 #include <limits>
19 #include <random>
20 #include <vector>
21 
22 #include <xnnpack.h>
23 
24 
25 class ClampOperatorTester {
26  public:
channels(size_t channels)27   inline ClampOperatorTester& channels(size_t channels) {
28     assert(channels != 0);
29     this->channels_ = channels;
30     return *this;
31   }
32 
channels()33   inline size_t channels() const {
34     return this->channels_;
35   }
36 
input_stride(size_t input_stride)37   inline ClampOperatorTester& input_stride(size_t input_stride) {
38     assert(input_stride != 0);
39     this->input_stride_ = input_stride;
40     return *this;
41   }
42 
input_stride()43   inline size_t input_stride() const {
44     if (this->input_stride_ == 0) {
45       return this->channels_;
46     } else {
47       assert(this->input_stride_ >= this->channels_);
48       return this->input_stride_;
49     }
50   }
51 
output_stride(size_t output_stride)52   inline ClampOperatorTester& output_stride(size_t output_stride) {
53     assert(output_stride != 0);
54     this->output_stride_ = output_stride;
55     return *this;
56   }
57 
output_stride()58   inline size_t output_stride() const {
59     if (this->output_stride_ == 0) {
60       return this->channels_;
61     } else {
62       assert(this->output_stride_ >= this->channels_);
63       return this->output_stride_;
64     }
65   }
66 
batch_size(size_t batch_size)67   inline ClampOperatorTester& batch_size(size_t batch_size) {
68     assert(batch_size != 0);
69     this->batch_size_ = batch_size;
70     return *this;
71   }
72 
batch_size()73   inline size_t batch_size() const {
74     return this->batch_size_;
75   }
76 
qmin(uint8_t qmin)77   inline ClampOperatorTester& qmin(uint8_t qmin) {
78     this->qmin_ = qmin;
79     return *this;
80   }
81 
qmin()82   inline uint8_t qmin() const {
83     return this->qmin_;
84   }
85 
qmax(uint8_t qmax)86   inline ClampOperatorTester& qmax(uint8_t qmax) {
87     this->qmax_ = qmax;
88     return *this;
89   }
90 
qmax()91   inline uint8_t qmax() const {
92     return this->qmax_;
93   }
94 
relu_activation(bool relu_activation)95   inline ClampOperatorTester& relu_activation(bool relu_activation) {
96     this->relu_activation_ = relu_activation;
97     return *this;
98   }
99 
relu_activation()100   inline bool relu_activation() const {
101     return this->relu_activation_;
102   }
103 
iterations(size_t iterations)104   inline ClampOperatorTester& iterations(size_t iterations) {
105     this->iterations_ = iterations;
106     return *this;
107   }
108 
iterations()109   inline size_t iterations() const {
110     return this->iterations_;
111   }
112 
TestS8()113   void TestS8() const {
114     std::random_device random_device;
115     auto rng = std::mt19937(random_device());
116     auto i8rng = std::bind(
117       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
118       std::ref(rng));
119 
120     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
121       (batch_size() - 1) * input_stride() + channels());
122     std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
123     std::vector<int8_t> output_ref(batch_size() * channels());
124     for (size_t iteration = 0; iteration < iterations(); iteration++) {
125       std::generate(input.begin(), input.end(), std::ref(i8rng));
126       std::fill(output.begin(), output.end(), INT8_C(0xA5));
127 
128       // Compute reference results.
129       for (size_t i = 0; i < batch_size(); i++) {
130         for (size_t c = 0; c < channels(); c++) {
131           const int8_t x = input[i * input_stride() + c];
132           const int8_t y = std::min(std::max(x, int8_t(qmin() - 0x80)), int8_t(qmax() - 0x80));
133           output_ref[i * channels() + c] = y;
134         }
135       }
136 
137       // Create, setup, run, and destroy Clamp operator.
138       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
139       xnn_operator_t clamp_op = nullptr;
140 
141       ASSERT_EQ(xnn_status_success,
142         xnn_create_clamp_nc_s8(
143           channels(), input_stride(), output_stride(),
144           int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
145           0, &clamp_op));
146       ASSERT_NE(nullptr, clamp_op);
147 
148       // Smart pointer to automatically delete clamp_op.
149       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator);
150 
151       ASSERT_EQ(xnn_status_success,
152         xnn_setup_clamp_nc_s8(
153           clamp_op,
154           batch_size(),
155           input.data(), output.data(),
156           nullptr /* thread pool */));
157 
158       ASSERT_EQ(xnn_status_success,
159         xnn_run_operator(clamp_op, nullptr /* thread pool */));
160 
161       // Verify results .
162       for (size_t i = 0; i < batch_size(); i++) {
163         for (size_t c = 0; c < channels(); c++) {
164           ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
165             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
166           ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
167             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
168           ASSERT_EQ(int32_t(output_ref[i * channels() + c]), int32_t(output[i * output_stride() + c]))
169             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels()
170             << ", qmin = " << int32_t(qmin() - 0x80) << ", qmax = " << int32_t(qmax() - 0x80);
171         }
172       }
173     }
174   }
175 
TestU8()176   void TestU8() const {
177     std::random_device random_device;
178     auto rng = std::mt19937(random_device());
179     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
180 
181     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
182       (batch_size() - 1) * input_stride() + channels());
183     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
184     std::vector<uint8_t> output_ref(batch_size() * channels());
185     for (size_t iteration = 0; iteration < iterations(); iteration++) {
186       std::generate(input.begin(), input.end(), std::ref(u8rng));
187       std::fill(output.begin(), output.end(), 0xA5);
188 
189       // Compute reference results.
190       for (size_t i = 0; i < batch_size(); i++) {
191         for (size_t c = 0; c < channels(); c++) {
192           const uint8_t x = input[i * input_stride() + c];
193           const uint8_t y = std::min(std::max(x, qmin()), qmax());
194           output_ref[i * channels() + c] = y;
195         }
196       }
197 
198       // Create, setup, run, and destroy Clamp operator.
199       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
200       xnn_operator_t clamp_op = nullptr;
201 
202       ASSERT_EQ(xnn_status_success,
203         xnn_create_clamp_nc_u8(
204           channels(), input_stride(), output_stride(),
205           qmin(), qmax(),
206           0, &clamp_op));
207       ASSERT_NE(nullptr, clamp_op);
208 
209       // Smart pointer to automatically delete clamp_op.
210       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator);
211 
212       ASSERT_EQ(xnn_status_success,
213         xnn_setup_clamp_nc_u8(
214           clamp_op,
215           batch_size(),
216           input.data(), output.data(),
217           nullptr /* thread pool */));
218 
219       ASSERT_EQ(xnn_status_success,
220         xnn_run_operator(clamp_op, nullptr /* thread pool */));
221 
222       // Verify results .
223       for (size_t i = 0; i < batch_size(); i++) {
224         for (size_t c = 0; c < channels(); c++) {
225           ASSERT_LE(uint32_t(output[i * output_stride() + c]), uint32_t(qmax()))
226             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
227           ASSERT_GE(uint32_t(output[i * output_stride() + c]), uint32_t(qmin()))
228             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
229           ASSERT_EQ(uint32_t(output_ref[i * channels() + c]), uint32_t(output[i * output_stride() + c]))
230             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels()
231             << ", qmin = " << uint32_t(qmin()) << ", qmax = " << uint32_t(qmax());
232         }
233       }
234     }
235   }
236 
TestF32()237   void TestF32() const {
238     std::random_device random_device;
239     auto rng = std::mt19937(random_device());
240     auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 255.0f), rng);
241 
242     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
243       (batch_size() - 1) * input_stride() + channels());
244     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
245     std::vector<float> output_ref(batch_size() * channels());
246     for (size_t iteration = 0; iteration < iterations(); iteration++) {
247       std::generate(input.begin(), input.end(), std::ref(f32rng));
248       std::fill(output.begin(), output.end(), std::nanf(""));
249 
250       // Compute reference results.
251       for (size_t i = 0; i < batch_size(); i++) {
252         for (size_t c = 0; c < channels(); c++) {
253           const float x = input[i * input_stride() + c];
254           const float y = relu_activation() ? std::max(x, 0.f) :
255             std::min(std::max(x, float(qmin())), float(qmax()));
256           output_ref[i * channels() + c] = y;
257         }
258       }
259 
260       // Create, setup, run, and destroy Clamp operator.
261       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
262       xnn_operator_t clamp_op = nullptr;
263 
264       const float output_min = relu_activation() ? 0.0f : float(qmin());
265       const float output_max = relu_activation() ? std::numeric_limits<float>::infinity() : float(qmax());
266       ASSERT_EQ(xnn_status_success,
267         xnn_create_clamp_nc_f32(
268           channels(), input_stride(), output_stride(),
269           output_min, output_max,
270           0, &clamp_op));
271       ASSERT_NE(nullptr, clamp_op);
272 
273       // Smart pointer to automatically delete clamp_op.
274       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator);
275 
276       ASSERT_EQ(xnn_status_success,
277         xnn_setup_clamp_nc_f32(
278           clamp_op,
279           batch_size(),
280           input.data(), output.data(),
281           nullptr /* thread pool */));
282 
283       ASSERT_EQ(xnn_status_success,
284         xnn_run_operator(clamp_op, nullptr /* thread pool */));
285 
286       // Verify results.
287       for (size_t i = 0; i < batch_size(); i++) {
288         for (size_t c = 0; c < channels(); c++) {
289           ASSERT_LE(output[i * output_stride() + c], output_max)
290             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
291           ASSERT_GE(output[i * output_stride() + c], output_min)
292             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
293           ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
294             << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels()
295             << ", min = " << output_min << ", max = " << output_max;
296         }
297       }
298     }
299   }
300 
301  private:
302   size_t batch_size_{1};
303   size_t channels_{1};
304   size_t input_stride_{0};
305   size_t output_stride_{0};
306   uint8_t qmin_{5};
307   uint8_t qmax_{250};
308   bool relu_activation_{false};
309   size_t iterations_{15};
310 };
311