1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <algorithm> 14 #include <cassert> 15 #include <cstddef> 16 #include <cstdlib> 17 #include <functional> 18 #include <limits> 19 #include <random> 20 #include <vector> 21 22 #include <xnnpack.h> 23 24 25 class ChannelShuffleOperatorTester { 26 public: groups(size_t groups)27 inline ChannelShuffleOperatorTester& groups(size_t groups) { 28 assert(groups != 0); 29 this->groups_ = groups; 30 return *this; 31 } 32 groups()33 inline size_t groups() const { 34 return this->groups_; 35 } 36 group_channels(size_t group_channels)37 inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) { 38 assert(group_channels != 0); 39 this->group_channels_ = group_channels; 40 return *this; 41 } 42 group_channels()43 inline size_t group_channels() const { 44 return this->group_channels_; 45 } 46 channels()47 inline size_t channels() const { 48 return groups() * group_channels(); 49 } 50 input_stride(size_t input_stride)51 inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) { 52 assert(input_stride != 0); 53 this->input_stride_ = input_stride; 54 return *this; 55 } 56 input_stride()57 inline size_t input_stride() const { 58 if (this->input_stride_ == 0) { 59 return channels(); 60 } else { 61 assert(this->input_stride_ >= channels()); 62 return this->input_stride_; 63 } 64 } 65 output_stride(size_t output_stride)66 inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) { 67 assert(output_stride != 0); 68 this->output_stride_ = output_stride; 69 return *this; 70 } 71 output_stride()72 inline size_t output_stride() const { 73 if (this->output_stride_ == 0) { 74 return channels(); 75 } else { 76 assert(this->output_stride_ >= channels()); 77 return this->output_stride_; 78 } 79 } 80 batch_size(size_t batch_size)81 inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) { 82 assert(batch_size != 0); 83 this->batch_size_ = batch_size; 84 return *this; 85 } 86 batch_size()87 inline size_t batch_size() const { 88 return this->batch_size_; 89 } 90 iterations(size_t iterations)91 inline ChannelShuffleOperatorTester& iterations(size_t iterations) { 92 this->iterations_ = iterations; 93 return *this; 94 } 95 iterations()96 inline size_t iterations() const { 97 return this->iterations_; 98 } 99 TestX8()100 void TestX8() const { 101 std::random_device random_device; 102 auto rng = std::mt19937(random_device()); 103 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng); 104 105 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels()); 106 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 107 for (size_t iteration = 0; iteration < iterations(); iteration++) { 108 std::generate(input.begin(), input.end(), std::ref(u8rng)); 109 std::fill(output.begin(), output.end(), 0xA5); 110 111 // Create, setup, run, and destroy Channel Shuffle operator. 112 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 113 xnn_operator_t channel_shuffle_op = nullptr; 114 115 ASSERT_EQ(xnn_status_success, 116 xnn_create_channel_shuffle_nc_x8( 117 groups(), group_channels(), 118 input_stride(), output_stride(), 119 0, &channel_shuffle_op)); 120 ASSERT_NE(nullptr, channel_shuffle_op); 121 122 // Smart pointer to automatically delete channel_shuffle_op. 123 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); 124 125 ASSERT_EQ(xnn_status_success, 126 xnn_setup_channel_shuffle_nc_x8( 127 channel_shuffle_op, 128 batch_size(), 129 input.data(), output.data(), 130 nullptr /* thread pool */)); 131 132 ASSERT_EQ(xnn_status_success, 133 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */)); 134 135 // Verify results. 136 for (size_t i = 0; i < batch_size(); i++) { 137 for (size_t g = 0; g < groups(); g++) { 138 for (size_t c = 0; c < group_channels(); c++) { 139 ASSERT_EQ(uint32_t(input[i * input_stride() + g * group_channels() + c]), 140 uint32_t(output[i * output_stride() + c * groups() + g])) 141 << "batch index " << i << ", group " << g << ", channel " << c; 142 } 143 } 144 } 145 } 146 } 147 TestX32()148 void TestX32() const { 149 std::random_device random_device; 150 auto rng = std::mt19937(random_device()); 151 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng); 152 153 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels()); 154 std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 155 for (size_t iteration = 0; iteration < iterations(); iteration++) { 156 std::generate(input.begin(), input.end(), std::ref(f32rng)); 157 std::fill(output.begin(), output.end(), std::nanf("")); 158 159 // Create, setup, run, and destroy Channel Shuffle operator. 160 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 161 xnn_operator_t channel_shuffle_op = nullptr; 162 163 ASSERT_EQ(xnn_status_success, 164 xnn_create_channel_shuffle_nc_x32( 165 groups(), group_channels(), 166 input_stride(), output_stride(), 167 0, &channel_shuffle_op)); 168 ASSERT_NE(nullptr, channel_shuffle_op); 169 170 // Smart pointer to automatically delete channel_shuffle_op. 171 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); 172 173 ASSERT_EQ(xnn_status_success, 174 xnn_setup_channel_shuffle_nc_x32( 175 channel_shuffle_op, 176 batch_size(), 177 input.data(), output.data(), 178 nullptr /* thread pool */)); 179 180 ASSERT_EQ(xnn_status_success, 181 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */)); 182 183 // Verify results. 184 for (size_t i = 0; i < batch_size(); i++) { 185 for (size_t g = 0; g < groups(); g++) { 186 for (size_t c = 0; c < group_channels(); c++) { 187 ASSERT_EQ(input[i * input_stride() + g * group_channels() + c], 188 output[i * output_stride() + c * groups() + g]) 189 << "batch index " << i << ", group " << g << ", channel " << c; 190 } 191 } 192 } 193 } 194 } 195 196 private: 197 size_t groups_{1}; 198 size_t group_channels_{1}; 199 size_t batch_size_{1}; 200 size_t input_stride_{0}; 201 size_t output_stride_{0}; 202 size_t iterations_{15}; 203 }; 204