1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15
16 #include <fp16.h>
17
18 #include <xnnpack.h>
19 #include <xnnpack/allocator.h>
20 #include <xnnpack/log.h>
21 #include <xnnpack/math.h>
22 #include <xnnpack/operator.h>
23 #include <xnnpack/pack.h>
24 #include <xnnpack/params.h>
25
26
create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_gemm_io_w_function pack_gemm_io_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,const void * packing_params,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,xnn_operator_t * fully_connected_op_out)27 static enum xnn_status create_fully_connected_nc(
28 size_t input_channels,
29 size_t output_channels,
30 size_t input_stride,
31 size_t output_stride,
32 const void* kernel,
33 const void* bias,
34 uint32_t flags,
35 uint32_t log2_filter_element_size,
36 uint32_t bias_element_size,
37 xnn_pack_gemm_io_w_function pack_gemm_io_w,
38 xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
39 const void* packing_params,
40 int packed_weights_padding_byte,
41 const void* params,
42 size_t params_size,
43 const struct gemm_parameters* gemm_parameters,
44 const struct gemm_fused_ukernels* gemm_ukernels,
45 uint32_t datatype_init_flags,
46 enum xnn_operator_type operator_type,
47 xnn_operator_t* fully_connected_op_out)
48 {
49 xnn_operator_t fully_connected_op = NULL;
50 enum xnn_status status = xnn_status_uninitialized;
51
52 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
53 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
54 xnn_operator_type_to_string(operator_type));
55 goto error;
56 }
57
58 status = xnn_status_unsupported_hardware;
59
60 if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
61 xnn_log_error(
62 "failed to create %s operator: operations on data type are not supported",
63 xnn_operator_type_to_string(operator_type));
64 goto error;
65 }
66
67 status = xnn_status_invalid_parameter;
68
69 if (input_channels == 0) {
70 xnn_log_error(
71 "failed to create %s operator with %zu input channels: number of channels must be non-zero",
72 xnn_operator_type_to_string(operator_type), input_channels);
73 goto error;
74 }
75
76 if (output_channels == 0) {
77 xnn_log_error(
78 "failed to create %s operator with %zu output channels: number of channels must be non-zero",
79 xnn_operator_type_to_string(operator_type), output_channels);
80 goto error;
81 }
82
83 if (input_stride < input_channels) {
84 xnn_log_error(
85 "failed to create %s operator with input element stride of %zu: "
86 "stride must be at least as large as the number of input channels (%zu)",
87 xnn_operator_type_to_string(operator_type), input_stride, input_channels);
88 goto error;
89 }
90
91 if (output_stride < output_channels) {
92 xnn_log_error(
93 "failed to create %s operator with output element stride of %zu: "
94 "stride must be at least as large as the number of output channels (%zu)",
95 xnn_operator_type_to_string(operator_type), output_stride, output_channels);
96 goto error;
97 }
98
99 status = xnn_status_out_of_memory;
100
101 fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
102 if (fully_connected_op == NULL) {
103 xnn_log_error(
104 "failed to allocate %zu bytes for %s operator descriptor",
105 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
106 goto error;
107 }
108
109 const uint32_t nr = gemm_parameters->nr;
110 const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
111 const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
112
113 const size_t n_stride = round_up(output_channels, nr);
114 const size_t k_stride = round_up_po2(input_channels, kr);
115
116 const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
117 fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
118 if (fully_connected_op->packed_weights == NULL) {
119 xnn_log_error(
120 "failed to allocate %zu bytes for %s operator packed weights",
121 packed_weights_size, xnn_operator_type_to_string(operator_type));
122 goto error;
123 }
124 memset(fully_connected_op->packed_weights, packed_weights_padding_byte, packed_weights_size);
125
126 if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
127 pack_gemm_io_w(
128 output_channels, input_channels,
129 nr, kr, sr,
130 kernel, bias,
131 fully_connected_op->packed_weights,
132 packing_params);
133 } else {
134 pack_gemm_goi_w(
135 1, output_channels, input_channels,
136 nr, kr, sr,
137 kernel, bias,
138 fully_connected_op->packed_weights,
139 0 /* extra bytes */,
140 packing_params);
141 }
142
143 fully_connected_op->group_input_channels = input_channels;
144 fully_connected_op->group_output_channels = output_channels;
145 fully_connected_op->input_pixel_stride = input_stride;
146 fully_connected_op->output_pixel_stride = output_stride;
147
148 memcpy(&fully_connected_op->params, params, params_size);
149 fully_connected_op->type = operator_type;
150 fully_connected_op->flags = flags;
151
152 fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
153 fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
154 .general_case = gemm_ukernels->gemm,
155 .mr1_case = gemm_ukernels->gemm1,
156 .mr = gemm_parameters->mr,
157 .nr = nr,
158 .kr = kr,
159 };
160
161 fully_connected_op->state = xnn_run_state_invalid;
162
163 *fully_connected_op_out = fully_connected_op;
164 return xnn_status_success;
165
166 error:
167 xnn_delete_operator(fully_connected_op);
168 return status;
169 }
170
setup_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,uint32_t datatype_init_flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)171 static enum xnn_status setup_fully_connected_nc(
172 xnn_operator_t fully_connected_op,
173 size_t batch_size,
174 const void* input,
175 void* output,
176 uint32_t datatype_init_flags,
177 uint32_t log2_input_element_size,
178 uint32_t log2_filter_element_size,
179 uint32_t bias_element_size,
180 uint32_t log2_output_element_size,
181 const void* params,
182 size_t params_size,
183 size_t num_threads)
184 {
185 fully_connected_op->state = xnn_run_state_invalid;
186
187 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
188 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
189 xnn_operator_type_to_string(fully_connected_op->type));
190 return xnn_status_uninitialized;
191 }
192
193 if (batch_size == 0) {
194 fully_connected_op->state = xnn_run_state_skip;
195 return xnn_status_success;
196 }
197
198 fully_connected_op->batch_size = 1;
199 fully_connected_op->input_height = batch_size;
200 fully_connected_op->input_width = 1;
201 fully_connected_op->input = input;
202
203 fully_connected_op->output_height = batch_size;
204 fully_connected_op->output_width = 1;
205 fully_connected_op->output = output;
206
207 const size_t input_channels = fully_connected_op->group_input_channels;
208 const size_t output_channels = fully_connected_op->group_output_channels;
209
210 uint32_t mr = fully_connected_op->ukernel.gemm.mr;
211 const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
212
213 struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.general_case;
214 if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
215 gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_case;
216 mr = 1;
217 }
218
219 fully_connected_op->context.gemm = (struct gemm_context) {
220 .k_scaled = input_channels << log2_input_element_size,
221 .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
222 .a = input,
223 .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
224 .packed_w = fully_connected_op->packed_weights,
225 .c = output,
226 .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
227 .cn_stride = nr << log2_output_element_size,
228 .log2_csize = log2_output_element_size,
229 .ukernel = gemm_ukernel,
230 };
231 memcpy(&fully_connected_op->context.gemm.params, params, params_size);
232
233 size_t nc = output_channels;
234 if (num_threads > 1) {
235 const size_t num_other_tiles = divide_round_up(batch_size, mr);
236 const size_t target_tiles_per_thread = 5;
237 const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
238 if (max_nc < nc) {
239 nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
240 }
241 }
242 fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
243 fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
244 fully_connected_op->compute.range[0] = batch_size;
245 fully_connected_op->compute.range[1] = output_channels;
246 fully_connected_op->compute.tile[0] = mr;
247 fully_connected_op->compute.tile[1] = nc;
248 fully_connected_op->state = xnn_run_state_ready;
249
250 return xnn_status_success;
251 }
252
xnn_create_fully_connected_nc_qu8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)253 enum xnn_status xnn_create_fully_connected_nc_qu8(
254 size_t input_channels,
255 size_t output_channels,
256 size_t input_stride,
257 size_t output_stride,
258 uint8_t input_zero_point,
259 float input_scale,
260 uint8_t kernel_zero_point,
261 float kernel_scale,
262 const uint8_t* kernel,
263 const int32_t* bias,
264 uint8_t output_zero_point,
265 float output_scale,
266 uint8_t output_min,
267 uint8_t output_max,
268 uint32_t flags,
269 xnn_operator_t* fully_connected_op_out)
270 {
271 if (input_scale <= 0.0f || !isnormal(input_scale)) {
272 xnn_log_error(
273 "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
274 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), input_scale);
275 return xnn_status_invalid_parameter;
276 }
277
278 if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
279 xnn_log_error(
280 "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
281 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), kernel_scale);
282 return xnn_status_invalid_parameter;
283 }
284
285 if (output_scale <= 0.0f || !isnormal(output_scale)) {
286 xnn_log_error(
287 "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
288 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_scale);
289 return xnn_status_invalid_parameter;
290 }
291
292 if (output_min >= output_max) {
293 xnn_log_error(
294 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
295 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_min, output_max);
296 return xnn_status_invalid_parameter;
297 }
298
299 const float requantization_scale = input_scale * kernel_scale / output_scale;
300 if (requantization_scale >= 256.0f) {
301 xnn_log_error(
302 "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
303 "requantization scale %.7g is greater or equal to 256.0",
304 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
305 input_scale, kernel_scale, output_scale, requantization_scale);
306 return xnn_status_unsupported_parameter;
307 }
308
309 union xnn_qu8_conv_minmax_params params;
310 if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
311 xnn_params.qu8.gemm.init.qu8(¶ms,
312 kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
313 }
314 const struct xnn_qu8_packing_params packing_params = {
315 .input_zero_point = input_zero_point,
316 .kernel_zero_point = kernel_zero_point,
317 };
318 return create_fully_connected_nc(
319 input_channels, output_channels,
320 input_stride, output_stride,
321 kernel, bias, flags,
322 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
323 sizeof(int32_t) /* sizeof(bias element) */,
324 (xnn_pack_gemm_io_w_function) xnn_pack_qu8_gemm_io_w,
325 (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
326 &packing_params, kernel_zero_point /* packed weights padding byte */,
327 ¶ms, sizeof(params),
328 &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
329 XNN_INIT_FLAG_QU8,
330 xnn_operator_type_fully_connected_nc_qu8,
331 fully_connected_op_out);
332 }
333
xnn_create_fully_connected_nc_qs8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)334 enum xnn_status xnn_create_fully_connected_nc_qs8(
335 size_t input_channels,
336 size_t output_channels,
337 size_t input_stride,
338 size_t output_stride,
339 int8_t input_zero_point,
340 float input_scale,
341 float kernel_scale,
342 const int8_t* kernel,
343 const int32_t* bias,
344 int8_t output_zero_point,
345 float output_scale,
346 int8_t output_min,
347 int8_t output_max,
348 uint32_t flags,
349 xnn_operator_t* fully_connected_op_out)
350 {
351 if (input_scale <= 0.0f || !isnormal(input_scale)) {
352 xnn_log_error(
353 "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
354 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), input_scale);
355 return xnn_status_invalid_parameter;
356 }
357
358 if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
359 xnn_log_error(
360 "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
361 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), kernel_scale);
362 return xnn_status_invalid_parameter;
363 }
364
365 if (output_scale <= 0.0f || !isnormal(output_scale)) {
366 xnn_log_error(
367 "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
368 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_scale);
369 return xnn_status_invalid_parameter;
370 }
371
372 if (output_min >= output_max) {
373 xnn_log_error(
374 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
375 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_min, output_max);
376 return xnn_status_invalid_parameter;
377 }
378
379 const float requantization_scale = input_scale * kernel_scale / output_scale;
380 if (requantization_scale >= 256.0f) {
381 xnn_log_error(
382 "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
383 "requantization scale %.7g is greater or equal to 256.0",
384 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
385 input_scale, kernel_scale, output_scale, requantization_scale);
386 return xnn_status_unsupported_parameter;
387 }
388
389 union xnn_qs8_conv_minmax_params params;
390 if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
391 xnn_params.qs8.gemm.init.qs8(¶ms, requantization_scale, output_zero_point, output_min, output_max);
392 }
393 const struct xnn_qs8_packing_params packing_params = {
394 .input_zero_point = input_zero_point,
395 };
396 return create_fully_connected_nc(
397 input_channels, output_channels,
398 input_stride, output_stride,
399 kernel, bias, flags,
400 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
401 sizeof(int32_t) /* sizeof(bias element) */,
402 (xnn_pack_gemm_io_w_function) xnn_pack_qs8_gemm_io_w,
403 (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
404 &packing_params, 0 /* packed weights padding byte */,
405 ¶ms, sizeof(params),
406 &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
407 XNN_INIT_FLAG_QS8,
408 xnn_operator_type_fully_connected_nc_qs8,
409 fully_connected_op_out);
410 }
411
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)412 enum xnn_status xnn_create_fully_connected_nc_f32(
413 size_t input_channels,
414 size_t output_channels,
415 size_t input_stride,
416 size_t output_stride,
417 const float* kernel,
418 const float* bias,
419 float output_min,
420 float output_max,
421 uint32_t flags,
422 xnn_operator_t* fully_connected_op_out)
423 {
424 if (isnan(output_min)) {
425 xnn_log_error(
426 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
427 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
428 return xnn_status_invalid_parameter;
429 }
430
431 if (isnan(output_max)) {
432 xnn_log_error(
433 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
434 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
435 return xnn_status_invalid_parameter;
436 }
437
438 if (output_min >= output_max) {
439 xnn_log_error(
440 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
441 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
442 return xnn_status_invalid_parameter;
443 }
444
445 const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
446 const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
447 if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
448 gemm_ukernels = &xnn_params.f32.gemm.linear;
449 }
450
451 union xnn_f32_minmax_params params;
452 if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
453 xnn_params.f32.gemm.init.f32(¶ms, output_min, output_max);
454 }
455 return create_fully_connected_nc(
456 input_channels, output_channels,
457 input_stride, output_stride,
458 kernel, bias, flags,
459 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
460 sizeof(float) /* sizeof(bias element) */,
461 (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
462 (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
463 NULL /* packing params */, 0 /* packed weights padding byte */,
464 ¶ms, sizeof(params),
465 &xnn_params.f32.gemm, gemm_ukernels,
466 XNN_INIT_FLAG_F32,
467 xnn_operator_type_fully_connected_nc_f32,
468 fully_connected_op_out);
469 }
470
xnn_create_fully_connected_nc_f16(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)471 enum xnn_status xnn_create_fully_connected_nc_f16(
472 size_t input_channels,
473 size_t output_channels,
474 size_t input_stride,
475 size_t output_stride,
476 const void* kernel,
477 const void* bias,
478 float output_min,
479 float output_max,
480 uint32_t flags,
481 xnn_operator_t* fully_connected_op_out)
482 {
483 if (isnan(output_min)) {
484 xnn_log_error(
485 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
486 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
487 return xnn_status_invalid_parameter;
488 }
489
490 if (isnan(output_max)) {
491 xnn_log_error(
492 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
493 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
494 return xnn_status_invalid_parameter;
495 }
496
497 const uint16_t fp16_output_min = fp16_ieee_from_fp32_value(output_min);
498 const uint16_t fp16_output_max = fp16_ieee_from_fp32_value(output_max);
499 const float rounded_output_min = fp16_ieee_to_fp32_value(fp16_output_min);
500 const float rounded_output_max = fp16_ieee_to_fp32_value(fp16_output_max);
501 if (rounded_output_min >= rounded_output_max) {
502 xnn_log_error(
503 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
504 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16), rounded_output_min, rounded_output_max);
505 return xnn_status_invalid_parameter;
506 }
507
508 union xnn_f16_scaleminmax_params params;
509 if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
510 xnn_params.f16.gemm.init.f16(¶ms, UINT16_C(0x3C00) /* 1.0 */, fp16_output_min, fp16_output_max);
511 }
512 xnn_pack_gemm_io_w_function pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f16_gemm_io_w;
513 xnn_pack_gemm_goi_w_function pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f16_gemm_goi_w;
514 if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
515 pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f32_to_f16_gemm_io_w;
516 pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f32_to_f16_gemm_goi_w;
517 }
518 return create_fully_connected_nc(
519 input_channels, output_channels,
520 input_stride, output_stride,
521 kernel, bias, flags,
522 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
523 sizeof(uint16_t) /* sizeof(bias element) */,
524 pack_gemm_io_w,
525 pack_gemm_goi_w,
526 NULL /* packing params */, 0 /* packed weights padding byte */,
527 ¶ms, sizeof(params),
528 &xnn_params.f16.gemm, &xnn_params.f16.gemm.minmax,
529 XNN_INIT_FLAG_F16,
530 xnn_operator_type_fully_connected_nc_f16,
531 fully_connected_op_out);
532 }
533
xnn_setup_fully_connected_nc_qu8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)534 enum xnn_status xnn_setup_fully_connected_nc_qu8(
535 xnn_operator_t fully_connected_op,
536 size_t batch_size,
537 const uint8_t* input,
538 uint8_t* output,
539 pthreadpool_t threadpool)
540 {
541 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qu8) {
542 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
543 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
544 xnn_operator_type_to_string(fully_connected_op->type));
545 return xnn_status_invalid_parameter;
546 }
547
548 return setup_fully_connected_nc(
549 fully_connected_op,
550 batch_size,
551 input, output,
552 XNN_INIT_FLAG_QU8,
553 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
554 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
555 sizeof(int32_t) /* sizeof(bias element) */,
556 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
557 &fully_connected_op->params.qu8_conv_minmax,
558 sizeof(fully_connected_op->params.qu8_conv_minmax),
559 pthreadpool_get_threads_count(threadpool));
560 }
561
xnn_setup_fully_connected_nc_qs8(xnn_operator_t fully_connected_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)562 enum xnn_status xnn_setup_fully_connected_nc_qs8(
563 xnn_operator_t fully_connected_op,
564 size_t batch_size,
565 const int8_t* input,
566 int8_t* output,
567 pthreadpool_t threadpool)
568 {
569 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qs8) {
570 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
571 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
572 xnn_operator_type_to_string(fully_connected_op->type));
573 return xnn_status_invalid_parameter;
574 }
575
576 return setup_fully_connected_nc(
577 fully_connected_op,
578 batch_size,
579 input, output,
580 XNN_INIT_FLAG_QS8,
581 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
582 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
583 sizeof(int32_t) /* sizeof(bias element) */,
584 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
585 &fully_connected_op->params.qs8_conv_minmax,
586 sizeof(fully_connected_op->params.qs8_conv_minmax),
587 pthreadpool_get_threads_count(threadpool));
588 }
589
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)590 enum xnn_status xnn_setup_fully_connected_nc_f32(
591 xnn_operator_t fully_connected_op,
592 size_t batch_size,
593 const float* input,
594 float* output,
595 pthreadpool_t threadpool)
596 {
597 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f32) {
598 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
599 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32),
600 xnn_operator_type_to_string(fully_connected_op->type));
601 return xnn_status_invalid_parameter;
602 }
603
604 return setup_fully_connected_nc(
605 fully_connected_op,
606 batch_size,
607 input, output,
608 XNN_INIT_FLAG_F32,
609 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
610 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
611 sizeof(float) /* sizeof(bias element) */,
612 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
613 &fully_connected_op->params.f32_minmax,
614 sizeof(fully_connected_op->params.f32_minmax),
615 pthreadpool_get_threads_count(threadpool));
616 }
617
xnn_setup_fully_connected_nc_f16(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)618 enum xnn_status xnn_setup_fully_connected_nc_f16(
619 xnn_operator_t fully_connected_op,
620 size_t batch_size,
621 const void* input,
622 void* output,
623 pthreadpool_t threadpool)
624 {
625 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f16) {
626 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
627 xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16),
628 xnn_operator_type_to_string(fully_connected_op->type));
629 return xnn_status_invalid_parameter;
630 }
631
632 return setup_fully_connected_nc(
633 fully_connected_op,
634 batch_size,
635 input, output,
636 XNN_INIT_FLAG_F32,
637 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
638 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
639 sizeof(uint16_t) /* sizeof(bias element) */,
640 1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
641 &fully_connected_op->params.f16_scaleminmax,
642 sizeof(fully_connected_op->params.f16_scaleminmax),
643 pthreadpool_get_threads_count(threadpool));
644 }
645