• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cstddef>
16 #include <cstdlib>
17 #include <functional>
18 #include <limits>
19 #include <random>
20 #include <vector>
21 
22 #include <xnnpack.h>
23 
24 
25 class ChannelShuffleOperatorTester {
26  public:
groups(size_t groups)27   inline ChannelShuffleOperatorTester& groups(size_t groups) {
28     assert(groups != 0);
29     this->groups_ = groups;
30     return *this;
31   }
32 
groups()33   inline size_t groups() const {
34     return this->groups_;
35   }
36 
group_channels(size_t group_channels)37   inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) {
38     assert(group_channels != 0);
39     this->group_channels_ = group_channels;
40     return *this;
41   }
42 
group_channels()43   inline size_t group_channels() const {
44     return this->group_channels_;
45   }
46 
channels()47   inline size_t channels() const {
48     return groups() * group_channels();
49   }
50 
input_stride(size_t input_stride)51   inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) {
52     assert(input_stride != 0);
53     this->input_stride_ = input_stride;
54     return *this;
55   }
56 
input_stride()57   inline size_t input_stride() const {
58     if (this->input_stride_ == 0) {
59       return channels();
60     } else {
61       assert(this->input_stride_ >= channels());
62       return this->input_stride_;
63     }
64   }
65 
output_stride(size_t output_stride)66   inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) {
67     assert(output_stride != 0);
68     this->output_stride_ = output_stride;
69     return *this;
70   }
71 
output_stride()72   inline size_t output_stride() const {
73     if (this->output_stride_ == 0) {
74       return channels();
75     } else {
76       assert(this->output_stride_ >= channels());
77       return this->output_stride_;
78     }
79   }
80 
batch_size(size_t batch_size)81   inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) {
82     assert(batch_size != 0);
83     this->batch_size_ = batch_size;
84     return *this;
85   }
86 
batch_size()87   inline size_t batch_size() const {
88     return this->batch_size_;
89   }
90 
iterations(size_t iterations)91   inline ChannelShuffleOperatorTester& iterations(size_t iterations) {
92     this->iterations_ = iterations;
93     return *this;
94   }
95 
iterations()96   inline size_t iterations() const {
97     return this->iterations_;
98   }
99 
TestX8()100   void TestX8() const {
101     std::random_device random_device;
102     auto rng = std::mt19937(random_device());
103     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
104 
105     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels());
106     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
107     for (size_t iteration = 0; iteration < iterations(); iteration++) {
108       std::generate(input.begin(), input.end(), std::ref(u8rng));
109       std::fill(output.begin(), output.end(), 0xA5);
110 
111       // Create, setup, run, and destroy Channel Shuffle operator.
112       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
113       xnn_operator_t channel_shuffle_op = nullptr;
114 
115       ASSERT_EQ(xnn_status_success,
116         xnn_create_channel_shuffle_nc_x8(
117           groups(), group_channels(),
118           input_stride(), output_stride(),
119           0, &channel_shuffle_op));
120       ASSERT_NE(nullptr, channel_shuffle_op);
121 
122       // Smart pointer to automatically delete channel_shuffle_op.
123       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
124 
125       ASSERT_EQ(xnn_status_success,
126         xnn_setup_channel_shuffle_nc_x8(
127           channel_shuffle_op,
128           batch_size(),
129           input.data(), output.data(),
130           nullptr /* thread pool */));
131 
132       ASSERT_EQ(xnn_status_success,
133         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
134 
135       // Verify results.
136       for (size_t i = 0; i < batch_size(); i++) {
137         for (size_t g = 0; g < groups(); g++) {
138           for (size_t c = 0; c < group_channels(); c++) {
139             ASSERT_EQ(uint32_t(input[i * input_stride() + g * group_channels() + c]),
140                 uint32_t(output[i * output_stride() + c * groups() + g]))
141               << "batch index " << i << ", group " << g << ", channel " << c;
142           }
143         }
144       }
145     }
146   }
147 
TestX32()148   void TestX32() const {
149     std::random_device random_device;
150     auto rng = std::mt19937(random_device());
151     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
152 
153     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
154     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
155     for (size_t iteration = 0; iteration < iterations(); iteration++) {
156       std::generate(input.begin(), input.end(), std::ref(f32rng));
157       std::fill(output.begin(), output.end(), std::nanf(""));
158 
159       // Create, setup, run, and destroy Channel Shuffle operator.
160       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
161       xnn_operator_t channel_shuffle_op = nullptr;
162 
163       ASSERT_EQ(xnn_status_success,
164         xnn_create_channel_shuffle_nc_x32(
165           groups(), group_channels(),
166           input_stride(), output_stride(),
167           0, &channel_shuffle_op));
168       ASSERT_NE(nullptr, channel_shuffle_op);
169 
170       // Smart pointer to automatically delete channel_shuffle_op.
171       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
172 
173       ASSERT_EQ(xnn_status_success,
174         xnn_setup_channel_shuffle_nc_x32(
175           channel_shuffle_op,
176           batch_size(),
177           input.data(), output.data(),
178           nullptr /* thread pool */));
179 
180       ASSERT_EQ(xnn_status_success,
181         xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
182 
183       // Verify results.
184       for (size_t i = 0; i < batch_size(); i++) {
185         for (size_t g = 0; g < groups(); g++) {
186           for (size_t c = 0; c < group_channels(); c++) {
187             ASSERT_EQ(input[i * input_stride() + g * group_channels() + c],
188                 output[i * output_stride() + c * groups() + g])
189               << "batch index " << i << ", group " << g << ", channel " << c;
190           }
191         }
192       }
193     }
194   }
195 
196  private:
197   size_t groups_{1};
198   size_t group_channels_{1};
199   size_t batch_size_{1};
200   size_t input_stride_{0};
201   size_t output_stride_{0};
202   size_t iterations_{15};
203 };
204