1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14
15
create_fully_connected_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_fully_connected_operator(
17 const struct xnn_node* node,
18 const struct xnn_value* values,
19 size_t num_values,
20 struct xnn_operator_data* opdata)
21 {
22 assert(node->num_inputs >= 2);
23 assert(node->num_inputs <= 3);
24 const uint32_t input_id = node->inputs[0];
25 assert(input_id != XNN_INVALID_VALUE_ID);
26 assert(input_id < num_values);
27 const uint32_t filter_id = node->inputs[1];
28 assert(filter_id != XNN_INVALID_VALUE_ID);
29 assert(filter_id < num_values);
30
31 assert(node->num_outputs == 1);
32 const uint32_t output_id = node->outputs[0];
33 assert(output_id != XNN_INVALID_VALUE_ID);
34 assert(output_id < num_values);
35
36 const size_t num_input_elements = xnn_shape_multiply_all_dims(&values[node->inputs[0]].shape);
37 size_t output_channels, input_channels;
38 if (node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
39 input_channels = values[node->inputs[1]].shape.dim[0];
40 output_channels = values[node->inputs[1]].shape.dim[1];
41 } else {
42 output_channels = values[node->inputs[1]].shape.dim[0];
43 input_channels = values[node->inputs[1]].shape.dim[1];
44 }
45
46 const void* filter_data = values[filter_id].data;
47 assert(filter_data != NULL);
48
49 const void* bias_data = NULL;
50 if (node->num_inputs > 2) {
51 const uint32_t bias_id = node->inputs[2];
52 assert(bias_id != XNN_INVALID_VALUE_ID);
53 assert(bias_id < num_values);
54
55 bias_data = values[bias_id].data;
56 assert(bias_data != NULL);
57 }
58
59 enum xnn_status status;
60 switch (node->compute_type) {
61 case xnn_compute_type_fp32:
62 status = xnn_create_fully_connected_nc_f32(
63 input_channels,
64 output_channels,
65 input_channels /* input stride */,
66 output_channels /* output stride */,
67 filter_data,
68 bias_data,
69 node->activation.output_min,
70 node->activation.output_max,
71 node->flags /* flags */,
72 &opdata->operator_object);
73 break;
74 #ifndef XNN_NO_QS8_OPERATORS
75 case xnn_compute_type_qs8:
76 {
77 const float output_scale = values[output_id].quantization.scale;
78 const int32_t output_zero_point = values[output_id].quantization.zero_point;
79 const int8_t output_min =
80 (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
81 const int8_t output_max =
82 (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
83 status = xnn_create_fully_connected_nc_qs8(
84 input_channels,
85 output_channels,
86 input_channels /* input stride */,
87 output_channels /* output stride */,
88 (int8_t) values[input_id].quantization.zero_point,
89 values[input_id].quantization.scale,
90 values[filter_id].quantization.scale,
91 filter_data,
92 bias_data,
93 (int8_t) output_zero_point,
94 output_scale, output_min, output_max,
95 node->flags /* flags */,
96 &opdata->operator_object);
97 break;
98 }
99 #endif // !defined(XNN_NO_QS8_OPERATORS)
100 #ifndef XNN_NO_QU8_OPERATORS
101 case xnn_compute_type_qu8:
102 {
103 const float output_scale = values[output_id].quantization.scale;
104 const int32_t output_zero_point = values[output_id].quantization.zero_point;
105 const uint8_t output_min =
106 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
107 const uint8_t output_max =
108 (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
109 status = xnn_create_fully_connected_nc_qu8(
110 input_channels,
111 output_channels,
112 input_channels /* input stride */,
113 output_channels /* output stride */,
114 (uint8_t) values[input_id].quantization.zero_point,
115 values[input_id].quantization.scale,
116 (uint8_t) values[filter_id].quantization.zero_point,
117 values[filter_id].quantization.scale,
118 filter_data,
119 bias_data,
120 (uint8_t) output_zero_point,
121 output_scale, output_min, output_max,
122 node->flags /* flags */,
123 &opdata->operator_object);
124 break;
125 }
126 #endif // !defined(XNN_NO_QU8_OPERATORS)
127 default:
128 XNN_UNREACHABLE;
129 }
130 if (status == xnn_status_success) {
131 opdata->batch_size = num_input_elements / input_channels;
132 opdata->inputs[0] = input_id;
133 opdata->outputs[0] = output_id;
134 }
135 return status;
136 }
137
setup_fully_connected_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)138 static enum xnn_status setup_fully_connected_operator(
139 const struct xnn_operator_data* opdata,
140 const struct xnn_blob* blobs,
141 size_t num_blobs,
142 pthreadpool_t threadpool)
143 {
144 const uint32_t input_id = opdata->inputs[0];
145 assert(input_id != XNN_INVALID_VALUE_ID);
146 assert(input_id < num_blobs);
147
148 const uint32_t output_id = opdata->outputs[0];
149 assert(output_id != XNN_INVALID_VALUE_ID);
150 assert(output_id < num_blobs);
151
152 const struct xnn_blob* input_blob = blobs + input_id;
153 const void* input_data = input_blob->data;
154 assert(input_data != NULL);
155
156 const struct xnn_blob* output_blob = blobs + output_id;
157 void* output_data = output_blob->data;
158 assert(output_data != NULL);
159
160 switch (opdata->operator_object->type) {
161 case xnn_operator_type_fully_connected_nc_f32:
162 return xnn_setup_fully_connected_nc_f32(
163 opdata->operator_object,
164 opdata->batch_size,
165 input_data,
166 output_data,
167 threadpool);
168 #ifndef XNN_NO_QS8_OPERATORS
169 case xnn_operator_type_fully_connected_nc_qs8:
170 return xnn_setup_fully_connected_nc_qs8(
171 opdata->operator_object,
172 opdata->batch_size,
173 input_data,
174 output_data,
175 threadpool);
176 #endif // !defined(XNN_NO_QS8_OPERATORS)
177 #ifndef XNN_NO_QU8_OPERATORS
178 case xnn_operator_type_fully_connected_nc_qu8:
179 return xnn_setup_fully_connected_nc_qu8(
180 opdata->operator_object,
181 opdata->batch_size,
182 input_data,
183 output_data,
184 threadpool);
185 #endif // !defined(XNN_NO_QU8_OPERATORS)
186 default:
187 XNN_UNREACHABLE;
188 }
189 }
190
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)191 static inline enum xnn_compute_type validate_datatypes_with_bias(
192 enum xnn_datatype input_datatype,
193 enum xnn_datatype filter_datatype,
194 enum xnn_datatype bias_datatype,
195 enum xnn_datatype output_datatype)
196 {
197 switch (filter_datatype) {
198 case xnn_datatype_fp32:
199 if (input_datatype == xnn_datatype_fp32 &&
200 bias_datatype == xnn_datatype_fp32 &&
201 output_datatype == xnn_datatype_fp32)
202 {
203 return xnn_compute_type_fp32;
204 }
205 break;
206 #ifndef XNN_NO_QS8_OPERATORS
207 case xnn_datatype_qint8:
208 if (input_datatype == xnn_datatype_qint8 &&
209 bias_datatype == xnn_datatype_qint32 &&
210 output_datatype == xnn_datatype_qint8)
211 {
212 return xnn_compute_type_qs8;
213 }
214 break;
215 #endif // !defined(XNN_NO_QS8_OPERATORS)
216 #ifndef XNN_NO_QU8_OPERATORS
217 case xnn_datatype_quint8:
218 if (input_datatype == xnn_datatype_quint8 &&
219 bias_datatype == xnn_datatype_qint32 &&
220 output_datatype == xnn_datatype_quint8)
221 {
222 return xnn_compute_type_qu8;
223 }
224 break;
225 #endif // !defined(XNN_NO_QU8_OPERATORS)
226 default:
227 XNN_UNREACHABLE;
228 }
229 return xnn_compute_type_invalid;
230 }
231
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)232 static inline enum xnn_compute_type validate_datatypes_without_bias(
233 enum xnn_datatype input_datatype,
234 enum xnn_datatype filter_datatype,
235 enum xnn_datatype output_datatype)
236 {
237 switch (filter_datatype) {
238 case xnn_datatype_fp32:
239 if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
240 return xnn_compute_type_fp32;
241 }
242 break;
243 #ifndef XNN_NO_QS8_OPERATORS
244 case xnn_datatype_qint8:
245 if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
246 return xnn_compute_type_qs8;
247 }
248 break;
249 #endif // !defined(XNN_NO_QS8_OPERATORS)
250 #ifndef XNN_NO_QU8_OPERATORS
251 case xnn_datatype_quint8:
252 if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
253 return xnn_compute_type_qu8;
254 }
255 break;
256 #endif // !defined(XNN_NO_QU8_OPERATORS)
257 default:
258 XNN_UNREACHABLE;
259 }
260 return xnn_compute_type_invalid;
261 }
262
xnn_define_fully_connected(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)263 enum xnn_status xnn_define_fully_connected(
264 xnn_subgraph_t subgraph,
265 float output_min,
266 float output_max,
267 uint32_t input_id,
268 uint32_t filter_id,
269 uint32_t bias_id,
270 uint32_t output_id,
271 uint32_t flags)
272 {
273 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
274 xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
275 xnn_node_type_to_string(xnn_node_type_fully_connected));
276 return xnn_status_uninitialized;
277 }
278
279 if (isnan(output_min)) {
280 xnn_log_error(
281 "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
282 xnn_node_type_to_string(xnn_node_type_fully_connected));
283 return xnn_status_invalid_parameter;
284 }
285
286 if (isnan(output_max)) {
287 xnn_log_error(
288 "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
289 xnn_node_type_to_string(xnn_node_type_fully_connected));
290 return xnn_status_invalid_parameter;
291 }
292
293 if (output_min >= output_max) {
294 xnn_log_error(
295 "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
296 xnn_node_type_to_string(xnn_node_type_fully_connected), output_min, output_max);
297 return xnn_status_invalid_parameter;
298 }
299
300 if (input_id >= subgraph->num_values) {
301 xnn_log_error(
302 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
303 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id);
304 return xnn_status_invalid_parameter;
305 }
306
307 const struct xnn_value* input_value = &subgraph->values[input_id];
308 if (input_value->type != xnn_value_type_dense_tensor) {
309 xnn_log_error(
310 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
311 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, input_value->type);
312 return xnn_status_invalid_parameter;
313 }
314
315 switch (input_value->datatype) {
316 case xnn_datatype_fp32:
317 #ifndef XNN_NO_QS8_OPERATORS
318 case xnn_datatype_qint8:
319 #endif // !defined(XNN_NO_QS8_OPERATORS)
320 #ifndef XNN_NO_QU8_OPERATORS
321 case xnn_datatype_quint8:
322 #endif // !defined(XNN_NO_QS8_OPERATORS)
323 break;
324 default:
325 xnn_log_error(
326 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
327 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id,
328 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
329 return xnn_status_invalid_parameter;
330 }
331
332 if (filter_id >= subgraph->num_values) {
333 xnn_log_error(
334 "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
335 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
336 return xnn_status_invalid_parameter;
337 }
338
339 const struct xnn_value* filter_value = &subgraph->values[filter_id];
340 if (filter_value->type != xnn_value_type_dense_tensor) {
341 xnn_log_error(
342 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
343 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id, filter_value->type);
344 return xnn_status_invalid_parameter;
345 }
346
347 if (filter_value->data == NULL) {
348 xnn_log_error(
349 "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
350 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
351 return xnn_status_invalid_parameter;
352 }
353
354 switch (filter_value->datatype) {
355 case xnn_datatype_fp32:
356 break;
357 #ifndef XNN_NO_QS8_OPERATORS
358 case xnn_datatype_qint8:
359 if (filter_value->quantization.zero_point != 0) {
360 xnn_log_error(
361 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
362 xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
363 filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
364 }
365 break;
366 #endif // !defined(XNN_NO_QS8_OPERATORS)
367 #ifndef XNN_NO_QU8_OPERATORS
368 case xnn_datatype_quint8:
369 break;
370 #endif // !defined(XNN_NO_QU8_OPERATORS)
371 default:
372 xnn_log_error(
373 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
374 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id,
375 xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
376 return xnn_status_invalid_parameter;
377 }
378
379 const struct xnn_value* bias_value = NULL;
380 if (bias_id != XNN_INVALID_VALUE_ID) {
381 if (bias_id >= subgraph->num_values) {
382 xnn_log_error(
383 "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
384 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
385 return xnn_status_invalid_parameter;
386 }
387
388 bias_value = &subgraph->values[bias_id];
389 if (bias_value->type != xnn_value_type_dense_tensor) {
390 xnn_log_error(
391 "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
392 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id, bias_value->type);
393 return xnn_status_invalid_parameter;
394 }
395
396 if (bias_value->data == NULL) {
397 xnn_log_error(
398 "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
399 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
400 return xnn_status_invalid_parameter;
401 }
402
403 switch (bias_value->datatype) {
404 case xnn_datatype_fp32:
405 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
406 case xnn_datatype_qint32:
407 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
408 break;
409 default:
410 xnn_log_error(
411 "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
412 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id,
413 xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
414 return xnn_status_invalid_parameter;
415 }
416 }
417
418 if (output_id >= subgraph->num_values) {
419 xnn_log_error(
420 "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
421 xnn_node_type_to_string(xnn_node_type_fully_connected), output_id);
422 return xnn_status_invalid_parameter;
423 }
424
425 const struct xnn_value* output_value = &subgraph->values[output_id];
426 if (output_value->type != xnn_value_type_dense_tensor) {
427 xnn_log_error(
428 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
429 xnn_node_type_to_string(xnn_node_type_fully_connected), output_id, output_value->type);
430 return xnn_status_invalid_parameter;
431 }
432
433 switch (output_value->datatype) {
434 case xnn_datatype_fp32:
435 #ifndef XNN_NO_QS8_OPERATORS
436 case xnn_datatype_qint8:
437 #endif // !defined(XNN_NO_QS8_OPERATORS)
438 #ifndef XNN_NO_QU8_OPERATORS
439 case xnn_datatype_quint8:
440 #endif // !defined(XNN_NO_QU8_OPERATORS)
441 break;
442 default:
443 xnn_log_error(
444 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
445 xnn_node_type_to_string(xnn_node_type_fully_connected), output_id,
446 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
447 return xnn_status_invalid_parameter;
448 }
449
450 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
451 if (bias_value != NULL) {
452 compute_type = validate_datatypes_with_bias(
453 input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
454 if (compute_type == xnn_compute_type_invalid) {
455 xnn_log_error(
456 "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
457 ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
458 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, bias_id, output_id,
459 xnn_datatype_to_string(input_value->datatype),
460 xnn_datatype_to_string(filter_value->datatype),
461 xnn_datatype_to_string(bias_value->datatype),
462 xnn_datatype_to_string(output_value->datatype));
463 return xnn_status_invalid_parameter;
464 }
465 } else {
466 compute_type = validate_datatypes_without_bias(
467 input_value->datatype, filter_value->datatype, output_value->datatype);
468 if (compute_type == xnn_compute_type_invalid) {
469 xnn_log_error(
470 "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
471 ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
472 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, output_id,
473 xnn_datatype_to_string(input_value->datatype),
474 xnn_datatype_to_string(filter_value->datatype),
475 xnn_datatype_to_string(output_value->datatype));
476 return xnn_status_invalid_parameter;
477 }
478 }
479
480 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
481 if (node == NULL) {
482 return xnn_status_out_of_memory;
483 }
484
485 node->type = xnn_node_type_fully_connected;
486 node->compute_type = compute_type;
487 node->activation.output_min = output_min;
488 node->activation.output_max = output_max;
489 node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
490 node->inputs[0] = input_id;
491 node->inputs[1] = filter_id;
492 node->inputs[2] = bias_id;
493 node->num_outputs = 1;
494 node->outputs[0] = output_id;
495 node->flags = flags;
496
497 node->create = create_fully_connected_operator;
498 node->setup = setup_fully_connected_operator;
499
500 return xnn_status_success;
501 }
502