1 // Copyright 2020 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <algorithm> 11 #include <cassert> 12 #include <cstddef> 13 #include <cstdlib> 14 #include <functional> 15 #include <limits> 16 #include <random> 17 #include <vector> 18 19 #include <xnnpack.h> 20 21 22 class CopyOperatorTester { 23 public: channels(size_t channels)24 inline CopyOperatorTester& channels(size_t channels) { 25 assert(channels != 0); 26 this->channels_ = channels; 27 return *this; 28 } 29 channels()30 inline size_t channels() const { 31 return this->channels_; 32 } 33 input_stride(size_t input_stride)34 inline CopyOperatorTester& input_stride(size_t input_stride) { 35 assert(input_stride != 0); 36 this->input_stride_ = input_stride; 37 return *this; 38 } 39 input_stride()40 inline size_t input_stride() const { 41 if (this->input_stride_ == 0) { 42 return this->channels_; 43 } else { 44 assert(this->input_stride_ >= this->channels_); 45 return this->input_stride_; 46 } 47 } 48 output_stride(size_t output_stride)49 inline CopyOperatorTester& output_stride(size_t output_stride) { 50 assert(output_stride != 0); 51 this->output_stride_ = output_stride; 52 return *this; 53 } 54 output_stride()55 inline size_t output_stride() const { 56 if (this->output_stride_ == 0) { 57 return this->channels_; 58 } else { 59 assert(this->output_stride_ >= this->channels_); 60 return this->output_stride_; 61 } 62 } 63 batch_size(size_t batch_size)64 inline CopyOperatorTester& batch_size(size_t batch_size) { 65 assert(batch_size != 0); 66 this->batch_size_ = batch_size; 67 return *this; 68 } 69 batch_size()70 inline size_t batch_size() const { 71 return this->batch_size_; 72 } 73 iterations(size_t iterations)74 inline CopyOperatorTester& iterations(size_t iterations) { 75 this->iterations_ = iterations; 76 return *this; 77 } 78 iterations()79 inline size_t iterations() const { 80 return this->iterations_; 81 } 82 TestX8()83 void TestX8() const { 84 std::random_device random_device; 85 auto rng = std::mt19937(random_device()); 86 auto u8rng = std::bind( 87 std::uniform_int_distribution<uint32_t>( std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()), 88 rng); 89 90 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 91 (batch_size() - 1) * input_stride() + channels()); 92 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 93 std::vector<uint8_t> output_ref(batch_size() * channels()); 94 for (size_t iteration = 0; iteration < iterations(); iteration++) { 95 std::generate(input.begin(), input.end(), std::ref(u8rng)); 96 std::fill(output.begin(), output.end(), UINT16_C(0xFA)); 97 98 // Compute reference results. 99 for (size_t i = 0; i < batch_size(); i++) { 100 for (size_t c = 0; c < channels(); c++) { 101 output_ref[i * channels() + c] = input[i * input_stride() + c]; 102 } 103 } 104 105 // Create, setup, run, and destroy Copy operator. 106 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 107 xnn_operator_t copy_op = nullptr; 108 109 ASSERT_EQ(xnn_status_success, 110 xnn_create_copy_nc_x8( 111 channels(), input_stride(), output_stride(), 112 0, ©_op)); 113 ASSERT_NE(nullptr, copy_op); 114 115 // Smart pointer to automatically delete copy_op. 116 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 117 118 ASSERT_EQ(xnn_status_success, 119 xnn_setup_copy_nc_x8( 120 copy_op, 121 batch_size(), 122 input.data(), output.data(), 123 nullptr /* thread pool */)); 124 125 ASSERT_EQ(xnn_status_success, 126 xnn_run_operator(copy_op, nullptr /* thread pool */)); 127 128 // Verify results. 129 for (size_t i = 0; i < batch_size(); i++) { 130 for (size_t c = 0; c < channels(); c++) { 131 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 132 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 133 } 134 } 135 } 136 } 137 TestX16()138 void TestX16() const { 139 std::random_device random_device; 140 auto rng = std::mt19937(random_device()); 141 auto u16rng = std::bind(std::uniform_int_distribution<uint16_t>(), rng); 142 143 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 144 (batch_size() - 1) * input_stride() + channels()); 145 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 146 std::vector<uint16_t> output_ref(batch_size() * channels()); 147 for (size_t iteration = 0; iteration < iterations(); iteration++) { 148 std::generate(input.begin(), input.end(), std::ref(u16rng)); 149 std::fill(output.begin(), output.end(), UINT16_C(0xDEAD)); 150 151 // Compute reference results. 152 for (size_t i = 0; i < batch_size(); i++) { 153 for (size_t c = 0; c < channels(); c++) { 154 output_ref[i * channels() + c] = input[i * input_stride() + c]; 155 } 156 } 157 158 // Create, setup, run, and destroy Copy operator. 159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 160 xnn_operator_t copy_op = nullptr; 161 162 ASSERT_EQ(xnn_status_success, 163 xnn_create_copy_nc_x16( 164 channels(), input_stride(), output_stride(), 165 0, ©_op)); 166 ASSERT_NE(nullptr, copy_op); 167 168 // Smart pointer to automatically delete copy_op. 169 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 170 171 ASSERT_EQ(xnn_status_success, 172 xnn_setup_copy_nc_x16( 173 copy_op, 174 batch_size(), 175 input.data(), output.data(), 176 nullptr /* thread pool */)); 177 178 ASSERT_EQ(xnn_status_success, 179 xnn_run_operator(copy_op, nullptr /* thread pool */)); 180 181 // Verify results. 182 for (size_t i = 0; i < batch_size(); i++) { 183 for (size_t c = 0; c < channels(); c++) { 184 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 185 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 186 } 187 } 188 } 189 } 190 TestX32()191 void TestX32() const { 192 std::random_device random_device; 193 auto rng = std::mt19937(random_device()); 194 auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng); 195 196 std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + 197 (batch_size() - 1) * input_stride() + channels()); 198 std::vector<uint32_t> output((batch_size() - 1) * output_stride() + channels()); 199 std::vector<uint32_t> output_ref(batch_size() * channels()); 200 for (size_t iteration = 0; iteration < iterations(); iteration++) { 201 std::generate(input.begin(), input.end(), std::ref(u32rng)); 202 std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF)); 203 204 // Compute reference results. 205 for (size_t i = 0; i < batch_size(); i++) { 206 for (size_t c = 0; c < channels(); c++) { 207 output_ref[i * channels() + c] = input[i * input_stride() + c]; 208 } 209 } 210 211 // Create, setup, run, and destroy Copy operator. 212 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 213 xnn_operator_t copy_op = nullptr; 214 215 ASSERT_EQ(xnn_status_success, 216 xnn_create_copy_nc_x32( 217 channels(), input_stride(), output_stride(), 218 0, ©_op)); 219 ASSERT_NE(nullptr, copy_op); 220 221 // Smart pointer to automatically delete copy_op. 222 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_copy_op(copy_op, xnn_delete_operator); 223 224 ASSERT_EQ(xnn_status_success, 225 xnn_setup_copy_nc_x32( 226 copy_op, 227 batch_size(), 228 input.data(), output.data(), 229 nullptr /* thread pool */)); 230 231 ASSERT_EQ(xnn_status_success, 232 xnn_run_operator(copy_op, nullptr /* thread pool */)); 233 234 // Verify results. 235 for (size_t i = 0; i < batch_size(); i++) { 236 for (size_t c = 0; c < channels(); c++) { 237 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 238 << "at batch " << i << " / " << batch_size() << ", channel = " << c << " / " << channels(); 239 } 240 } 241 } 242 } 243 244 private: 245 size_t batch_size_{1}; 246 size_t channels_{1}; 247 size_t input_stride_{0}; 248 size_t output_stride_{0}; 249 size_t iterations_{15}; 250 }; 251