• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/log.h>
12 #include <xnnpack/params.h>
13 #include <xnnpack/subgraph.h>
14 
15 
create_convolution_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata)16 static enum xnn_status create_convolution_operator(
17   const struct xnn_node* node,
18   const struct xnn_value* values,
19   size_t num_values,
20   struct xnn_operator_data* opdata)
21 {
22   assert(node->num_inputs >= 2);
23   assert(node->num_inputs <= 3);
24   const uint32_t input_id = node->inputs[0];
25   assert(input_id != XNN_INVALID_VALUE_ID);
26   assert(input_id < num_values);
27   const uint32_t filter_id = node->inputs[1];
28   assert(filter_id != XNN_INVALID_VALUE_ID);
29   assert(filter_id < num_values);
30 
31   assert(node->num_outputs == 1);
32   const uint32_t output_id = node->outputs[0];
33   assert(output_id != XNN_INVALID_VALUE_ID);
34   assert(output_id < num_values);
35 
36   const void* filter_data = values[filter_id].data;
37   assert(filter_data != NULL);
38 
39   const void* bias_data = NULL;
40   if (node->num_inputs > 2) {
41     const uint32_t bias_id = node->inputs[2];
42     assert(bias_id != XNN_INVALID_VALUE_ID);
43     assert(bias_id < num_values);
44 
45     bias_data = values[bias_id].data;
46     assert(bias_data != NULL);
47   }
48 
49   enum xnn_status status;
50   if (values[output_id].layout == xnn_layout_type_nchw) {
51     assert(node->compute_type == xnn_compute_type_fp32);
52     status = xnn_create_convolution2d_nchw_f32(
53       node->params.convolution_2d.input_padding_top,
54       node->params.convolution_2d.input_padding_right,
55       node->params.convolution_2d.input_padding_bottom,
56       node->params.convolution_2d.input_padding_left,
57       node->params.convolution_2d.kernel_height,
58       node->params.convolution_2d.kernel_width,
59       node->params.convolution_2d.subsampling_height,
60       node->params.convolution_2d.subsampling_width,
61       node->params.convolution_2d.dilation_height,
62       node->params.convolution_2d.dilation_width,
63       node->params.convolution_2d.groups,
64       node->params.convolution_2d.group_input_channels,
65       node->params.convolution_2d.group_output_channels,
66       node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
67       node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
68       filter_data,
69       bias_data,
70       node->activation.output_min,
71       node->activation.output_max,
72       node->flags | (values[input_id].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0),
73       &opdata->operator_object);
74   } else {
75     assert(values[input_id].layout == xnn_layout_type_nhwc);
76     assert(values[output_id].layout == xnn_layout_type_nhwc);
77     switch (node->compute_type) {
78       case xnn_compute_type_fp32:
79         status = xnn_create_convolution2d_nhwc_f32(
80           node->params.convolution_2d.input_padding_top,
81           node->params.convolution_2d.input_padding_right,
82           node->params.convolution_2d.input_padding_bottom,
83           node->params.convolution_2d.input_padding_left,
84           node->params.convolution_2d.kernel_height,
85           node->params.convolution_2d.kernel_width,
86           node->params.convolution_2d.subsampling_height,
87           node->params.convolution_2d.subsampling_width,
88           node->params.convolution_2d.dilation_height,
89           node->params.convolution_2d.dilation_width,
90           node->params.convolution_2d.groups,
91           node->params.convolution_2d.group_input_channels,
92           node->params.convolution_2d.group_output_channels,
93           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
94           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
95           filter_data,
96           bias_data,
97           node->activation.output_min,
98           node->activation.output_max,
99           node->flags,
100           &opdata->operator_object);
101         break;
102 #ifndef XNN_NO_F16_OPERATORS
103       case xnn_compute_type_fp16:
104         status = xnn_create_convolution2d_nhwc_f16(
105           node->params.convolution_2d.input_padding_top,
106           node->params.convolution_2d.input_padding_right,
107           node->params.convolution_2d.input_padding_bottom,
108           node->params.convolution_2d.input_padding_left,
109           node->params.convolution_2d.kernel_height,
110           node->params.convolution_2d.kernel_width,
111           node->params.convolution_2d.subsampling_height,
112           node->params.convolution_2d.subsampling_width,
113           node->params.convolution_2d.dilation_height,
114           node->params.convolution_2d.dilation_width,
115           node->params.convolution_2d.groups,
116           node->params.convolution_2d.group_input_channels,
117           node->params.convolution_2d.group_output_channels,
118           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
119           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
120           filter_data,
121           bias_data,
122           node->activation.output_min,
123           node->activation.output_max,
124           node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
125           &opdata->operator_object);
126         break;
127 #endif  // XNN_NO_F16_OPERATORS
128 #ifndef XNN_NO_QS8_OPERATORS
129       case xnn_compute_type_qs8:
130       {
131         const float output_scale = values[output_id].quantization.scale;
132         const int32_t output_zero_point = values[output_id].quantization.zero_point;
133         const int8_t output_min =
134           (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
135         const int8_t output_max =
136           (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
137         status = xnn_create_convolution2d_nhwc_qs8(
138           node->params.convolution_2d.input_padding_top,
139           node->params.convolution_2d.input_padding_right,
140           node->params.convolution_2d.input_padding_bottom,
141           node->params.convolution_2d.input_padding_left,
142           node->params.convolution_2d.kernel_height,
143           node->params.convolution_2d.kernel_width,
144           node->params.convolution_2d.subsampling_height,
145           node->params.convolution_2d.subsampling_width,
146           node->params.convolution_2d.dilation_height,
147           node->params.convolution_2d.dilation_width,
148           node->params.convolution_2d.groups,
149           node->params.convolution_2d.group_input_channels,
150           node->params.convolution_2d.group_output_channels,
151           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
152           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
153           (int8_t) values[input_id].quantization.zero_point,
154           values[input_id].quantization.scale,
155           values[filter_id].quantization.scale,
156           filter_data,
157           bias_data,
158           (int8_t) output_zero_point,
159           output_scale, output_min, output_max,
160           node->flags,
161           &opdata->operator_object);
162         break;
163       }
164       case xnn_compute_type_qc8:
165       {
166         const float output_scale = values[output_id].quantization.scale;
167         const int32_t output_zero_point = values[output_id].quantization.zero_point;
168         const int8_t output_min =
169           (int8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, -128.0f), 127.0f));
170         const int8_t output_max =
171           (int8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, -128.0f), 127.0f));
172         status = xnn_create_convolution2d_nhwc_qc8(
173           node->params.convolution_2d.input_padding_top,
174           node->params.convolution_2d.input_padding_right,
175           node->params.convolution_2d.input_padding_bottom,
176           node->params.convolution_2d.input_padding_left,
177           node->params.convolution_2d.kernel_height,
178           node->params.convolution_2d.kernel_width,
179           node->params.convolution_2d.subsampling_height,
180           node->params.convolution_2d.subsampling_width,
181           node->params.convolution_2d.dilation_height,
182           node->params.convolution_2d.dilation_width,
183           node->params.convolution_2d.groups,
184           node->params.convolution_2d.group_input_channels,
185           node->params.convolution_2d.group_output_channels,
186           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
187           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
188           (int8_t) values[input_id].quantization.zero_point,
189           values[input_id].quantization.scale,
190           values[filter_id].quantization.channelwise_scale,
191           filter_data,
192           bias_data,
193           (int8_t) output_zero_point,
194           output_scale, output_min, output_max,
195           node->flags,
196           &opdata->operator_object);
197         break;
198       }
199 #endif  // !defined(XNN_NO_QS8_OPERATORS)
200 #ifndef XNN_NO_QU8_OPERATORS
201       case xnn_compute_type_qu8:
202       {
203         const float output_scale = values[output_id].quantization.scale;
204         const int32_t output_zero_point = values[output_id].quantization.zero_point;
205         const uint8_t output_min =
206           (uint8_t) lrintf(fminf(fmaxf(node->activation.output_min / output_scale + (float) output_zero_point, 0.0f), 255.0f));
207         const uint8_t output_max =
208           (uint8_t) lrintf(fminf(fmaxf(node->activation.output_max / output_scale + (float) output_zero_point, 0.0f), 255.0f));
209         status = xnn_create_convolution2d_nhwc_qu8(
210           node->params.convolution_2d.input_padding_top,
211           node->params.convolution_2d.input_padding_right,
212           node->params.convolution_2d.input_padding_bottom,
213           node->params.convolution_2d.input_padding_left,
214           node->params.convolution_2d.kernel_height,
215           node->params.convolution_2d.kernel_width,
216           node->params.convolution_2d.subsampling_height,
217           node->params.convolution_2d.subsampling_width,
218           node->params.convolution_2d.dilation_height,
219           node->params.convolution_2d.dilation_width,
220           node->params.convolution_2d.groups,
221           node->params.convolution_2d.group_input_channels,
222           node->params.convolution_2d.group_output_channels,
223           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
224           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
225           (uint8_t) values[input_id].quantization.zero_point,
226           values[input_id].quantization.scale,
227           (uint8_t) values[filter_id].quantization.zero_point,
228           values[filter_id].quantization.scale,
229           filter_data,
230           bias_data,
231           (uint8_t) output_zero_point,
232           output_scale, output_min, output_max,
233           node->flags,
234           &opdata->operator_object);
235         break;
236       }
237 #endif  // !defined(XNN_NO_QU8_OPERATORS)
238       default:
239         XNN_UNREACHABLE;
240     }
241   }
242   if (status == xnn_status_success) {
243     opdata->batch_size = values[input_id].shape.dim[0];
244     opdata->input_height = values[input_id].shape.dim[1];
245     opdata->input_width = values[input_id].shape.dim[2];
246     opdata->inputs[0] = input_id;
247     opdata->outputs[0] = output_id;
248   }
249   return status;
250 }
251 
setup_convolution_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)252 static enum xnn_status setup_convolution_operator(
253   const struct xnn_operator_data* opdata,
254   const struct xnn_blob* blobs,
255   size_t num_blobs,
256   pthreadpool_t threadpool)
257 {
258   const uint32_t input_id = opdata->inputs[0];
259   assert(input_id != XNN_INVALID_VALUE_ID);
260   assert(input_id < num_blobs);
261 
262   const uint32_t output_id = opdata->outputs[0];
263   assert(output_id != XNN_INVALID_VALUE_ID);
264   assert(output_id < num_blobs);
265 
266   const struct xnn_blob* input_blob = blobs + input_id;
267   const void* input_data = input_blob->data;
268   assert(input_data != NULL);
269 
270   const struct xnn_blob* output_blob = blobs + output_id;
271   void* output_data = output_blob->data;
272   assert(output_data != NULL);
273 
274   switch (opdata->operator_object->type) {
275     case xnn_operator_type_convolution_nchw_f32:
276       return xnn_setup_convolution2d_nchw_f32(
277         opdata->operator_object,
278         opdata->batch_size,
279         opdata->input_height,
280         opdata->input_width,
281         input_data,
282         output_data,
283         threadpool);
284       break;
285     case xnn_operator_type_convolution_nhwc_f32:
286       return xnn_setup_convolution2d_nhwc_f32(
287         opdata->operator_object,
288         opdata->batch_size,
289         opdata->input_height,
290         opdata->input_width,
291         input_data,
292         output_data,
293         threadpool);
294       break;
295 #ifndef XNN_NO_F16_OPERATORS
296     case xnn_operator_type_convolution_nhwc_f16:
297       return xnn_setup_convolution2d_nhwc_f16(
298         opdata->operator_object,
299         opdata->batch_size,
300         opdata->input_height,
301         opdata->input_width,
302         input_data,
303         output_data,
304         threadpool);
305       break;
306 #endif  // !defined(XNN_NO_F16_OPERATORS)
307 #ifndef XNN_NO_QS8_OPERATORS
308     case xnn_operator_type_convolution_nhwc_qc8:
309       return xnn_setup_convolution2d_nhwc_qc8(
310         opdata->operator_object,
311         opdata->batch_size,
312         opdata->input_height,
313         opdata->input_width,
314         input_data,
315         output_data,
316         threadpool);
317       break;
318     case xnn_operator_type_convolution_nhwc_qs8:
319       return xnn_setup_convolution2d_nhwc_qs8(
320         opdata->operator_object,
321         opdata->batch_size,
322         opdata->input_height,
323         opdata->input_width,
324         input_data,
325         output_data,
326         threadpool);
327       break;
328 #endif  // !defined(XNN_NO_QS8_OPERATORS)
329 #ifndef XNN_NO_QU8_OPERATORS
330     case xnn_operator_type_convolution_nhwc_qu8:
331       return xnn_setup_convolution2d_nhwc_qu8(
332         opdata->operator_object,
333         opdata->batch_size,
334         opdata->input_height,
335         opdata->input_width,
336         input_data,
337         output_data,
338         threadpool);
339       break;
340 #endif  // !defined(XNN_NO_QU8_OPERATORS)
341     default:
342       XNN_UNREACHABLE;
343   }
344 }
345 
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)346 static inline enum xnn_compute_type validate_datatypes_with_bias(
347   enum xnn_datatype input_datatype,
348   enum xnn_datatype filter_datatype,
349   enum xnn_datatype bias_datatype,
350   enum xnn_datatype output_datatype)
351 {
352   switch (filter_datatype) {
353     case xnn_datatype_fp32:
354       if (input_datatype == xnn_datatype_fp32 &&
355           bias_datatype == xnn_datatype_fp32 &&
356           output_datatype == xnn_datatype_fp32)
357       {
358         return xnn_compute_type_fp32;
359       }
360       break;
361 #ifndef XNN_NO_QS8_OPERATORS
362     case xnn_datatype_qint8:
363       if (input_datatype == xnn_datatype_qint8 &&
364           bias_datatype == xnn_datatype_qint32 &&
365           output_datatype == xnn_datatype_qint8)
366       {
367         return xnn_compute_type_qs8;
368       }
369       break;
370     case xnn_datatype_qcint8:
371       if (input_datatype == xnn_datatype_qint8 &&
372           bias_datatype == xnn_datatype_qcint32 &&
373           output_datatype == xnn_datatype_qint8)
374       {
375         return xnn_compute_type_qc8;
376       }
377       break;
378 #endif  // !defined(XNN_NO_QS8_OPERATORS)
379 #ifndef XNN_NO_QU8_OPERATORS
380     case xnn_datatype_quint8:
381       if (input_datatype == xnn_datatype_quint8 &&
382           bias_datatype == xnn_datatype_qint32 &&
383           output_datatype == xnn_datatype_quint8)
384       {
385         return xnn_compute_type_qu8;
386       }
387       break;
388 #endif  // !defined(XNN_NO_QU8_OPERATORS)
389     default:
390       XNN_UNREACHABLE;
391   }
392   return xnn_compute_type_invalid;
393 }
394 
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)395 static inline enum xnn_compute_type validate_datatypes_without_bias(
396   enum xnn_datatype input_datatype,
397   enum xnn_datatype filter_datatype,
398   enum xnn_datatype output_datatype)
399 {
400   switch (filter_datatype) {
401     case xnn_datatype_fp32:
402       if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
403         return xnn_compute_type_fp32;
404       }
405       break;
406 #ifndef XNN_NO_QS8_OPERATORS
407     case xnn_datatype_qint8:
408       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
409         return xnn_compute_type_qs8;
410       }
411       break;
412     case xnn_datatype_qcint8:
413       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
414         return xnn_compute_type_qc8;
415       }
416       break;
417 #endif  // !defined(XNN_NO_QS8_OPERATORS)
418 #ifndef XNN_NO_QU8_OPERATORS
419     case xnn_datatype_quint8:
420       if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
421         return xnn_compute_type_qu8;
422       }
423       break;
424 #endif  // !defined(XNN_NO_QU8_OPERATORS)
425     default:
426       XNN_UNREACHABLE;
427   }
428   return xnn_compute_type_invalid;
429 }
430 
xnn_define_convolution_2d(xnn_subgraph_t subgraph,uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)431 enum xnn_status xnn_define_convolution_2d(
432   xnn_subgraph_t subgraph,
433   uint32_t input_padding_top,
434   uint32_t input_padding_right,
435   uint32_t input_padding_bottom,
436   uint32_t input_padding_left,
437   uint32_t kernel_height,
438   uint32_t kernel_width,
439   uint32_t subsampling_height,
440   uint32_t subsampling_width,
441   uint32_t dilation_height,
442   uint32_t dilation_width,
443   uint32_t groups,
444   size_t group_input_channels,
445   size_t group_output_channels,
446   float output_min,
447   float output_max,
448   uint32_t input_id,
449   uint32_t filter_id,
450   uint32_t bias_id,
451   uint32_t output_id,
452   uint32_t flags)
453 {
454   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
455     xnn_log_error("failed to define %s operator: XNNPACK is not initialized",
456       xnn_node_type_to_string(xnn_node_type_convolution_2d));
457     return xnn_status_uninitialized;
458   }
459 
460   if (kernel_width == 0 || kernel_height == 0) {
461     xnn_log_error(
462       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
463       xnn_node_type_to_string(xnn_node_type_convolution_2d), kernel_width, kernel_height);
464     return xnn_status_invalid_parameter;
465   }
466 
467   if (subsampling_width == 0 || subsampling_height == 0) {
468     xnn_log_error(
469       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " subsampling: subsampling dimensions must be non-zero",
470       xnn_node_type_to_string(xnn_node_type_convolution_2d), subsampling_width, subsampling_height);
471     return xnn_status_invalid_parameter;
472   }
473 
474   if (dilation_width == 0 || dilation_height == 0) {
475     xnn_log_error(
476       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
477       xnn_node_type_to_string(xnn_node_type_convolution_2d), dilation_width, dilation_height);
478     return xnn_status_invalid_parameter;
479   }
480 
481   if (groups == 0) {
482     xnn_log_error(
483       "failed to define %s operator with %" PRIu32 " groups: number of groups must be non-zero",
484       xnn_node_type_to_string(xnn_node_type_convolution_2d), groups);
485     return xnn_status_invalid_parameter;
486   }
487 
488   if (group_input_channels == 0) {
489     xnn_log_error(
490       "failed to define %s operator with %zu input channels per group: number of channels must be non-zero",
491       xnn_node_type_to_string(xnn_node_type_convolution_2d), group_input_channels);
492     return xnn_status_invalid_parameter;
493   }
494 
495   if (group_output_channels == 0) {
496     xnn_log_error(
497       "failed to define %s operator with %zu output channels per group: number of channels must be non-zero",
498       xnn_node_type_to_string(xnn_node_type_convolution_2d), group_output_channels);
499     return xnn_status_invalid_parameter;
500   }
501 
502   if (isnan(output_min)) {
503     xnn_log_error(
504       "failed to define %s operator with NaN output lower bound: lower bound must be non-NaN",
505       xnn_node_type_to_string(xnn_node_type_convolution_2d));
506     return xnn_status_invalid_parameter;
507   }
508 
509   if (isnan(output_max)) {
510     xnn_log_error(
511       "failed to define %s operator with NaN output upper bound: upper bound must be non-NaN",
512       xnn_node_type_to_string(xnn_node_type_convolution_2d));
513     return xnn_status_invalid_parameter;
514   }
515 
516   if (output_min >= output_max) {
517     xnn_log_error(
518       "failed to define %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
519       xnn_node_type_to_string(xnn_node_type_convolution_2d), output_min, output_max);
520     return xnn_status_invalid_parameter;
521   }
522 
523   const uint32_t supported_flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
524   const uint32_t invalid_flags = flags & ~supported_flags;
525   if (invalid_flags != 0) {
526     xnn_log_error(
527       "failed to define %s operator with 0x%08" PRIx32 " flags: invalid flags 0x%08" PRIx32,
528       xnn_node_type_to_string(xnn_node_type_convolution_2d), flags, invalid_flags);
529     return xnn_status_invalid_parameter;
530   }
531 
532   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
533   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && any_padding) {
534     xnn_log_error(
535       "failed to define %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
536       "TensorFlow SAME padding can't be combined with explicit padding specification",
537       xnn_node_type_to_string(xnn_node_type_convolution_2d),
538       input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
539     return xnn_status_invalid_parameter;
540   }
541 
542   // Convert TensorFlow SAME padding to explicit padding specification whenever possible
543   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && (subsampling_height | subsampling_width) == 1) {
544     flags &= ~XNN_FLAG_TENSORFLOW_SAME_PADDING;
545     const uint32_t padding_height = (kernel_height - 1) * dilation_height;
546     const uint32_t padding_width = (kernel_width - 1) * dilation_width;
547     input_padding_left = padding_width / 2;
548     input_padding_top = padding_height / 2;
549     input_padding_right = padding_width - input_padding_left;
550     input_padding_bottom = padding_height - input_padding_top;
551   }
552 
553   if (input_id >= subgraph->num_values) {
554     xnn_log_error(
555       "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
556       xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id);
557     return xnn_status_invalid_parameter;
558   }
559 
560   const struct xnn_value* input_value = &subgraph->values[input_id];
561   if (input_value->type != xnn_value_type_dense_tensor) {
562     xnn_log_error(
563       "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
564       xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, input_value->type);
565     return xnn_status_invalid_parameter;
566   }
567 
568   switch (input_value->datatype) {
569     case xnn_datatype_fp32:
570 #ifndef XNN_NO_QS8_OPERATORS
571     case xnn_datatype_qint8:
572 #endif  // !defined(XNN_NO_QS8_OPERATORS)
573 #ifndef XNN_NO_QU8_OPERATORS
574     case xnn_datatype_quint8:
575 #endif  // !defined(XNN_NO_QU8_OPERATORS)
576       break;
577     default:
578       xnn_log_error(
579         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
580         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id,
581         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
582       return xnn_status_invalid_parameter;
583   }
584 
585   if (filter_id >= subgraph->num_values) {
586     xnn_log_error(
587       "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
588       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id);
589     return xnn_status_invalid_parameter;
590   }
591 
592   const struct xnn_value* filter_value = &subgraph->values[filter_id];
593   if (filter_value->type != xnn_value_type_dense_tensor) {
594     xnn_log_error(
595       "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
596       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id, filter_value->type);
597     return xnn_status_invalid_parameter;
598   }
599 
600   if (filter_value->data == NULL) {
601     xnn_log_error(
602       "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
603       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id);
604     return xnn_status_invalid_parameter;
605   }
606 
607   switch (filter_value->datatype) {
608     case xnn_datatype_fp32:
609       break;
610 #ifndef XNN_NO_QS8_OPERATORS
611     case xnn_datatype_qint8:
612       if (filter_value->quantization.zero_point != 0) {
613         xnn_log_error(
614           "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
615           xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
616           filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
617       }
618       break;
619     case xnn_datatype_qcint8:
620       break;
621 #endif  // !defined(XNN_NO_QS8_OPERATORS)
622 #ifndef XNN_NO_QU8_OPERATORS
623     case xnn_datatype_quint8:
624       break;
625 #endif  // !defined(XNN_NO_QU8_OPERATORS)
626     default:
627       xnn_log_error(
628         "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
629         xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
630         xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
631       return xnn_status_invalid_parameter;
632   }
633 
634   const struct xnn_value* bias_value = NULL;
635   if (bias_id != XNN_INVALID_VALUE_ID) {
636     if (bias_id >= subgraph->num_values) {
637       xnn_log_error(
638         "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
639         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id);
640       return xnn_status_invalid_parameter;
641     }
642 
643     bias_value = &subgraph->values[bias_id];
644     if (bias_value->type != xnn_value_type_dense_tensor) {
645       xnn_log_error(
646         "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
647         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id, bias_value->type);
648       return xnn_status_invalid_parameter;
649     }
650 
651     if (bias_value->data == NULL) {
652       xnn_log_error(
653         "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
654         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id);
655       return xnn_status_invalid_parameter;
656     }
657 
658     switch (bias_value->datatype) {
659       case xnn_datatype_fp32:
660 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
661       case xnn_datatype_qint32:
662 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
663 #ifndef XNN_NO_QS8_OPERATORS
664       case xnn_datatype_qcint32:
665 #endif  // !defined(XNN_NO_QS8_OPERATORS)
666         break;
667       default:
668         xnn_log_error(
669           "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
670           xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id,
671           xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
672         return xnn_status_invalid_parameter;
673     }
674   }
675 
676   if (output_id >= subgraph->num_values) {
677     xnn_log_error(
678       "failed to define %s operator with output ID #%" PRIu32 ": invalid Value ID",
679       xnn_node_type_to_string(xnn_node_type_convolution_2d), output_id);
680     return xnn_status_invalid_parameter;
681   }
682 
683   const struct xnn_value* output_value = &subgraph->values[output_id];
684   if (output_value->type != xnn_value_type_dense_tensor) {
685     xnn_log_error(
686       "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
687       xnn_node_type_to_string(xnn_node_type_convolution_2d), output_id, output_value->type);
688     return xnn_status_invalid_parameter;
689   }
690 
691   switch (output_value->datatype) {
692     case xnn_datatype_fp32:
693 #ifndef XNN_NO_QS8_OPERATORS
694     case xnn_datatype_qint8:
695 #endif  // !defined(XNN_NO_QS8_OPERATORS)
696 #ifndef XNN_NO_QU8_OPERATORS
697     case xnn_datatype_quint8:
698 #endif  // !defined(XNN_NO_QU8_OPERATORS)
699       break;
700     default:
701       xnn_log_error(
702         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
703         xnn_node_type_to_string(xnn_node_type_convolution_2d), output_id,
704         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
705       return xnn_status_invalid_parameter;
706   }
707 
708   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
709   if (bias_value != NULL) {
710     compute_type = validate_datatypes_with_bias(
711       input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
712     if (compute_type == xnn_compute_type_invalid) {
713       xnn_log_error(
714         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
715         ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
716         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_id, bias_id, output_id,
717         xnn_datatype_to_string(input_value->datatype),
718         xnn_datatype_to_string(filter_value->datatype),
719         xnn_datatype_to_string(bias_value->datatype),
720         xnn_datatype_to_string(output_value->datatype));
721       return xnn_status_invalid_parameter;
722     }
723   } else {
724     compute_type = validate_datatypes_without_bias(
725       input_value->datatype, filter_value->datatype, output_value->datatype);
726     if (compute_type == xnn_compute_type_invalid) {
727       xnn_log_error(
728         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
729         ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
730         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_id, output_id,
731         xnn_datatype_to_string(input_value->datatype),
732         xnn_datatype_to_string(filter_value->datatype),
733         xnn_datatype_to_string(output_value->datatype));
734       return xnn_status_invalid_parameter;
735     }
736   }
737 
738 #ifndef XNN_NO_QS8_OPERATORS
739   if (filter_value->datatype == xnn_datatype_qcint8) {
740     if (filter_value->quantization.channel_dimension != 0) {
741       xnn_log_error(
742         "failed to define %s operator with filter ID #%" PRIu32 ": invalid channel dimension %zu",
743         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_value->quantization.channel_dimension);
744       return xnn_status_invalid_parameter;
745     }
746 
747     if (bias_value != NULL) {
748       assert(bias_value->datatype == xnn_datatype_qcint32);
749       if (bias_value->quantization.channel_dimension != 0) {
750         xnn_log_error(
751           "failed to define %s operator with bias ID #%" PRIu32 ": invalid channel dimension %zu",
752           xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id, bias_value->quantization.channel_dimension);
753         return xnn_status_invalid_parameter;
754       }
755     }
756   }
757 #endif  // !defined(XNN_NO_QS8_OPERATORS)
758 
759   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
760   if (node == NULL) {
761     return xnn_status_out_of_memory;
762   }
763 
764   node->type = xnn_node_type_convolution_2d;
765   node->compute_type = compute_type;
766   node->params.convolution_2d.input_padding_top = input_padding_top;
767   node->params.convolution_2d.input_padding_right = input_padding_right;
768   node->params.convolution_2d.input_padding_bottom = input_padding_bottom;
769   node->params.convolution_2d.input_padding_left = input_padding_left;
770   node->params.convolution_2d.kernel_height = kernel_height;
771   node->params.convolution_2d.kernel_width = kernel_width;
772   node->params.convolution_2d.subsampling_height = subsampling_height;
773   node->params.convolution_2d.subsampling_width = subsampling_width;
774   node->params.convolution_2d.dilation_height = dilation_height;
775   node->params.convolution_2d.dilation_width = dilation_width;
776   node->params.convolution_2d.groups = groups;
777   node->params.convolution_2d.group_input_channels = group_input_channels;
778   node->params.convolution_2d.group_output_channels = group_output_channels;
779   node->activation.output_min = output_min;
780   node->activation.output_max = output_max;
781   node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
782   node->inputs[0] = input_id;
783   node->inputs[1] = filter_id;
784   node->inputs[2] = bias_id;
785   node->num_outputs = 1;
786   node->outputs[0] = output_id;
787   node->flags = flags;
788 
789   node->create = create_convolution_operator;
790   node->setup = setup_convolution_operator;
791 
792   return xnn_status_success;
793 };
794