• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15 
16 #include <fp16.h>
17 
18 #include <xnnpack.h>
19 #include <xnnpack/allocator.h>
20 #include <xnnpack/log.h>
21 #include <xnnpack/math.h>
22 #include <xnnpack/operator.h>
23 #include <xnnpack/pack.h>
24 #include <xnnpack/params.h>
25 
26 
create_fully_connected_nc(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_gemm_io_w_function pack_gemm_io_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,const void * packing_params,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,xnn_operator_t * fully_connected_op_out)27 static enum xnn_status create_fully_connected_nc(
28     size_t input_channels,
29     size_t output_channels,
30     size_t input_stride,
31     size_t output_stride,
32     const void* kernel,
33     const void* bias,
34     uint32_t flags,
35     uint32_t log2_filter_element_size,
36     uint32_t bias_element_size,
37     xnn_pack_gemm_io_w_function pack_gemm_io_w,
38     xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
39     const void* packing_params,
40     int packed_weights_padding_byte,
41     const void* params,
42     size_t params_size,
43     const struct gemm_parameters* gemm_parameters,
44     const struct gemm_fused_ukernels* gemm_ukernels,
45     uint32_t datatype_init_flags,
46     enum xnn_operator_type operator_type,
47     xnn_operator_t* fully_connected_op_out)
48 {
49   xnn_operator_t fully_connected_op = NULL;
50   enum xnn_status status = xnn_status_uninitialized;
51 
52   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
53     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
54       xnn_operator_type_to_string(operator_type));
55     goto error;
56   }
57 
58   status = xnn_status_unsupported_hardware;
59 
60   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
61     xnn_log_error(
62       "failed to create %s operator: operations on data type are not supported",
63       xnn_operator_type_to_string(operator_type));
64     goto error;
65   }
66 
67   status = xnn_status_invalid_parameter;
68 
69   if (input_channels == 0) {
70     xnn_log_error(
71       "failed to create %s operator with %zu input channels: number of channels must be non-zero",
72       xnn_operator_type_to_string(operator_type), input_channels);
73     goto error;
74   }
75 
76   if (output_channels == 0) {
77     xnn_log_error(
78       "failed to create %s operator with %zu output channels: number of channels must be non-zero",
79       xnn_operator_type_to_string(operator_type), output_channels);
80     goto error;
81   }
82 
83   if (input_stride < input_channels) {
84     xnn_log_error(
85       "failed to create %s operator with input element stride of %zu: "
86       "stride must be at least as large as the number of input channels (%zu)",
87       xnn_operator_type_to_string(operator_type), input_stride, input_channels);
88     goto error;
89   }
90 
91   if (output_stride < output_channels) {
92     xnn_log_error(
93       "failed to create %s operator with output element stride of %zu: "
94       "stride must be at least as large as the number of output channels (%zu)",
95       xnn_operator_type_to_string(operator_type), output_stride, output_channels);
96     goto error;
97   }
98 
99   status = xnn_status_out_of_memory;
100 
101   fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
102   if (fully_connected_op == NULL) {
103     xnn_log_error(
104       "failed to allocate %zu bytes for %s operator descriptor",
105       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
106     goto error;
107   }
108 
109   const uint32_t nr = gemm_parameters->nr;
110   const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
111   const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
112 
113   const size_t n_stride = round_up(output_channels, nr);
114   const size_t k_stride = round_up_po2(input_channels, kr);
115 
116   const size_t packed_weights_size = n_stride * (bias_element_size + (k_stride << log2_filter_element_size));
117   fully_connected_op->packed_weights = xnn_allocate_simd_memory(packed_weights_size);
118   if (fully_connected_op->packed_weights == NULL) {
119     xnn_log_error(
120       "failed to allocate %zu bytes for %s operator packed weights",
121       packed_weights_size, xnn_operator_type_to_string(operator_type));
122     goto error;
123   }
124   memset(fully_connected_op->packed_weights, packed_weights_padding_byte, packed_weights_size);
125 
126   if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
127     pack_gemm_io_w(
128       output_channels, input_channels,
129       nr, kr, sr,
130       kernel, bias,
131       fully_connected_op->packed_weights,
132       packing_params);
133   } else {
134     pack_gemm_goi_w(
135       1, output_channels, input_channels,
136       nr, kr, sr,
137       kernel, bias,
138       fully_connected_op->packed_weights,
139       0 /* extra bytes */,
140       packing_params);
141   }
142 
143   fully_connected_op->group_input_channels = input_channels;
144   fully_connected_op->group_output_channels = output_channels;
145   fully_connected_op->input_pixel_stride = input_stride;
146   fully_connected_op->output_pixel_stride = output_stride;
147 
148   memcpy(&fully_connected_op->params, params, params_size);
149   fully_connected_op->type = operator_type;
150   fully_connected_op->flags = flags;
151 
152   fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
153   fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
154     .general_case = gemm_ukernels->gemm,
155     .mr1_case = gemm_ukernels->gemm1,
156     .mr = gemm_parameters->mr,
157     .nr = nr,
158     .kr = kr,
159   };
160 
161   fully_connected_op->state = xnn_run_state_invalid;
162 
163   *fully_connected_op_out = fully_connected_op;
164   return xnn_status_success;
165 
166 error:
167   xnn_delete_operator(fully_connected_op);
168   return status;
169 }
170 
setup_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,uint32_t datatype_init_flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)171 static enum xnn_status setup_fully_connected_nc(
172   xnn_operator_t fully_connected_op,
173   size_t batch_size,
174   const void* input,
175   void* output,
176   uint32_t datatype_init_flags,
177   uint32_t log2_input_element_size,
178   uint32_t log2_filter_element_size,
179   uint32_t bias_element_size,
180   uint32_t log2_output_element_size,
181   const void* params,
182   size_t params_size,
183   size_t num_threads)
184 {
185   fully_connected_op->state = xnn_run_state_invalid;
186 
187   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
188     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
189       xnn_operator_type_to_string(fully_connected_op->type));
190     return xnn_status_uninitialized;
191   }
192 
193   if (batch_size == 0) {
194     fully_connected_op->state = xnn_run_state_skip;
195     return xnn_status_success;
196   }
197 
198   fully_connected_op->batch_size = 1;
199   fully_connected_op->input_height = batch_size;
200   fully_connected_op->input_width = 1;
201   fully_connected_op->input = input;
202 
203   fully_connected_op->output_height = batch_size;
204   fully_connected_op->output_width = 1;
205   fully_connected_op->output = output;
206 
207   const size_t input_channels = fully_connected_op->group_input_channels;
208   const size_t output_channels = fully_connected_op->group_output_channels;
209 
210   uint32_t mr = fully_connected_op->ukernel.gemm.mr;
211   const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
212 
213   struct xnn_hmp_gemm_ukernel gemm_ukernel = fully_connected_op->ukernel.gemm.general_case;
214   if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_case.function[XNN_UARCH_DEFAULT] != NULL) {
215     gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_case;
216     mr = 1;
217   }
218 
219   fully_connected_op->context.gemm = (struct gemm_context) {
220     .k_scaled = input_channels << log2_input_element_size,
221     .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
222     .a = input,
223     .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
224     .packed_w = fully_connected_op->packed_weights,
225     .c = output,
226     .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
227     .cn_stride = nr << log2_output_element_size,
228     .log2_csize = log2_output_element_size,
229     .ukernel = gemm_ukernel,
230   };
231   memcpy(&fully_connected_op->context.gemm.params, params, params_size);
232 
233   size_t nc = output_channels;
234   if (num_threads > 1) {
235     const size_t num_other_tiles = divide_round_up(batch_size, mr);
236     const size_t target_tiles_per_thread = 5;
237     const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
238     if (max_nc < nc) {
239       nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
240     }
241   }
242   fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
243   fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
244   fully_connected_op->compute.range[0] = batch_size;
245   fully_connected_op->compute.range[1] = output_channels;
246   fully_connected_op->compute.tile[0] = mr;
247   fully_connected_op->compute.tile[1] = nc;
248   fully_connected_op->state = xnn_run_state_ready;
249 
250   return xnn_status_success;
251 }
252 
xnn_create_fully_connected_nc_qu8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)253 enum xnn_status xnn_create_fully_connected_nc_qu8(
254     size_t input_channels,
255     size_t output_channels,
256     size_t input_stride,
257     size_t output_stride,
258     uint8_t input_zero_point,
259     float input_scale,
260     uint8_t kernel_zero_point,
261     float kernel_scale,
262     const uint8_t* kernel,
263     const int32_t* bias,
264     uint8_t output_zero_point,
265     float output_scale,
266     uint8_t output_min,
267     uint8_t output_max,
268     uint32_t flags,
269     xnn_operator_t* fully_connected_op_out)
270 {
271   if (input_scale <= 0.0f || !isnormal(input_scale)) {
272     xnn_log_error(
273       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
274       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), input_scale);
275     return xnn_status_invalid_parameter;
276   }
277 
278   if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
279     xnn_log_error(
280       "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
281       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), kernel_scale);
282     return xnn_status_invalid_parameter;
283   }
284 
285   if (output_scale <= 0.0f || !isnormal(output_scale)) {
286     xnn_log_error(
287       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
288       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_scale);
289     return xnn_status_invalid_parameter;
290   }
291 
292   if (output_min >= output_max) {
293     xnn_log_error(
294       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
295       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8), output_min, output_max);
296     return xnn_status_invalid_parameter;
297   }
298 
299   const float requantization_scale = input_scale * kernel_scale / output_scale;
300   if (requantization_scale >= 256.0f) {
301     xnn_log_error(
302       "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
303       "requantization scale %.7g is greater or equal to 256.0",
304       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
305       input_scale, kernel_scale, output_scale, requantization_scale);
306     return xnn_status_unsupported_parameter;
307   }
308 
309   union xnn_qu8_conv_minmax_params params;
310   if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
311     xnn_params.qu8.gemm.init.qu8(&params,
312       kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
313   }
314   const struct xnn_qu8_packing_params packing_params = {
315     .input_zero_point = input_zero_point,
316     .kernel_zero_point = kernel_zero_point,
317   };
318   return create_fully_connected_nc(
319     input_channels, output_channels,
320     input_stride, output_stride,
321     kernel, bias, flags,
322     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
323     sizeof(int32_t) /* sizeof(bias element) */,
324     (xnn_pack_gemm_io_w_function) xnn_pack_qu8_gemm_io_w,
325     (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
326     &packing_params, kernel_zero_point /* packed weights padding byte */,
327     &params, sizeof(params),
328     &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
329     XNN_INIT_FLAG_QU8,
330     xnn_operator_type_fully_connected_nc_qu8,
331     fully_connected_op_out);
332 }
333 
xnn_create_fully_connected_nc_qs8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)334 enum xnn_status xnn_create_fully_connected_nc_qs8(
335     size_t input_channels,
336     size_t output_channels,
337     size_t input_stride,
338     size_t output_stride,
339     int8_t input_zero_point,
340     float input_scale,
341     float kernel_scale,
342     const int8_t* kernel,
343     const int32_t* bias,
344     int8_t output_zero_point,
345     float output_scale,
346     int8_t output_min,
347     int8_t output_max,
348     uint32_t flags,
349     xnn_operator_t* fully_connected_op_out)
350 {
351   if (input_scale <= 0.0f || !isnormal(input_scale)) {
352     xnn_log_error(
353       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
354       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), input_scale);
355     return xnn_status_invalid_parameter;
356   }
357 
358   if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
359     xnn_log_error(
360       "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
361       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), kernel_scale);
362     return xnn_status_invalid_parameter;
363   }
364 
365   if (output_scale <= 0.0f || !isnormal(output_scale)) {
366     xnn_log_error(
367       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
368       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_scale);
369     return xnn_status_invalid_parameter;
370   }
371 
372   if (output_min >= output_max) {
373     xnn_log_error(
374       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
375       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8), output_min, output_max);
376     return xnn_status_invalid_parameter;
377   }
378 
379   const float requantization_scale = input_scale * kernel_scale / output_scale;
380   if (requantization_scale >= 256.0f) {
381     xnn_log_error(
382       "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
383       "requantization scale %.7g is greater or equal to 256.0",
384       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
385       input_scale, kernel_scale, output_scale, requantization_scale);
386     return xnn_status_unsupported_parameter;
387   }
388 
389   union xnn_qs8_conv_minmax_params params;
390   if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
391     xnn_params.qs8.gemm.init.qs8(&params, requantization_scale, output_zero_point, output_min, output_max);
392   }
393   const struct xnn_qs8_packing_params packing_params = {
394     .input_zero_point = input_zero_point,
395   };
396   return create_fully_connected_nc(
397     input_channels, output_channels,
398     input_stride, output_stride,
399     kernel, bias, flags,
400     0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
401     sizeof(int32_t) /* sizeof(bias element) */,
402     (xnn_pack_gemm_io_w_function) xnn_pack_qs8_gemm_io_w,
403     (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
404     &packing_params, 0 /* packed weights padding byte */,
405     &params, sizeof(params),
406     &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
407     XNN_INIT_FLAG_QS8,
408     xnn_operator_type_fully_connected_nc_qs8,
409     fully_connected_op_out);
410 }
411 
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)412 enum xnn_status xnn_create_fully_connected_nc_f32(
413     size_t input_channels,
414     size_t output_channels,
415     size_t input_stride,
416     size_t output_stride,
417     const float* kernel,
418     const float* bias,
419     float output_min,
420     float output_max,
421     uint32_t flags,
422     xnn_operator_t* fully_connected_op_out)
423 {
424   if (isnan(output_min)) {
425     xnn_log_error(
426       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
427       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
428     return xnn_status_invalid_parameter;
429   }
430 
431   if (isnan(output_max)) {
432     xnn_log_error(
433       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
434       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32));
435     return xnn_status_invalid_parameter;
436   }
437 
438   if (output_min >= output_max) {
439     xnn_log_error(
440       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
441       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32), output_min, output_max);
442     return xnn_status_invalid_parameter;
443   }
444 
445   const struct gemm_fused_ukernels* gemm_ukernels = &xnn_params.f32.gemm.minmax;
446   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
447   if (linear_activation && xnn_params.f32.gemm.linear.gemm.function[XNN_UARCH_DEFAULT] != NULL) {
448     gemm_ukernels = &xnn_params.f32.gemm.linear;
449   }
450 
451   union xnn_f32_minmax_params params;
452   if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
453     xnn_params.f32.gemm.init.f32(&params, output_min, output_max);
454   }
455   return create_fully_connected_nc(
456     input_channels, output_channels,
457     input_stride, output_stride,
458     kernel, bias, flags,
459     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
460     sizeof(float) /* sizeof(bias element) */,
461     (xnn_pack_gemm_io_w_function) xnn_pack_f32_gemm_io_w,
462     (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
463     NULL /* packing params */, 0 /* packed weights padding byte */,
464     &params, sizeof(params),
465     &xnn_params.f32.gemm, gemm_ukernels,
466     XNN_INIT_FLAG_F32,
467     xnn_operator_type_fully_connected_nc_f32,
468     fully_connected_op_out);
469 }
470 
xnn_create_fully_connected_nc_f16(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)471 enum xnn_status xnn_create_fully_connected_nc_f16(
472     size_t input_channels,
473     size_t output_channels,
474     size_t input_stride,
475     size_t output_stride,
476     const void* kernel,
477     const void* bias,
478     float output_min,
479     float output_max,
480     uint32_t flags,
481     xnn_operator_t* fully_connected_op_out)
482 {
483   if (isnan(output_min)) {
484     xnn_log_error(
485       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
486       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
487     return xnn_status_invalid_parameter;
488   }
489 
490   if (isnan(output_max)) {
491     xnn_log_error(
492       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
493       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
494     return xnn_status_invalid_parameter;
495   }
496 
497   const uint16_t fp16_output_min = fp16_ieee_from_fp32_value(output_min);
498   const uint16_t fp16_output_max = fp16_ieee_from_fp32_value(output_max);
499   const float rounded_output_min = fp16_ieee_to_fp32_value(fp16_output_min);
500   const float rounded_output_max = fp16_ieee_to_fp32_value(fp16_output_max);
501   if (rounded_output_min >= rounded_output_max) {
502     xnn_log_error(
503       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
504       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16), rounded_output_min, rounded_output_max);
505     return xnn_status_invalid_parameter;
506   }
507 
508   union xnn_f16_scaleminmax_params params;
509   if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
510     xnn_params.f16.gemm.init.f16(&params, UINT16_C(0x3C00) /* 1.0 */, fp16_output_min, fp16_output_max);
511   }
512   xnn_pack_gemm_io_w_function pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f16_gemm_io_w;
513   xnn_pack_gemm_goi_w_function pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f16_gemm_goi_w;
514   if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
515     pack_gemm_io_w = (xnn_pack_gemm_io_w_function) xnn_pack_f32_to_f16_gemm_io_w;
516     pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f32_to_f16_gemm_goi_w;
517   }
518   return create_fully_connected_nc(
519     input_channels, output_channels,
520     input_stride, output_stride,
521     kernel, bias, flags,
522     1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
523     sizeof(uint16_t) /* sizeof(bias element) */,
524     pack_gemm_io_w,
525     pack_gemm_goi_w,
526     NULL /* packing params */, 0 /* packed weights padding byte */,
527     &params, sizeof(params),
528     &xnn_params.f16.gemm, &xnn_params.f16.gemm.minmax,
529     XNN_INIT_FLAG_F16,
530     xnn_operator_type_fully_connected_nc_f16,
531     fully_connected_op_out);
532 }
533 
xnn_setup_fully_connected_nc_qu8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)534 enum xnn_status xnn_setup_fully_connected_nc_qu8(
535     xnn_operator_t fully_connected_op,
536     size_t batch_size,
537     const uint8_t* input,
538     uint8_t* output,
539     pthreadpool_t threadpool)
540 {
541   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qu8) {
542     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
543       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qu8),
544       xnn_operator_type_to_string(fully_connected_op->type));
545     return xnn_status_invalid_parameter;
546   }
547 
548   return setup_fully_connected_nc(
549     fully_connected_op,
550     batch_size,
551     input, output,
552     XNN_INIT_FLAG_QU8,
553     0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
554     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
555     sizeof(int32_t) /* sizeof(bias element) */,
556     0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
557     &fully_connected_op->params.qu8_conv_minmax,
558     sizeof(fully_connected_op->params.qu8_conv_minmax),
559     pthreadpool_get_threads_count(threadpool));
560 }
561 
xnn_setup_fully_connected_nc_qs8(xnn_operator_t fully_connected_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)562 enum xnn_status xnn_setup_fully_connected_nc_qs8(
563     xnn_operator_t fully_connected_op,
564     size_t batch_size,
565     const int8_t* input,
566     int8_t* output,
567     pthreadpool_t threadpool)
568 {
569   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_qs8) {
570     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
571       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qs8),
572       xnn_operator_type_to_string(fully_connected_op->type));
573     return xnn_status_invalid_parameter;
574   }
575 
576   return setup_fully_connected_nc(
577     fully_connected_op,
578     batch_size,
579     input, output,
580     XNN_INIT_FLAG_QS8,
581     0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
582     0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
583     sizeof(int32_t) /* sizeof(bias element) */,
584     0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
585     &fully_connected_op->params.qs8_conv_minmax,
586     sizeof(fully_connected_op->params.qs8_conv_minmax),
587     pthreadpool_get_threads_count(threadpool));
588 }
589 
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)590 enum xnn_status xnn_setup_fully_connected_nc_f32(
591     xnn_operator_t fully_connected_op,
592     size_t batch_size,
593     const float* input,
594     float* output,
595     pthreadpool_t threadpool)
596 {
597   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f32) {
598     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
599       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f32),
600       xnn_operator_type_to_string(fully_connected_op->type));
601     return xnn_status_invalid_parameter;
602   }
603 
604   return setup_fully_connected_nc(
605     fully_connected_op,
606     batch_size,
607     input, output,
608     XNN_INIT_FLAG_F32,
609     2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
610     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
611     sizeof(float) /* sizeof(bias element) */,
612     2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
613     &fully_connected_op->params.f32_minmax,
614     sizeof(fully_connected_op->params.f32_minmax),
615     pthreadpool_get_threads_count(threadpool));
616 }
617 
xnn_setup_fully_connected_nc_f16(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)618 enum xnn_status xnn_setup_fully_connected_nc_f16(
619     xnn_operator_t fully_connected_op,
620     size_t batch_size,
621     const void* input,
622     void* output,
623     pthreadpool_t threadpool)
624 {
625   if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f16) {
626     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
627       xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16),
628       xnn_operator_type_to_string(fully_connected_op->type));
629     return xnn_status_invalid_parameter;
630   }
631 
632   return setup_fully_connected_nc(
633     fully_connected_op,
634     batch_size,
635     input, output,
636     XNN_INIT_FLAG_F32,
637     1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
638     1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
639     sizeof(uint16_t) /* sizeof(bias element) */,
640     1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
641     &fully_connected_op->params.f16_scaleminmax,
642     sizeof(fully_connected_op->params.f16_scaleminmax),
643     pthreadpool_get_threads_count(threadpool));
644 }
645