1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15
16 #include <xnnpack.h>
17 #include <xnnpack/allocator.h>
18 #include <xnnpack/log.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/operator.h>
21 #include <xnnpack/pack.h>
22 #include <xnnpack/params-init.h>
23 #include <xnnpack/params.h>
24
25
xnn_create_fully_connected_nc_q8(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)26 enum xnn_status xnn_create_fully_connected_nc_q8(
27 size_t input_channels,
28 size_t output_channels,
29 size_t input_stride,
30 size_t output_stride,
31 uint8_t input_zero_point,
32 float input_scale,
33 uint8_t kernel_zero_point,
34 float kernel_scale,
35 const uint8_t* kernel,
36 const int32_t* bias,
37 uint8_t output_zero_point,
38 float output_scale,
39 uint8_t output_min,
40 uint8_t output_max,
41 uint32_t flags,
42 xnn_operator_t* fully_connected_op_out)
43 {
44 xnn_operator_t fully_connected_op = NULL;
45 enum xnn_status status = xnn_status_uninitialized;
46
47 if (!xnn_params.initialized) {
48 xnn_log_error("failed to create Fully Connected operator: XNNPACK is not initialized");
49 goto error;
50 }
51
52 status = xnn_status_invalid_parameter;
53
54 if (input_channels == 0) {
55 xnn_log_error(
56 "failed to create Fully Connected operator with %zu input channels: number of channels must be non-zero",
57 input_channels);
58 goto error;
59 }
60
61 if (output_channels == 0) {
62 xnn_log_error(
63 "failed to create Fully Connected operator with %zu output channels: number of channels must be non-zero",
64 output_channels);
65 goto error;
66 }
67
68 if (input_stride < input_channels) {
69 xnn_log_error(
70 "failed to create Fully Connected operator with input element stride of %zu: "
71 "stride must be at least as large as the number of input channels (%zu)",
72 input_stride, input_channels);
73 goto error;
74 }
75
76 if (output_stride < output_channels) {
77 xnn_log_error(
78 "failed to create Fully Connected operator with output element stride of %zu: "
79 "stride must be at least as large as the number of output channels (%zu)",
80 output_stride, output_channels);
81 goto error;
82 }
83
84 if (input_scale <= 0.0f || !isnormal(input_scale)) {
85 xnn_log_error(
86 "failed to create Fully Connected operator with %.7g input scale: scale must be finite, normalized, and positive",
87 input_scale);
88 goto error;
89 }
90
91 if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
92 xnn_log_error(
93 "failed to create Fully Connected operator with %.7g kernel scale: scale must be finite, normalized, and positive",
94 kernel_scale);
95 goto error;
96 }
97
98 if (output_scale <= 0.0f || !isnormal(output_scale)) {
99 xnn_log_error(
100 "failed to create Fully Connected operator with %.7g output scale: scale must be finite, normalized, and positive",
101 output_scale);
102 goto error;
103 }
104
105 if (output_min >= output_max) {
106 xnn_log_error(
107 "failed to create Fully Connected operator with [%" PRIu8 ", %" PRIu8 "] output range: "
108 "range min must be below range max",
109 output_min, output_max);
110 goto error;
111 }
112
113 status = xnn_status_unsupported_parameter;
114
115 const float requantization_scale = input_scale * kernel_scale / output_scale;
116 if (requantization_scale >= 1.0f) {
117 xnn_log_error(
118 "failed to create Fully Connected operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
119 "requantization scale %.7g is greater or equal to 1.0",
120 input_scale, kernel_scale, output_scale, requantization_scale);
121 goto error;
122 }
123
124 status = xnn_status_out_of_memory;
125
126 fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
127 if (fully_connected_op == NULL) {
128 xnn_log_error("failed to allocate %zu bytes for Fully Connected operator descriptor", sizeof(struct xnn_operator));
129 goto error;
130 }
131
132 const uint32_t nr = xnn_params.q8.gemm.nr;
133 const uint32_t kr = UINT32_C(1) << xnn_params.q8.gemm.log2_kr;
134
135 const uint32_t n_stride = round_up(output_channels, nr);
136 const uint32_t k_stride = round_up_po2(input_channels, kr);
137
138 fully_connected_op->packed_weights = xnn_allocate_simd_memory(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
139 if (fully_connected_op->packed_weights == NULL) {
140 xnn_log_error("failed to allocate %zu bytes for packed weights",
141 n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
142 goto error;
143 }
144 memset(fully_connected_op->packed_weights, kernel_zero_point, n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
145
146 if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
147 xnn_pack_q8_gemm_io_w(
148 output_channels, input_channels,
149 nr, kr,
150 input_zero_point, kernel_zero_point,
151 kernel, bias,
152 fully_connected_op->packed_weights);
153 } else {
154 xnn_pack_q8_gemm_goi_w(
155 1, output_channels, input_channels,
156 nr, kr,
157 input_zero_point, kernel_zero_point,
158 kernel, bias,
159 fully_connected_op->packed_weights);
160 }
161
162 fully_connected_op->group_input_channels = input_channels;
163 fully_connected_op->group_output_channels = output_channels;
164 fully_connected_op->input_pixel_stride = input_stride;
165 fully_connected_op->output_pixel_stride = output_stride;
166
167 fully_connected_op->kernel_zero_point = kernel_zero_point;
168
169 fully_connected_op->q8_gemm_params =
170 xnn_init_q8_gemm_params(
171 input_zero_point, kernel_zero_point,
172 requantization_scale, output_zero_point, output_min, output_max);
173
174 fully_connected_op->type = xnn_operator_type_fully_connected_nc_q8;
175
176 fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
177 fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
178 .default_function = xnn_params.q8.gemm.gemm,
179 .mr = xnn_params.q8.gemm.mr,
180 .nr = nr,
181 .kr = kr,
182 };
183
184 fully_connected_op->state = xnn_run_state_invalid;
185
186 *fully_connected_op_out = fully_connected_op;
187 return xnn_status_success;
188
189 error:
190 xnn_delete_operator(fully_connected_op);
191 return status;
192 }
193
xnn_create_fully_connected_nc_f32(size_t input_channels,size_t output_channels,size_t input_stride,size_t output_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_operator_t * fully_connected_op_out)194 enum xnn_status xnn_create_fully_connected_nc_f32(
195 size_t input_channels,
196 size_t output_channels,
197 size_t input_stride,
198 size_t output_stride,
199 const float* kernel,
200 const float* bias,
201 float output_min,
202 float output_max,
203 uint32_t flags,
204 xnn_operator_t* fully_connected_op_out)
205 {
206 xnn_operator_t fully_connected_op = NULL;
207 enum xnn_status status = xnn_status_uninitialized;
208
209 if (!xnn_params.initialized) {
210 xnn_log_error("failed to create Fully Connected operator: XNNPACK is not initialized");
211 goto error;
212 }
213
214 status = xnn_status_invalid_parameter;
215
216 if (input_channels == 0) {
217 xnn_log_error(
218 "failed to create Fully Connected operator with %zu input channels: number of channels must be non-zero",
219 input_channels);
220 goto error;
221 }
222
223 if (output_channels == 0) {
224 xnn_log_error(
225 "failed to create Fully Connected operator with %zu output channels: number of channels must be non-zero",
226 output_channels);
227 goto error;
228 }
229
230 if (input_stride < input_channels) {
231 xnn_log_error(
232 "failed to create Fully Connected operator with input element stride of %zu: "
233 "stride must be at least as large as the number of input channels (%zu)",
234 input_stride, input_channels);
235 goto error;
236 }
237
238 if (output_stride < output_channels) {
239 xnn_log_error(
240 "failed to create Fully Connected operator with output element stride of %zu: "
241 "stride must be at least as large as the number of output channels (%zu)",
242 output_stride, output_channels);
243 goto error;
244 }
245
246 if (isnan(output_min)) {
247 xnn_log_error(
248 "failed to create Fully Connected operator with NaN output lower bound: lower bound must be non-NaN");
249 goto error;
250 }
251
252 if (isnan(output_max)) {
253 xnn_log_error(
254 "failed to create Fully Connected operator with NaN output upper bound: upper bound must be non-NaN");
255 goto error;
256 }
257
258 if (output_min >= output_max) {
259 xnn_log_error(
260 "failed to create Fully Connected operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
261 output_min, output_max);
262 goto error;
263 }
264
265 status = xnn_status_out_of_memory;
266
267 fully_connected_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
268 if (fully_connected_op == NULL) {
269 xnn_log_error("failed to allocate %zu bytes for Fully Connected operator descriptor", sizeof(struct xnn_operator));
270 goto error;
271 }
272
273 const uint32_t nr = xnn_params.f32.gemm.nr;
274 const uint32_t kr = UINT32_C(1) << xnn_params.f32.gemm.log2_kr;
275 const uint32_t sr = UINT32_C(1) << xnn_params.f32.gemm.log2_sr;
276
277 const uint32_t n_stride = round_up(output_channels, nr);
278 const uint32_t k_stride = round_up_po2(input_channels, kr);
279
280 fully_connected_op->packed_weights = xnn_allocate_simd_memory(n_stride * (k_stride * sizeof(float) + sizeof(float)));
281 if (fully_connected_op->packed_weights == NULL) {
282 xnn_log_error("failed to allocate %zu bytes for packed weights",
283 n_stride * (k_stride * sizeof(float) + sizeof(float)));
284 goto error;
285 }
286 memset(fully_connected_op->packed_weights, 0, n_stride * (k_stride * sizeof(float) + sizeof(float)));
287
288 if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
289 xnn_pack_f32_gemm_io_w(
290 output_channels, input_channels,
291 nr, kr, sr,
292 kernel, bias,
293 fully_connected_op->packed_weights);
294 } else {
295 xnn_pack_f32_gemm_goi_w(
296 1, output_channels, input_channels,
297 nr, kr, sr,
298 kernel, bias,
299 fully_connected_op->packed_weights);
300 }
301
302 fully_connected_op->group_input_channels = input_channels;
303 fully_connected_op->group_output_channels = output_channels;
304 fully_connected_op->input_pixel_stride = input_stride;
305 fully_connected_op->output_pixel_stride = output_stride;
306
307 fully_connected_op->f32_output_params = xnn_init_f32_output_params(output_min, output_max);
308
309 fully_connected_op->type = xnn_operator_type_fully_connected_nc_f32;
310
311 fully_connected_op->ukernel.type = xnn_ukernel_type_gemm;
312 fully_connected_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
313 .default_function = xnn_params.f32.gemm.gemm,
314 .mr1_function = xnn_params.f32.gemm.gemm1,
315 .mr = xnn_params.f32.gemm.mr,
316 .nr = nr,
317 .kr = kr,
318 };
319
320 fully_connected_op->state = xnn_run_state_invalid;
321
322 *fully_connected_op_out = fully_connected_op;
323 return xnn_status_success;
324
325 error:
326 xnn_delete_operator(fully_connected_op);
327 return status;
328 }
329
setup_fully_connected_nc(xnn_operator_t fully_connected_op,size_t batch_size,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t num_threads)330 static enum xnn_status setup_fully_connected_nc(
331 xnn_operator_t fully_connected_op,
332 size_t batch_size,
333 const void* input,
334 void* output,
335 uint32_t log2_input_element_size,
336 uint32_t log2_filter_element_size,
337 uint32_t bias_element_size,
338 uint32_t log2_output_element_size,
339 const void* params,
340 size_t num_threads)
341 {
342 fully_connected_op->state = xnn_run_state_invalid;
343
344 if (!xnn_params.initialized) {
345 xnn_log_error("failed to setup Fully Connected operator: XNNPACK is not initialized");
346 return xnn_status_uninitialized;
347 }
348
349 if (batch_size == 0) {
350 fully_connected_op->state = xnn_run_state_skip;
351 return xnn_status_success;
352 }
353
354 fully_connected_op->batch_size = 1;
355 fully_connected_op->input_height = batch_size;
356 fully_connected_op->input_width = 1;
357 fully_connected_op->input = input;
358
359 fully_connected_op->output_height = batch_size;
360 fully_connected_op->output_width = 1;
361 fully_connected_op->output = output;
362
363 const size_t input_channels = fully_connected_op->group_input_channels;
364 const size_t output_channels = fully_connected_op->group_output_channels;
365
366 uint32_t mr = fully_connected_op->ukernel.gemm.mr;
367 const uint32_t nr = fully_connected_op->ukernel.gemm.nr;
368
369 xnn_gemm_ukernel_function gemm_ukernel = fully_connected_op->ukernel.gemm.default_function;
370 if (batch_size == 1 && fully_connected_op->ukernel.gemm.mr1_function != NULL) {
371 gemm_ukernel = fully_connected_op->ukernel.gemm.mr1_function;
372 mr = 1;
373 }
374
375 fully_connected_op->context.gemm = (struct gemm_context) {
376 .k_scaled = input_channels << log2_input_element_size,
377 .w_stride = (round_up_po2(input_channels, fully_connected_op->ukernel.gemm.kr) << log2_input_element_size) + bias_element_size,
378 .a = input,
379 .a_stride = fully_connected_op->input_pixel_stride << log2_input_element_size,
380 .packed_w = fully_connected_op->packed_weights,
381 .c = output,
382 .cm_stride = fully_connected_op->output_pixel_stride << log2_output_element_size,
383 .cn_stride = nr << log2_output_element_size,
384 .log2_csize = log2_output_element_size,
385 .ukernel = gemm_ukernel,
386 };
387 memcpy(&fully_connected_op->context.gemm.params, params, sizeof(fully_connected_op->context.gemm.params));
388
389 size_t nc = output_channels;
390 if (num_threads > 1) {
391 const size_t num_other_tiles = divide_round_up(batch_size, mr);
392 const size_t target_tiles_per_thread = 5;
393 const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
394 if (max_nc < nc) {
395 nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
396 }
397 }
398 fully_connected_op->compute.type = xnn_parallelization_type_2d_tile_2d;
399 fully_connected_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
400 fully_connected_op->compute.range[0] = batch_size;
401 fully_connected_op->compute.range[1] = output_channels;
402 fully_connected_op->compute.tile[0] = mr;
403 fully_connected_op->compute.tile[1] = nc;
404 fully_connected_op->state = xnn_run_state_ready;
405
406 return xnn_status_success;
407 }
408
xnn_setup_fully_connected_nc_q8(xnn_operator_t fully_connected_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)409 enum xnn_status xnn_setup_fully_connected_nc_q8(
410 xnn_operator_t fully_connected_op,
411 size_t batch_size,
412 const uint8_t* input,
413 uint8_t* output,
414 pthreadpool_t threadpool)
415 {
416 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_q8) {
417 xnn_log_error("failed to setup Fully Connected (NC, Q8) operator: operator type mismatch");
418 return xnn_status_invalid_parameter;
419 }
420
421 return setup_fully_connected_nc(
422 fully_connected_op,
423 batch_size,
424 input, output,
425 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
426 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
427 sizeof(int32_t) /* sizeof(bias element) */,
428 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
429 &fully_connected_op->q8_gemm_params,
430 pthreadpool_get_threads_count(threadpool));
431 }
432
xnn_setup_fully_connected_nc_f32(xnn_operator_t fully_connected_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)433 enum xnn_status xnn_setup_fully_connected_nc_f32(
434 xnn_operator_t fully_connected_op,
435 size_t batch_size,
436 const float* input,
437 float* output,
438 pthreadpool_t threadpool)
439 {
440 if (fully_connected_op->type != xnn_operator_type_fully_connected_nc_f32) {
441 xnn_log_error("failed to setup Fully Connected (NC, F32) operator: operator type mismatch");
442 return xnn_status_invalid_parameter;
443 }
444
445 return setup_fully_connected_nc(
446 fully_connected_op,
447 batch_size,
448 input, output,
449 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
450 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
451 sizeof(float) /* sizeof(bias element) */,
452 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
453 &fully_connected_op->f32_output_params,
454 pthreadpool_get_threads_count(threadpool));
455 }
456