1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <xnnpack.h>
16 #include <xnnpack/allocator.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/params.h>
20
21
create_channel_shuffle_nc(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,enum xnn_operator_type operator_type,xnn_operator_t * channel_shuffle_op_out)22 static enum xnn_status create_channel_shuffle_nc(
23 size_t groups,
24 size_t group_channels,
25 size_t input_stride,
26 size_t output_stride,
27 uint32_t flags,
28 enum xnn_operator_type operator_type,
29 xnn_operator_t* channel_shuffle_op_out)
30 {
31 xnn_operator_t channel_shuffle_op = NULL;
32 enum xnn_status status = xnn_status_uninitialized;
33
34 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
35 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
36 xnn_operator_type_to_string(operator_type));
37 goto error;
38 }
39
40 status = xnn_status_invalid_parameter;
41
42 if (groups <= 1) {
43 xnn_log_error(
44 "failed to create %s operator with %zu groups: at least two groups required",
45 xnn_operator_type_to_string(operator_type), groups);
46 goto error;
47 }
48
49 if (group_channels == 0) {
50 xnn_log_error(
51 "failed to create %s operator with %zu group channels: number of group channels must be non-zero",
52 xnn_operator_type_to_string(operator_type), group_channels);
53 goto error;
54 }
55
56 const size_t channels = groups * group_channels;
57 if (input_stride < channels) {
58 xnn_log_error(
59 "failed to create %s operator with input element stride of %zu: "
60 "stride must be at least as large as the number of channels (%zux%zu)",
61 xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels);
62 goto error;
63 }
64
65 if (output_stride < channels) {
66 xnn_log_error(
67 "failed to create %s operator with output element stride of %zu: "
68 "stride must be at least as large as the number of channels (%zux%zu)",
69 xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels);
70 goto error;
71 }
72
73 status = xnn_status_out_of_memory;
74
75 channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
76 if (channel_shuffle_op == NULL) {
77 xnn_log_error(
78 "failed to allocate %zu bytes for %s operator descriptor",
79 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
80 goto error;
81 }
82
83 channel_shuffle_op->groups = groups;
84 channel_shuffle_op->group_channels = group_channels;
85 channel_shuffle_op->input_pixel_stride = input_stride;
86 channel_shuffle_op->output_pixel_stride = output_stride;
87
88 channel_shuffle_op->type = operator_type;
89
90 channel_shuffle_op->state = xnn_run_state_invalid;
91
92 *channel_shuffle_op_out = channel_shuffle_op;
93 return xnn_status_success;
94
95 error:
96 xnn_delete_operator(channel_shuffle_op);
97 return status;
98 }
99
100
xnn_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)101 enum xnn_status xnn_create_channel_shuffle_nc_x8(
102 size_t groups,
103 size_t group_channels,
104 size_t input_stride,
105 size_t output_stride,
106 uint32_t flags,
107 xnn_operator_t* channel_shuffle_op_out)
108 {
109 return create_channel_shuffle_nc(
110 groups,
111 group_channels,
112 input_stride,
113 output_stride,
114 flags,
115 xnn_operator_type_channel_shuffle_nc_x8,
116 channel_shuffle_op_out);
117 }
118
xnn_create_channel_shuffle_nc_x32(size_t groups,size_t group_channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * channel_shuffle_op_out)119 enum xnn_status xnn_create_channel_shuffle_nc_x32(
120 size_t groups,
121 size_t group_channels,
122 size_t input_stride,
123 size_t output_stride,
124 uint32_t flags,
125 xnn_operator_t* channel_shuffle_op_out)
126 {
127 return create_channel_shuffle_nc(
128 groups,
129 group_channels,
130 input_stride,
131 output_stride,
132 flags,
133 xnn_operator_type_channel_shuffle_nc_x32,
134 channel_shuffle_op_out);
135 }
136
setup_channel_shuffle_nc(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,uint32_t log2_element_size,const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS (1)])137 static enum xnn_status setup_channel_shuffle_nc(
138 xnn_operator_t channel_shuffle_op,
139 size_t batch_size,
140 const void* input,
141 void* output,
142 uint32_t log2_element_size,
143 const struct zip_parameters zip[restrict XNN_MIN_ELEMENTS(1)])
144 {
145 channel_shuffle_op->state = xnn_run_state_invalid;
146
147 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
148 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
149 xnn_operator_type_to_string(channel_shuffle_op->type));
150 return xnn_status_uninitialized;
151 }
152
153 if (batch_size == 0) {
154 channel_shuffle_op->state = xnn_run_state_skip;
155 return xnn_status_success;
156 }
157
158 channel_shuffle_op->batch_size = batch_size;
159 channel_shuffle_op->input = input;
160 channel_shuffle_op->output = output;
161
162 const size_t groups = channel_shuffle_op->groups;
163 channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) {
164 .x = input,
165 .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size,
166 .y = output,
167 .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size,
168 .n = channel_shuffle_op->group_channels << log2_element_size,
169 .m = groups,
170 };
171 channel_shuffle_op->compute.type = xnn_parallelization_type_1d;
172 channel_shuffle_op->compute.range[0] = batch_size;
173 switch (groups) {
174 case 2:
175 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
176 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2;
177 break;
178 case 3:
179 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
180 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3;
181 break;
182 case 4:
183 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed;
184 channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4;
185 break;
186 default:
187 channel_shuffle_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable;
188 channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm;
189 break;
190 case 0:
191 case 1:
192 XNN_UNREACHABLE;
193 }
194 channel_shuffle_op->state = xnn_run_state_ready;
195
196 return xnn_status_success;
197 }
198
xnn_setup_channel_shuffle_nc_x8(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)199 enum xnn_status xnn_setup_channel_shuffle_nc_x8(
200 xnn_operator_t channel_shuffle_op,
201 size_t batch_size,
202 const void* input,
203 void* output,
204 pthreadpool_t threadpool)
205 {
206 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) {
207 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
208 xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8),
209 xnn_operator_type_to_string(channel_shuffle_op->type));
210 return xnn_status_invalid_parameter;
211 }
212
213 return setup_channel_shuffle_nc(
214 channel_shuffle_op,
215 batch_size,
216 input,
217 output,
218 0 /* log2(sizeof(element)) = log2(sizeof(uint8_t)) */,
219 &xnn_params.x8.zip);
220 }
221
xnn_setup_channel_shuffle_nc_x32(xnn_operator_t channel_shuffle_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)222 enum xnn_status xnn_setup_channel_shuffle_nc_x32(
223 xnn_operator_t channel_shuffle_op,
224 size_t batch_size,
225 const void* input,
226 void* output,
227 pthreadpool_t threadpool)
228 {
229 if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) {
230 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
231 xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32),
232 xnn_operator_type_to_string(channel_shuffle_op->type));
233 return xnn_status_invalid_parameter;
234 }
235
236 return setup_channel_shuffle_nc(
237 channel_shuffle_op,
238 batch_size,
239 input,
240 output,
241 2 /* log2(sizeof(element)) = log2(sizeof(uint32_t)) */,
242 &xnn_params.x32.zip);
243 }
244