1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/params.h>
20
21
create_channel_shuffle_nc(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,enum xnn_operator_type operator_type,xnn_operator_t * channel_shuffle_op_out)22 static enum xnn_status create_channel_shuffle_nc(
23 size_t groups,
24 size_t group_channels,
25 size_t input_stride,
26 size_t output_stride,
27 uint32_t flags,
28 enum xnn_operator_type operator_type,
29 xnn_operator_t* channel_shuffle_op_out)
30 {
31 xnn_operator_t channel_shuffle_op = NULL;
32 enum xnn_status status = xnn_status_uninitialized;
33
34 if (!xnn_params.initialized) {
35 xnn_log_error("failed to create Channel Shuffle operator: XNNPACK is not initialized");
36 goto error;
37 }
38
39 status = xnn_status_invalid_parameter;
40
41 if (groups <= 1) {
42 xnn_log_error(
43 "failed to create Channel Shuffle operator with %zu groups: at least two groups required", groups);
44 goto error;
45 }
46
47 if (group_channels == 0) {
48 xnn_log_error(
49 "failed to create Channel Shuffle operator with %zu group channels: number of group channels must be non-zero",
50 group_channels);
51 goto error;
52 }
53
54 const size_t channels = groups * group_channels;
55 if (input_stride < channels) {
56 xnn_log_error(
57 "failed to create Channel Shuffle operator with input element stride of %zu: "
58 "stride must be at least as large as the number of channels (%zux%zu)",
59 input_stride, groups, group_channels);
60 goto error;
61 }
62
63 if (output_stride < channels) {
64 xnn_log_error(
65 "failed to create Channel Shuffle operator with output element stride of %zu: "
66 "stride must be at least as large as the number of channels (%zux%zu)",
67 output_stride, groups, group_channels);
68 goto error;
69 }
70
71 status = xnn_status_out_of_memory;
72
73 channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
74 if (channel_shuffle_op == NULL) {
75 xnn_log_error("failed to allocate %zu bytes for Channel Shuffle operator descriptor", sizeof(struct xnn_operator));
76 goto error;
77 }
78
79 channel_shuffle_op->groups = groups;
80 channel_shuffle_op->group_channels = group_channels;
81 channel_shuffle_op->input_pixel_stride = input_stride;
82 channel_shuffle_op->output_pixel_stride = output_stride;
83
84 channel_shuffle_op->type = operator_type;
85 channel_shuffle_op->ukernel.type = xnn_ukernel_type_channel_shuffle;
86
87 channel_shuffle_op->state = xnn_run_state_invalid;
88
89 *channel_shuffle_op_out = channel_shuffle_op;
90 return xnn_status_success;
91
92 error:
93 xnn_delete_operator(channel_shuffle_op);
94 return status;
95 }
96
97
xnn_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)98 enum xnn_status xnn_create_channel_shuffle_nc_x8(
99 size_t groups,
100 size_t group_channels,
101 size_t input_stride,
102 size_t output_stride,
103 uint32_t flags,
104 xnn_operator_t* channel_shuffle_op_out)
105 {
106 return create_channel_shuffle_nc(
107 groups,
108 group_channels,
109 input_stride,
110 output_stride,
111 flags,
112 xnn_operator_type_channel_shuffle_nc_x8,
113 channel_shuffle_op_out);
114 }
115
xnn_create_channel_shuffle_nc_x32(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)116 enum xnn_status xnn_create_channel_shuffle_nc_x32(
117 size_t groups,
118 size_t group_channels,
119 size_t input_stride,
120 size_t output_stride,
121 uint32_t flags,
122 xnn_operator_t* channel_shuffle_op_out)
123 {
124 return create_channel_shuffle_nc(
125 groups,
126 group_channels,
127 input_stride,
128 output_stride,
129 flags,
130 xnn_operator_type_channel_shuffle_nc_x32,
131 channel_shuffle_op_out);
132 }
133
setup_channel_shuffle_nc(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const struct zip_parameters zip[restrict static1])134 static enum xnn_status setup_channel_shuffle_nc(
135 xnn_operator_t channel_shuffle_op,
136 size_t batch_size,
137 const void* input,
138 void* output,
139 uint32_t log2_element_size,
140 const struct zip_parameters zip[restrict static 1])
141 {
142 channel_shuffle_op->state = xnn_run_state_invalid;
143
144 if (!xnn_params.initialized) {
145 xnn_log_error("failed to setup Channel Shuffle operator: XNNPACK is not initialized");
146 return xnn_status_uninitialized;
147 }
148
149 if (batch_size == 0) {
150 channel_shuffle_op->state = xnn_run_state_skip;
151 return xnn_status_success;
152 }
153
154 channel_shuffle_op->batch_size = batch_size;
155 channel_shuffle_op->input = input;
156 channel_shuffle_op->output = output;
157
158 const size_t groups = channel_shuffle_op->groups;
159 channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
160 .x = input,
161 .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
162 .y = output,
163 .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
164 .n = channel_shuffle_op->group_channels << log2_element_size,
165 .m = groups,
166 };
167 channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
168 channel_shuffle_op->compute.range[0] = batch_size;
169 switch (groups) {
170 case 2:
171 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
172 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
173 break;
174 case 3:
175 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
176 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
177 break;
178 case 4:
179 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
180 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
181 break;
182 default:
183 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
184 channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
185 break;
186 case 0:
187 case 1:
188 XNN_UNREACHABLE;
189 }
190 channel_shuffle_op->state = xnn_run_state_ready;
191
192 return xnn_status_success;
193 }
194
xnn_setup_channel_shuffle_nc_x8(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)195 enum xnn_status xnn_setup_channel_shuffle_nc_x8(
196 xnn_operator_t channel_shuffle_op,
197 size_t batch_size,
198 const void* input,
199 void* output,
200 pthreadpool_t threadpool)
201 {
202 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
203 xnn_log_error("failed to setup Channel Shuffle (NC, X8) operator: operator type mismatch");
204 return xnn_status_invalid_parameter;
205 }
206
207 return setup_channel_shuffle_nc(
208 channel_shuffle_op,
209 batch_size,
210 input,
211 output,
212 0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
213 &xnn_params.x8.zip);
214 }
215
xnn_setup_channel_shuffle_nc_x32(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)216 enum xnn_status xnn_setup_channel_shuffle_nc_x32(
217 xnn_operator_t channel_shuffle_op,
218 size_t batch_size,
219 const void* input,
220 void* output,
221 pthreadpool_t threadpool)
222 {
223 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
224 xnn_log_error("failed to setup Channel Shuffle (NC, X32) operator: operator type mismatch");
225 return xnn_status_invalid_parameter;
226 }
227
228 return setup_channel_shuffle_nc(
229 channel_shuffle_op,
230 batch_size,
231 input,
232 output,
233 2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
234 &xnn_params.x32.zip);
235 }
236