1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15
16 #include <xnnpack.h>
17 #include <xnnpack/allocator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/operator.h>
21 #include <xnnpack/pack.h>
22 #include <xnnpack/params-init.h>
23 #include <xnnpack/params.h>
24
25
create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_gemm_io_w_function pack_gemm_io_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,const void * packing_params,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,enum xnn_operator_type operator_type,xnn_operator_t * fully_connected_op_out)26 static enum xnn_status create_fully_connected_nc(
27 size_t input_channels,
28 size_t output_channels,
29 size_t input_stride,
30 size_t output_stride,
31 const void* kernel,
32 const void* bias,
33 uint32_t flags,
34 uint32_t log2_filter_element_size,
35 uint32_t bias_element_size,
36 xnn_pack_gemm_io_w_function pack_gemm_io_w,
37 xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
38 const void* packing_params,
39 int packed_weights_padding_byte,
40 const void* params,
41 size_t params_size,
42 const struct gemm_parameters* gemm_parameters,
43 const struct gemm_fused_ukernels* gemm_ukernels,
44 enum xnn_operator_type operator_type,
45 xnn_operator_t* fully_connected_op_out)
46 {
47 xnn_operator_t fully_connected_op = NULL;
48 enum xnn_status status = xnn_status_uninitialized;
49
50 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
51 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
52 xnn_operator_type_to_string(operator_type));
53 goto error;
54 }
55
56 status = xnn_status_invalid_parameter;
57
58 if (input_channels == 0) {
59 xnn_log_error(
60 "failed to create %s operator with %zu input channels: number of channels must be non-zero",
61 xnn_operator_type_to_string(operator_type), input_channels);
62 goto error;
63 }
64
65 if (output_channels == 0) {
66 xnn_log_error(
67 "failed to create %s operator with %zu output channels: number of channels must be non-zero",
68 xnn_operator_type_to_string(operator_type), output_channels);
69 goto error;
70 }
71
72 if (input_stride < input_channels) {
73 xnn_log_error(
74 "failed to create %s operator with input element stride of %zu: "
75 "stride must be at least as large as the number of input channels (%zu)",
76 xnn_operator_type_to_string(operator_type), input_stride, input_channels);
77 goto error;
78 }
79
80 if (output_stride < output_channels) {
81 xnn_log_error(
82 "failed to create %s operator with output element stride of %zu: "
83 "stride must be at least as large as the number of output channels (%zu)",
84 xnn_operator_type_to_string(operator_type), output_stride, output_channels);
85 goto error;
86 }
87
88 status = xnn_status_out_of_memory;
89
90 fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
91 if (fully_connected_op == NULL) {
92 xnn_log_error(
93 "failed to allocate %zu bytes for %s operator descriptor",
94 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
95 goto error;
96 }
97
98 const uint32_t nr = gemm_parameters->nr;
99 const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
100 const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
101
102 const size_t n_stride = round_up(output_channels, nr);
103 const size_t k_stride = round_up_po2(input_channels, kr);
104
105 const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
106 fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
107 if (fully_connected_op->packed_weights == NULL) {
108 xnn_log_error(
109 "failed to allocate %zu bytes for %s operator packed weights",
110 packed_weights_size, xnn_operator_type_to_string(operator_type));
111 goto error;
112 }
113 memset(fully_connected_op->packed_weights, packed_weights_padding_byte, packed_weights_size);
114
115 if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
116 pack_gemm_io_w(
117 output_channels, input_channels,
118 nr, kr, sr,
119 kernel, bias,
120 fully_connected_op->packed_weights,
121 packing_params);
122 } else {
123 pack_gemm_goi_w(
124 1, output_channels, input_channels,
125 nr, kr, sr,
126 kernel, bias,
127 fully_connected_op->packed_weights,
128 packing_params);
129 }
130
131 fully_connected_op->group_input_channels = input_channels;
132 fully_connected_op->group_output_channels = output_channels;
133 fully_connected_op->input_pixel_stride = input_stride;
134 fully_connected_op->output_pixel_stride = output_stride;
135
136 memcpy(&fully_connected_op->params, params, params_size);
137 fully_connected_op->type = operator_type;
138
139 fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
140 fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
141 .general_case = gemm_ukernels->gemm,
142 .mr1_case = gemm_ukernels->gemm1,
143 .mr = gemm_parameters->mr,
144 .nr = nr,
145 .kr = kr,
146 };
147
148 fully_connected_op->state = xnn_run_state_invalid;
149
150 *fully_connected_op_out = fully_connected_op;
151 return xnn_status_success;
152
153 error:
154 xnn_delete_operator(fully_connected_op);
155 return status;
156 }
157
setup_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)158 static enum xnn_status setup_fully_connected_nc(
159 xnn_operator_t fully_connected_op,
160 size_t batch_size,
161 const void* input,
162 void* output,
163 uint32_t log2_input_element_size,
164 uint32_t log2_filter_element_size,
165 uint32_t bias_element_size,
166 uint32_t log2_output_element_size,
167 const void* params,
168 size_t params_size,
169 size_t num_threads)
170 {
171 fully_connected_op->state = xnn_run_state_invalid;
172
173 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
174 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
175 xnn_operator_type_to_string(fully_connected_op->type));
176 return xnn_status_uninitialized;
177 }
178
179 if (batch_size == 0) {
180 fully_connected_op->state = xnn_run_state_skip;
181 return xnn_status_success;
182 }
183
184 fully_connected_op->batch_size = 1;
185 fully_connected_op->input_height = batch_size;
186 fully_connected_op->input_width = 1;
187 fully_connected_op->input = input;
188
189 fully_connected_op->output_height = batch_size;
190 fully_connected_op->output_width = 1;
191 fully_connected_op->output = output;
192
193 const size_t input_channels = fully_connected_op->group_input_channels;
194 const size_t output_channels = fully_connected_op->group_output_channels;
195
196 uint32_t mr = fully_connected_op->ukernel.gemm.mr;
197 const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
198
199 struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.general_case;
200 if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
201 gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_case;
202 mr = 1;
203 }
204
205 fully_connected_op->context.gemm = (struct gemm_context) {
206 .k_scaled = input_channels << log2_input_element_size,
207 .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
208 .a = input,
209 .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
210 .packed_w = fully_connected_op->packed_weights,
211 .c = output,
212 .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
213 .cn_stride = nr << log2_output_element_size,
214 .log2_csize = log2_output_element_size,
215 .ukernel = gemm_ukernel,
216 };
217 memcpy(&fully_connected_op->context.gemm.params, params, params_size);
218
219 size_t nc = output_channels;
220 if (num_threads > 1) {
221 const size_t num_other_tiles = divide_round_up(batch_size, mr);
222 const size_t target_tiles_per_thread = 5;
223 const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
224 if (max_nc < nc) {
225 nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
226 }
227 }
228 fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
229 fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
230 fully_connected_op->compute.range[0] = batch_size;
231 fully_connected_op->compute.range[1] = output_channels;
232 fully_connected_op->compute.tile[0] = mr;
233 fully_connected_op->compute.tile[1] = nc;
234 fully_connected_op->state = xnn_run_state_ready;
235
236 return xnn_status_success;
237 }
238
xnn_create_fully_connected_nc_qu8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)239 enum xnn_status xnn_create_fully_connected_nc_qu8(
240 size_t input_channels,
241 size_t output_channels,
242 size_t input_stride,
243 size_t output_stride,
244 uint8_t input_zero_point,
245 float input_scale,
246 uint8_t kernel_zero_point,
247 float kernel_scale,
248 const uint8_t* kernel,
249 const int32_t* bias,
250 uint8_t output_zero_point,
251 float output_scale,
252 uint8_t output_min,
253 uint8_t output_max,
254 uint32_t flags,
255 xnn_operator_t* fully_connected_op_out)
256 {
257 if (input_scale <= 0.0f || !isnormal(input_scale)) {
258 xnn_log_error(
259 "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
260 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), input_scale);
261 return xnn_status_invalid_parameter;
262 }
263
264 if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
265 xnn_log_error(
266 "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
267 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), kernel_scale);
268 return xnn_status_invalid_parameter;
269 }
270
271 if (output_scale <= 0.0f || !isnormal(output_scale)) {
272 xnn_log_error(
273 "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
274 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_scale);
275 return xnn_status_invalid_parameter;
276 }
277
278 if (output_min >= output_max) {
279 xnn_log_error(
280 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
281 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_min, output_max);
282 return xnn_status_invalid_parameter;
283 }
284
285 const float requantization_scale = input_scale * kernel_scale / output_scale;
286 if (requantization_scale >= 1.0f) {
287 xnn_log_error(
288 "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
289 "requantization scale %.7g is greater or equal to 1.0",
290 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
291 input_scale, kernel_scale, output_scale, requantization_scale);
292 return xnn_status_unsupported_parameter;
293 }
294
295 const union xnn_qu8_gemm_params params = xnn_init_qu8_gemm_params(
296 kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
297 const struct xnn_qu8_packing_params packing_params = {
298 .input_zero_point = input_zero_point,
299 .kernel_zero_point = kernel_zero_point,
300 };
301 return create_fully_connected_nc(
302 input_channels, output_channels,
303 input_stride, output_stride,
304 kernel, bias, flags,
305 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
306 sizeof(int32_t) /* sizeof(bias element) */,
307 (xnn_pack_gemm_io_w_function) xnn_pack_qu8_gemm_io_w,
308 (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
309 &packing_params, kernel_zero_point /* packed weights padding byte */,
310 ¶ms, sizeof(params),
311 &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
312 xnn_operator_type_fully_connected_nc_qu8,
313 fully_connected_op_out);
314 }
315
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)316 enum xnn_status xnn_create_fully_connected_nc_f32(
317 size_t input_channels,
318 size_t output_channels,
319 size_t input_stride,
320 size_t output_stride,
321 const float* kernel,
322 const float* bias,
323 float output_min,
324 float output_max,
325 uint32_t flags,
326 xnn_operator_t* fully_connected_op_out)
327 {
328 if (isnan(output_min)) {
329 xnn_log_error(
330 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
331 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
332 return xnn_status_invalid_parameter;
333 }
334
335 if (isnan(output_max)) {
336 xnn_log_error(
337 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
338 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
339 return xnn_status_invalid_parameter;
340 }
341
342 if (output_min >= output_max) {
343 xnn_log_error(
344 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
345 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
346 return xnn_status_invalid_parameter;
347 }
348
349 const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
350 const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
351 if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
352 gemm_ukernels = &xnn_params.f32.gemm.linear;
353 }
354
355 const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
356 return create_fully_connected_nc(
357 input_channels, output_channels,
358 input_stride, output_stride,
359 kernel, bias, flags,
360 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
361 sizeof(float) /* sizeof(bias element) */,
362 (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
363 (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
364 NULL /* packing params */, 0 /* packed weights padding byte */,
365 ¶ms, sizeof(params),
366 &xnn_params.f32.gemm, gemm_ukernels,
367 xnn_operator_type_fully_connected_nc_f32,
368 fully_connected_op_out);
369 }
370
xnn_setup_fully_connected_nc_qu8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)371 enum xnn_status xnn_setup_fully_connected_nc_qu8(
372 xnn_operator_t fully_connected_op,
373 size_t batch_size,
374 const uint8_t* input,
375 uint8_t* output,
376 pthreadpool_t threadpool)
377 {
378 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qu8) {
379 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
380 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
381 xnn_operator_type_to_string(fully_connected_op->type));
382 return xnn_status_invalid_parameter;
383 }
384
385 return setup_fully_connected_nc(
386 fully_connected_op,
387 batch_size,
388 input, output,
389 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
390 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
391 sizeof(int32_t) /* sizeof(bias element) */,
392 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
393 &fully_connected_op->params.qu8_gemm,
394 sizeof(fully_connected_op->params.qu8_gemm),
395 pthreadpool_get_threads_count(threadpool));
396 }
397
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)398 enum xnn_status xnn_setup_fully_connected_nc_f32(
399 xnn_operator_t fully_connected_op,
400 size_t batch_size,
401 const float* input,
402 float* output,
403 pthreadpool_t threadpool)
404 {
405 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f32) {
406 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
407 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32),
408 xnn_operator_type_to_string(fully_connected_op->type));
409 return xnn_status_invalid_parameter;
410 }
411
412 return setup_fully_connected_nc(
413 fully_connected_op,
414 batch_size,
415 input, output,
416 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
417 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
418 sizeof(float) /* sizeof(bias element) */,
419 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
420 &fully_connected_op->params.f32_minmax,
421 sizeof(fully_connected_op->params.f32_minmax),
422 pthreadpool_get_threads_count(threadpool));
423 }
424