1 // Copyright 2020 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <algorithm> 11 #include <cassert> 12 #include <cmath> 13 #include <cstddef> 14 #include <cstdlib> 15 #include <random> 16 #include <vector> 17 18 #include <fp16.h> 19 20 #include <xnnpack.h> 21 22 23 class TruncationOperatorTester { 24 public: channels(size_t channels)25 inline TruncationOperatorTester& channels(size_t channels) { 26 assert(channels != 0); 27 this->channels_ = channels; 28 return *this; 29 } 30 channels()31 inline size_t channels() const { 32 return this->channels_; 33 } 34 input_stride(size_t input_stride)35 inline TruncationOperatorTester& input_stride(size_t input_stride) { 36 assert(input_stride != 0); 37 this->input_stride_ = input_stride; 38 return *this; 39 } 40 input_stride()41 inline size_t input_stride() const { 42 if (this->input_stride_ == 0) { 43 return this->channels_; 44 } else { 45 assert(this->input_stride_ >= this->channels_); 46 return this->input_stride_; 47 } 48 } 49 output_stride(size_t output_stride)50 inline TruncationOperatorTester& output_stride(size_t output_stride) { 51 assert(output_stride != 0); 52 this->output_stride_ = output_stride; 53 return *this; 54 } 55 output_stride()56 inline size_t output_stride() const { 57 if (this->output_stride_ == 0) { 58 return this->channels_; 59 } else { 60 assert(this->output_stride_ >= this->channels_); 61 return this->output_stride_; 62 } 63 } 64 batch_size(size_t batch_size)65 inline TruncationOperatorTester& batch_size(size_t batch_size) { 66 assert(batch_size != 0); 67 this->batch_size_ = batch_size; 68 return *this; 69 } 70 batch_size()71 inline size_t batch_size() const { 72 return this->batch_size_; 73 } 74 iterations(size_t iterations)75 inline TruncationOperatorTester& iterations(size_t iterations) { 76 this->iterations_ = iterations; 77 return *this; 78 } 79 iterations()80 inline size_t iterations() const { 81 return this->iterations_; 82 } 83 TestF16()84 void TestF16() const { 85 std::random_device random_device; 86 auto rng = std::mt19937(random_device()); 87 std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f); 88 89 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 90 (batch_size() - 1) * input_stride() + channels()); 91 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 92 std::vector<uint16_t> output_ref(batch_size() * channels()); 93 for (size_t iteration = 0; iteration < iterations(); iteration++) { 94 std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 95 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 96 97 // Compute reference results. 98 for (size_t i = 0; i < batch_size(); i++) { 99 for (size_t c = 0; c < channels(); c++) { 100 output_ref[i * channels() + c] = fp16_ieee_from_fp32_value(std::trunc(fp16_ieee_to_fp32_value(input[i * input_stride() + c]))); 101 } 102 } 103 104 // Create, setup, run, and destroy Truncation operator. 105 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 106 xnn_operator_t truncation_op = nullptr; 107 108 const xnn_status status = xnn_create_truncation_nc_f16( 109 channels(), input_stride(), output_stride(), 110 0, &truncation_op); 111 if (status == xnn_status_unsupported_hardware) { 112 GTEST_SKIP(); 113 } 114 ASSERT_EQ(xnn_status_success, status); 115 ASSERT_NE(nullptr, truncation_op); 116 117 // Smart pointer to automatically delete truncation_op. 118 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_truncation_op(truncation_op, xnn_delete_operator); 119 120 ASSERT_EQ(xnn_status_success, 121 xnn_setup_truncation_nc_f16( 122 truncation_op, 123 batch_size(), 124 input.data(), output.data(), 125 nullptr /* thread pool */)); 126 127 ASSERT_EQ(xnn_status_success, 128 xnn_run_operator(truncation_op, nullptr /* thread pool */)); 129 130 // Verify results. 131 for (size_t i = 0; i < batch_size(); i++) { 132 for (size_t c = 0; c < channels(); c++) { 133 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 134 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 135 } 136 } 137 } 138 } 139 TestF32()140 void TestF32() const { 141 std::random_device random_device; 142 auto rng = std::mt19937(random_device()); 143 std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f); 144 145 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 146 (batch_size() - 1) * input_stride() + channels()); 147 std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 148 std::vector<float> output_ref(batch_size() * channels()); 149 for (size_t iteration = 0; iteration < iterations(); iteration++) { 150 std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 151 std::fill(output.begin(), output.end(), std::nanf("")); 152 153 // Compute reference results. 154 for (size_t i = 0; i < batch_size(); i++) { 155 for (size_t c = 0; c < channels(); c++) { 156 output_ref[i * channels() + c] = std::trunc(input[i * input_stride() + c]); 157 } 158 } 159 160 // Create, setup, run, and destroy Truncation operator. 161 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 162 xnn_operator_t truncation_op = nullptr; 163 164 ASSERT_EQ(xnn_status_success, 165 xnn_create_truncation_nc_f32( 166 channels(), input_stride(), output_stride(), 167 0, &truncation_op)); 168 ASSERT_NE(nullptr, truncation_op); 169 170 // Smart pointer to automatically delete truncation_op. 171 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_truncation_op(truncation_op, xnn_delete_operator); 172 173 ASSERT_EQ(xnn_status_success, 174 xnn_setup_truncation_nc_f32( 175 truncation_op, 176 batch_size(), 177 input.data(), output.data(), 178 nullptr /* thread pool */)); 179 180 ASSERT_EQ(xnn_status_success, 181 xnn_run_operator(truncation_op, nullptr /* thread pool */)); 182 183 // Verify results. 184 for (size_t i = 0; i < batch_size(); i++) { 185 for (size_t c = 0; c < channels(); c++) { 186 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 187 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 188 } 189 } 190 } 191 } 192 193 private: 194 size_t batch_size_{1}; 195 size_t channels_{1}; 196 size_t input_stride_{0}; 197 size_t output_stride_{0}; 198 size_t iterations_{15}; 199 }; 200