• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <random>
16 #include <vector>
17 
18 #include <fp16.h>
19 
20 #include <xnnpack.h>
21 
22 
23 class TruncationOperatorTester {
24  public:
channels(size_t channels)25   inline TruncationOperatorTester& channels(size_t channels) {
26     assert(channels != 0);
27     this->channels_ = channels;
28     return *this;
29   }
30 
channels()31   inline size_t channels() const {
32     return this->channels_;
33   }
34 
input_stride(size_t input_stride)35   inline TruncationOperatorTester& input_stride(size_t input_stride) {
36     assert(input_stride != 0);
37     this->input_stride_ = input_stride;
38     return *this;
39   }
40 
input_stride()41   inline size_t input_stride() const {
42     if (this->input_stride_ == 0) {
43       return this->channels_;
44     } else {
45       assert(this->input_stride_ >= this->channels_);
46       return this->input_stride_;
47     }
48   }
49 
output_stride(size_t output_stride)50   inline TruncationOperatorTester& output_stride(size_t output_stride) {
51     assert(output_stride != 0);
52     this->output_stride_ = output_stride;
53     return *this;
54   }
55 
output_stride()56   inline size_t output_stride() const {
57     if (this->output_stride_ == 0) {
58       return this->channels_;
59     } else {
60       assert(this->output_stride_ >= this->channels_);
61       return this->output_stride_;
62     }
63   }
64 
batch_size(size_t batch_size)65   inline TruncationOperatorTester& batch_size(size_t batch_size) {
66     assert(batch_size != 0);
67     this->batch_size_ = batch_size;
68     return *this;
69   }
70 
batch_size()71   inline size_t batch_size() const {
72     return this->batch_size_;
73   }
74 
iterations(size_t iterations)75   inline TruncationOperatorTester& iterations(size_t iterations) {
76     this->iterations_ = iterations;
77     return *this;
78   }
79 
iterations()80   inline size_t iterations() const {
81     return this->iterations_;
82   }
83 
TestF16()84   void TestF16() const {
85     std::random_device random_device;
86     auto rng = std::mt19937(random_device());
87     std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f);
88 
89     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
90       (batch_size() - 1) * input_stride() + channels());
91     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
92     std::vector<uint16_t> output_ref(batch_size() * channels());
93     for (size_t iteration = 0; iteration < iterations(); iteration++) {
94       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
95       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
96 
97       // Compute reference results.
98       for (size_t i = 0; i < batch_size(); i++) {
99         for (size_t c = 0; c < channels(); c++) {
100           output_ref[i * channels() + c] = fp16_ieee_from_fp32_value(std::trunc(fp16_ieee_to_fp32_value(input[i * input_stride() + c])));
101         }
102       }
103 
104       // Create, setup, run, and destroy Truncation operator.
105       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
106       xnn_operator_t truncation_op = nullptr;
107 
108       const xnn_status status = xnn_create_truncation_nc_f16(
109           channels(), input_stride(), output_stride(),
110           0, &truncation_op);
111       if (status == xnn_status_unsupported_hardware) {
112         GTEST_SKIP();
113       }
114       ASSERT_EQ(xnn_status_success, status);
115       ASSERT_NE(nullptr, truncation_op);
116 
117       // Smart pointer to automatically delete truncation_op.
118       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_truncation_op(truncation_op, xnn_delete_operator);
119 
120       ASSERT_EQ(xnn_status_success,
121         xnn_setup_truncation_nc_f16(
122           truncation_op,
123           batch_size(),
124           input.data(), output.data(),
125           nullptr /* thread pool */));
126 
127       ASSERT_EQ(xnn_status_success,
128         xnn_run_operator(truncation_op, nullptr /* thread pool */));
129 
130       // Verify results.
131       for (size_t i = 0; i < batch_size(); i++) {
132         for (size_t c = 0; c < channels(); c++) {
133           ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
134             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
135         }
136       }
137     }
138   }
139 
TestF32()140   void TestF32() const {
141     std::random_device random_device;
142     auto rng = std::mt19937(random_device());
143     std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f);
144 
145     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
146       (batch_size() - 1) * input_stride() + channels());
147     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
148     std::vector<float> output_ref(batch_size() * channels());
149     for (size_t iteration = 0; iteration < iterations(); iteration++) {
150       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
151       std::fill(output.begin(), output.end(), std::nanf(""));
152 
153       // Compute reference results.
154       for (size_t i = 0; i < batch_size(); i++) {
155         for (size_t c = 0; c < channels(); c++) {
156           output_ref[i * channels() + c] = std::trunc(input[i * input_stride() + c]);
157         }
158       }
159 
160       // Create, setup, run, and destroy Truncation operator.
161       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
162       xnn_operator_t truncation_op = nullptr;
163 
164       ASSERT_EQ(xnn_status_success,
165         xnn_create_truncation_nc_f32(
166           channels(), input_stride(), output_stride(),
167           0, &truncation_op));
168       ASSERT_NE(nullptr, truncation_op);
169 
170       // Smart pointer to automatically delete truncation_op.
171       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_truncation_op(truncation_op, xnn_delete_operator);
172 
173       ASSERT_EQ(xnn_status_success,
174         xnn_setup_truncation_nc_f32(
175           truncation_op,
176           batch_size(),
177           input.data(), output.data(),
178           nullptr /* thread pool */));
179 
180       ASSERT_EQ(xnn_status_success,
181         xnn_run_operator(truncation_op, nullptr /* thread pool */));
182 
183       // Verify results.
184       for (size_t i = 0; i < batch_size(); i++) {
185         for (size_t c = 0; c < channels(); c++) {
186           ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
187             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
188         }
189       }
190     }
191   }
192 
193  private:
194   size_t batch_size_{1};
195   size_t channels_{1};
196   size_t input_stride_{0};
197   size_t output_stride_{0};
198   size_t iterations_{15};
199 };
200