• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <fp16.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/microparams-init.h>
19 #include <xnnpack/params.h>
20 
21 
create_binary_elementwise_nd(uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,const struct vbinary_fused_ukernels * vbinary_fused_ukernels,xnn_operator_t * binary_elementwise_op_out)22 static enum xnn_status create_binary_elementwise_nd(
23     uint32_t flags,
24     const void* params,
25     size_t params_size,
26     uint32_t datatype_init_flags,
27     enum xnn_operator_type operator_type,
28     const struct vbinary_fused_ukernels* vbinary_fused_ukernels,
29     xnn_operator_t* binary_elementwise_op_out)
30 {
31   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
32     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
33       xnn_operator_type_to_string(operator_type));
34     return xnn_status_uninitialized;
35   }
36 
37   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
38     xnn_log_error("failed to create %s operator: operations on data type are not supported",
39       xnn_operator_type_to_string(operator_type));
40     return xnn_status_unsupported_hardware;
41   }
42 
43   xnn_operator_t binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
44   if (binary_elementwise_op == NULL) {
45     xnn_log_error(
46       "failed to allocate %zu bytes for %s operator descriptor",
47       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
48     return xnn_status_out_of_memory;
49   }
50 
51   if (params_size != 0) {
52     memcpy(&binary_elementwise_op->params, params, params_size);
53   }
54 
55   binary_elementwise_op->ukernel.vbinary.op_function   = vbinary_fused_ukernels->op_ukernel;
56   binary_elementwise_op->ukernel.vbinary.opc_function  = vbinary_fused_ukernels->opc_ukernel;
57   binary_elementwise_op->ukernel.vbinary.ropc_function = vbinary_fused_ukernels->ropc_ukernel;
58 
59   binary_elementwise_op->type = operator_type;
60   binary_elementwise_op->flags = flags;
61 
62   binary_elementwise_op->state = xnn_run_state_invalid;
63 
64   *binary_elementwise_op_out = binary_elementwise_op;
65   return xnn_status_success;
66 }
67 
create_binary_elementwise_nd_f16(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)68 static enum xnn_status create_binary_elementwise_nd_f16(
69     float output_min,
70     float output_max,
71     uint32_t flags,
72     enum xnn_operator_type operator_type,
73     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
74     xnn_operator_t* binary_elementwise_op_out)
75 {
76   if (isnan(output_min)) {
77     xnn_log_error(
78       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
79       xnn_operator_type_to_string(operator_type));
80     return xnn_status_invalid_parameter;
81   }
82 
83   if (isnan(output_max)) {
84     xnn_log_error(
85       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
86       xnn_operator_type_to_string(operator_type));
87     return xnn_status_invalid_parameter;
88   }
89 
90   if (fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)) >= fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max))) {
91     xnn_log_error(
92       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
93       xnn_operator_type_to_string(operator_type),
94       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)),
95       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)));
96     return xnn_status_invalid_parameter;
97   }
98 
99   union xnn_f16_minmax_params params;
100   if (vbinary->init.f16_minmax != NULL) {
101     vbinary->init.f16_minmax(&params,
102       fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max));
103   }
104   return create_binary_elementwise_nd(
105     flags,
106     &params,
107     sizeof(params),
108     XNN_INIT_FLAG_F16,
109     operator_type,
110     &vbinary->minmax,
111     binary_elementwise_op_out);
112 }
113 
create_binary_elementwise_nd_f32(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)114 static enum xnn_status create_binary_elementwise_nd_f32(
115     float output_min,
116     float output_max,
117     uint32_t flags,
118     enum xnn_operator_type operator_type,
119     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
120     xnn_operator_t* binary_elementwise_op_out)
121 {
122   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
123     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
124       xnn_operator_type_to_string(operator_type));
125     return xnn_status_uninitialized;
126   }
127 
128   if (isnan(output_min)) {
129     xnn_log_error(
130       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
131       xnn_operator_type_to_string(operator_type));
132     return xnn_status_invalid_parameter;
133   }
134 
135   if (isnan(output_max)) {
136     xnn_log_error(
137       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
138       xnn_operator_type_to_string(operator_type));
139     return xnn_status_invalid_parameter;
140   }
141 
142   if (output_min >= output_max) {
143     xnn_log_error(
144       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
145       xnn_operator_type_to_string(operator_type), output_min, output_max);
146     return xnn_status_invalid_parameter;
147   }
148 
149   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
150   const struct vbinary_fused_ukernels* vbinary_fused_ukernels = &vbinary->minmax;
151   if (linear_activation && vbinary->linear.op_ukernel != NULL) {
152     vbinary_fused_ukernels = &vbinary->linear;
153   }
154 
155   union xnn_f32_minmax_params params;
156   if (vbinary->init.f32_minmax != NULL) {
157     vbinary->init.f32_minmax(&params, output_min, output_max);
158   }
159   return create_binary_elementwise_nd(
160     flags,
161     &params,
162     sizeof(params),
163     XNN_INIT_FLAG_F32,
164     operator_type,
165     vbinary_fused_ukernels,
166     binary_elementwise_op_out);
167 }
168 
xnn_create_add_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)169 enum xnn_status xnn_create_add_nd_qs8(
170     int8_t input1_zero_point,
171     float input1_scale,
172     int8_t input2_zero_point,
173     float input2_scale,
174     int8_t output_zero_point,
175     float output_scale,
176     int8_t output_min,
177     int8_t output_max,
178     uint32_t flags,
179     xnn_operator_t* add_op_out)
180 {
181   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
182     xnn_log_error(
183       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
184       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_scale);
185     return xnn_status_invalid_parameter;
186   }
187 
188   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
189     xnn_log_error(
190       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
191       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_scale);
192     return xnn_status_invalid_parameter;
193   }
194 
195   if (output_scale <= 0.0f || !isnormal(output_scale)) {
196     xnn_log_error(
197       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
198       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_scale);
199     return xnn_status_invalid_parameter;
200   }
201 
202   if (output_min >= output_max) {
203     xnn_log_error(
204       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
205       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_min, output_max);
206     return xnn_status_invalid_parameter;
207   }
208 
209   const float input1_output_scale = input1_scale / output_scale;
210   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
211     xnn_log_error(
212       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
213       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_output_scale);
214     return xnn_status_unsupported_parameter;
215   }
216 
217   const float input2_output_scale = input2_scale / output_scale;
218   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
219     xnn_log_error(
220       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
221       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_output_scale);
222     return xnn_status_unsupported_parameter;
223   }
224 
225   struct {
226     union xnn_qs8_add_minmax_params qs8_add;
227     union xnn_qs8_add_minmax_params qs8_radd;
228   } params;
229   if (xnn_params.qs8.vadd.init.qs8_add != NULL) {
230     xnn_params.qs8.vadd.init.qs8_add(
231       &params.qs8_add, input1_zero_point, input2_zero_point, output_zero_point,
232       input1_output_scale, input2_output_scale, output_min, output_max);
233     xnn_params.qs8.vadd.init.qs8_add(
234       &params.qs8_radd, input2_zero_point, input1_zero_point, output_zero_point,
235       input2_output_scale, input1_output_scale, output_min, output_max);
236   }
237   return create_binary_elementwise_nd(
238     flags,
239     &params,
240     sizeof(params),
241     XNN_INIT_FLAG_QS8,
242     xnn_operator_type_add_nd_qs8,
243     &xnn_params.qs8.vadd.minmax,
244     add_op_out);
245 }
246 
xnn_create_add_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)247 enum xnn_status xnn_create_add_nd_qu8(
248     uint8_t input1_zero_point,
249     float input1_scale,
250     uint8_t input2_zero_point,
251     float input2_scale,
252     uint8_t output_zero_point,
253     float output_scale,
254     uint8_t output_min,
255     uint8_t output_max,
256     uint32_t flags,
257     xnn_operator_t* add_op_out)
258 {
259   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
260     xnn_log_error(
261       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
262       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_scale);
263     return xnn_status_invalid_parameter;
264   }
265 
266   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
267     xnn_log_error(
268       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
269       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_scale);
270     return xnn_status_invalid_parameter;
271   }
272 
273   if (output_scale <= 0.0f || !isnormal(output_scale)) {
274     xnn_log_error(
275       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
276       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_scale);
277     return xnn_status_invalid_parameter;
278   }
279 
280   if (output_min >= output_max) {
281     xnn_log_error(
282       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
283       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_min, output_max);
284     return xnn_status_invalid_parameter;
285   }
286 
287   const float input1_output_scale = input1_scale / output_scale;
288   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
289     xnn_log_error(
290       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
291       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_output_scale);
292     return xnn_status_unsupported_parameter;
293   }
294 
295   const float input2_output_scale = input2_scale / output_scale;
296   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
297     xnn_log_error(
298       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
299       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_output_scale);
300     return xnn_status_unsupported_parameter;
301   }
302 
303   struct {
304     union xnn_qu8_add_minmax_params qu8_add;
305     union xnn_qu8_add_minmax_params qu8_radd;
306   } params;
307   if (xnn_params.qu8.vadd.init.qu8_add != NULL) {
308     xnn_params.qu8.vadd.init.qu8_add(
309       &params.qu8_add, input1_zero_point, input2_zero_point, output_zero_point,
310       input1_output_scale, input2_output_scale, output_min, output_max);
311     xnn_params.qu8.vadd.init.qu8_add(
312       &params.qu8_radd, input2_zero_point, input1_zero_point, output_zero_point,
313       input2_output_scale, input1_output_scale, output_min, output_max);
314   }
315   return create_binary_elementwise_nd(
316     flags,
317     &params,
318     sizeof(params),
319     XNN_INIT_FLAG_QU8,
320     xnn_operator_type_add_nd_qu8,
321     &xnn_params.qu8.vadd.minmax,
322     add_op_out);
323 }
324 
xnn_create_add_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)325 enum xnn_status xnn_create_add_nd_f16(
326     float output_min,
327     float output_max,
328     uint32_t flags,
329     xnn_operator_t* add_op_out)
330 {
331   return create_binary_elementwise_nd_f16(
332     output_min,
333     output_max,
334     flags,
335     xnn_operator_type_add_nd_f16,
336     &xnn_params.f16.vadd,
337     add_op_out);
338 }
339 
xnn_create_add_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)340 enum xnn_status xnn_create_add_nd_f32(
341     float output_min,
342     float output_max,
343     uint32_t flags,
344     xnn_operator_t* add_op_out)
345 {
346   return create_binary_elementwise_nd_f32(
347     output_min,
348     output_max,
349     flags,
350     xnn_operator_type_add_nd_f32,
351     &xnn_params.f32.vadd,
352     add_op_out);
353 }
354 
xnn_create_divide_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)355 enum xnn_status xnn_create_divide_nd_f16(
356     float output_min,
357     float output_max,
358     uint32_t flags,
359     xnn_operator_t* divide_op_out)
360 {
361   return create_binary_elementwise_nd_f16(
362     output_min,
363     output_max,
364     flags,
365     xnn_operator_type_divide_nd_f16,
366     &xnn_params.f16.vdiv,
367     divide_op_out);
368 }
369 
xnn_create_divide_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)370 enum xnn_status xnn_create_divide_nd_f32(
371     float output_min,
372     float output_max,
373     uint32_t flags,
374     xnn_operator_t* divide_op_out)
375 {
376   return create_binary_elementwise_nd_f32(
377     output_min,
378     output_max,
379     flags,
380     xnn_operator_type_divide_nd_f32,
381     &xnn_params.f32.vdiv,
382     divide_op_out);
383 }
384 
xnn_create_maximum_nd_f16(uint32_t flags,xnn_operator_t * maximum_op_out)385 enum xnn_status xnn_create_maximum_nd_f16(
386     uint32_t flags,
387     xnn_operator_t* maximum_op_out)
388 {
389   return create_binary_elementwise_nd(
390     flags,
391     NULL,
392     0,
393     XNN_INIT_FLAG_F16,
394     xnn_operator_type_maximum_nd_f16,
395     &xnn_params.f16.vmax.minmax,
396     maximum_op_out);
397 }
398 
xnn_create_maximum_nd_f32(uint32_t flags,xnn_operator_t * maximum_op_out)399 enum xnn_status xnn_create_maximum_nd_f32(
400     uint32_t flags,
401     xnn_operator_t* maximum_op_out)
402 {
403   union xnn_f32_default_params params;
404   if (xnn_params.f32.vmin.init.f32_default != NULL) {
405     xnn_params.f32.vmin.init.f32_default(&params);
406   }
407   return create_binary_elementwise_nd(
408     flags,
409     &params,
410     sizeof(params),
411     XNN_INIT_FLAG_F32,
412     xnn_operator_type_maximum_nd_f32,
413     &xnn_params.f32.vmax.minmax,
414     maximum_op_out);
415 }
416 
xnn_create_minimum_nd_f16(uint32_t flags,xnn_operator_t * minimum_op_out)417 enum xnn_status xnn_create_minimum_nd_f16(
418     uint32_t flags,
419     xnn_operator_t* minimum_op_out)
420 {
421   return create_binary_elementwise_nd(
422     flags,
423     NULL,
424     0,
425     XNN_INIT_FLAG_F16,
426     xnn_operator_type_minimum_nd_f16,
427     &xnn_params.f16.vmin.minmax,
428     minimum_op_out);
429 }
430 
xnn_create_minimum_nd_f32(uint32_t flags,xnn_operator_t * minimum_op_out)431 enum xnn_status xnn_create_minimum_nd_f32(
432     uint32_t flags,
433     xnn_operator_t* minimum_op_out)
434 {
435   union xnn_f32_default_params params;
436   if (xnn_params.f32.vmin.init.f32_default != NULL) {
437     xnn_params.f32.vmin.init.f32_default(&params);
438   }
439   return create_binary_elementwise_nd(
440     flags,
441     &params,
442     sizeof(params),
443     XNN_INIT_FLAG_F32,
444     xnn_operator_type_minimum_nd_f32,
445     &xnn_params.f32.vmin.minmax,
446     minimum_op_out);
447 }
448 
xnn_create_multiply_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)449 enum xnn_status xnn_create_multiply_nd_f16(
450     float output_min,
451     float output_max,
452     uint32_t flags,
453     xnn_operator_t* multiply_op_out)
454 {
455   return create_binary_elementwise_nd_f16(
456     output_min,
457     output_max,
458     flags,
459     xnn_operator_type_multiply_nd_f16,
460     &xnn_params.f16.vmul,
461     multiply_op_out);
462 }
463 
xnn_create_multiply_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)464 enum xnn_status xnn_create_multiply_nd_f32(
465     float output_min,
466     float output_max,
467     uint32_t flags,
468     xnn_operator_t* multiply_op_out)
469 {
470   return create_binary_elementwise_nd_f32(
471     output_min,
472     output_max,
473     flags,
474     xnn_operator_type_multiply_nd_f32,
475     &xnn_params.f32.vmul,
476     multiply_op_out);
477 }
478 
xnn_create_multiply_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)479 enum xnn_status xnn_create_multiply_nd_qs8(
480     int8_t input1_zero_point,
481     float input1_scale,
482     int8_t input2_zero_point,
483     float input2_scale,
484     int8_t output_zero_point,
485     float output_scale,
486     int8_t output_min,
487     int8_t output_max,
488     uint32_t flags,
489     xnn_operator_t* multiply_op_out)
490 {
491   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
492     xnn_log_error(
493       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
494       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input1_scale);
495     return xnn_status_invalid_parameter;
496   }
497 
498   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
499     xnn_log_error(
500       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
501       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input2_scale);
502     return xnn_status_invalid_parameter;
503   }
504 
505   if (output_scale <= 0.0f || !isnormal(output_scale)) {
506     xnn_log_error(
507       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
508       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_scale);
509     return xnn_status_invalid_parameter;
510   }
511 
512   if (output_min >= output_max) {
513     xnn_log_error(
514       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
515       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_min, output_max);
516     return xnn_status_invalid_parameter;
517   }
518 
519   const float product_scale = input1_scale * input2_scale;
520   const float product_output_scale = product_scale / output_scale;
521   if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
522     xnn_log_error(
523       "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
524       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), product_output_scale);
525     return xnn_status_unsupported_parameter;
526   }
527 
528   struct {
529     union xnn_qs8_mul_minmax_params qs8_mul;
530     union xnn_qs8_mul_minmax_params qs8_rmul;
531   } params;
532   if (xnn_params.qs8.vmul.init.qs8_mul != NULL) {
533     xnn_params.qs8.vmul.init.qs8_mul(
534       &params.qs8_mul, input1_zero_point, input2_zero_point, output_zero_point,
535       product_output_scale, output_min, output_max);
536     xnn_params.qs8.vmul.init.qs8_mul(
537       &params.qs8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
538       product_output_scale, output_min, output_max);
539   }
540   return create_binary_elementwise_nd(
541     flags,
542     &params,
543     sizeof(params),
544     XNN_INIT_FLAG_QS8,
545     xnn_operator_type_multiply_nd_qs8,
546     &xnn_params.qs8.vmul.minmax,
547     multiply_op_out);
548 }
549 
xnn_create_multiply_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)550 enum xnn_status xnn_create_multiply_nd_qu8(
551     uint8_t input1_zero_point,
552     float input1_scale,
553     uint8_t input2_zero_point,
554     float input2_scale,
555     uint8_t output_zero_point,
556     float output_scale,
557     uint8_t output_min,
558     uint8_t output_max,
559     uint32_t flags,
560     xnn_operator_t* multiply_op_out)
561 {
562   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
563     xnn_log_error(
564       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
565       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input1_scale);
566     return xnn_status_invalid_parameter;
567   }
568 
569   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
570     xnn_log_error(
571       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
572       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input2_scale);
573     return xnn_status_invalid_parameter;
574   }
575 
576   if (output_scale <= 0.0f || !isnormal(output_scale)) {
577     xnn_log_error(
578       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
579       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_scale);
580     return xnn_status_invalid_parameter;
581   }
582 
583   if (output_min >= output_max) {
584     xnn_log_error(
585       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
586       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_min, output_max);
587     return xnn_status_invalid_parameter;
588   }
589 
590   const float product_scale = input1_scale * input2_scale;
591   const float product_output_scale = product_scale / output_scale;
592   if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
593     xnn_log_error(
594       "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
595       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), product_output_scale);
596     return xnn_status_unsupported_parameter;
597   }
598 
599   struct {
600     union xnn_qu8_mul_minmax_params qu8_mul;
601     union xnn_qu8_mul_minmax_params qu8_rmul;
602   } params;
603   if (xnn_params.qu8.vmul.init.qu8_mul != NULL) {
604     xnn_params.qu8.vmul.init.qu8_mul(
605       &params.qu8_mul, input1_zero_point, input2_zero_point, output_zero_point,
606       product_output_scale, output_min, output_max);
607     xnn_params.qu8.vmul.init.qu8_mul(
608       &params.qu8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
609       product_output_scale, output_min, output_max);
610   }
611   return create_binary_elementwise_nd(
612     flags,
613     &params,
614     sizeof(params),
615     XNN_INIT_FLAG_QU8,
616     xnn_operator_type_multiply_nd_qu8,
617     &xnn_params.qu8.vmul.minmax,
618     multiply_op_out);
619 }
620 
xnn_create_squared_difference_nd_f16(uint32_t flags,xnn_operator_t * squared_difference_op_out)621 enum xnn_status xnn_create_squared_difference_nd_f16(
622     uint32_t flags,
623     xnn_operator_t* squared_difference_op_out)
624 {
625   return create_binary_elementwise_nd(
626     flags,
627     NULL,
628     0,
629     XNN_INIT_FLAG_F16,
630     xnn_operator_type_squared_difference_nd_f16,
631     &xnn_params.f16.vsqrdiff.minmax,
632     squared_difference_op_out);
633 }
634 
xnn_create_squared_difference_nd_f32(uint32_t flags,xnn_operator_t * squared_difference_op_out)635 enum xnn_status xnn_create_squared_difference_nd_f32(
636     uint32_t flags,
637     xnn_operator_t* squared_difference_op_out)
638 {
639   union xnn_f32_default_params params;
640   if (xnn_params.f32.vmin.init.f32_default != NULL) {
641     xnn_params.f32.vmin.init.f32_default(&params);
642   }
643   return create_binary_elementwise_nd(
644     flags,
645     &params,
646     sizeof(params),
647     XNN_INIT_FLAG_F32,
648     xnn_operator_type_squared_difference_nd_f32,
649     &xnn_params.f32.vsqrdiff.minmax,
650     squared_difference_op_out);
651 }
652 
xnn_create_subtract_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)653 enum xnn_status xnn_create_subtract_nd_f16(
654     float output_min,
655     float output_max,
656     uint32_t flags,
657     xnn_operator_t* subtract_op_out)
658 {
659   return create_binary_elementwise_nd_f16(
660     output_min,
661     output_max,
662     flags,
663     xnn_operator_type_subtract_nd_f16,
664     &xnn_params.f16.vsub,
665     subtract_op_out);
666 }
667 
xnn_create_subtract_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)668 enum xnn_status xnn_create_subtract_nd_f32(
669     float output_min,
670     float output_max,
671     uint32_t flags,
672     xnn_operator_t* subtract_op_out)
673 {
674   return create_binary_elementwise_nd_f32(
675     output_min,
676     output_max,
677     flags,
678     xnn_operator_type_subtract_nd_f32,
679     &xnn_params.f32.vsub,
680     subtract_op_out);
681 }
682 
xnn_create_subtract_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)683 enum xnn_status xnn_create_subtract_nd_qs8(
684     int8_t input1_zero_point,
685     float input1_scale,
686     int8_t input2_zero_point,
687     float input2_scale,
688     int8_t output_zero_point,
689     float output_scale,
690     int8_t output_min,
691     int8_t output_max,
692     uint32_t flags,
693     xnn_operator_t* subtract_op_out)
694 {
695   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
696     xnn_log_error(
697       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
698       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_scale);
699     return xnn_status_invalid_parameter;
700   }
701 
702   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
703     xnn_log_error(
704       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
705       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_scale);
706     return xnn_status_invalid_parameter;
707   }
708 
709   if (output_scale <= 0.0f || !isnormal(output_scale)) {
710     xnn_log_error(
711       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
712       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_scale);
713     return xnn_status_invalid_parameter;
714   }
715 
716   if (output_min >= output_max) {
717     xnn_log_error(
718       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
719       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_min, output_max);
720     return xnn_status_invalid_parameter;
721   }
722 
723   const float input1_output_scale = input1_scale / output_scale;
724   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
725     xnn_log_error(
726       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
727       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_output_scale);
728     return xnn_status_unsupported_parameter;
729   }
730 
731   const float input2_output_scale = input2_scale / output_scale;
732   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
733     xnn_log_error(
734       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
735       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_output_scale);
736     return xnn_status_unsupported_parameter;
737   }
738 
739   struct {
740     union xnn_qs8_add_minmax_params qs8_add;
741     union xnn_qs8_add_minmax_params qs8_radd;
742   } params;
743   if (xnn_params.qs8.vadd.init.qs8_add != NULL) {
744     xnn_params.qs8.vadd.init.qs8_add(
745       &params.qs8_add, input1_zero_point, input2_zero_point, output_zero_point,
746       input1_output_scale, -input2_output_scale, output_min, output_max);
747     xnn_params.qs8.vadd.init.qs8_add(
748       &params.qs8_radd, input2_zero_point, input1_zero_point, output_zero_point,
749       -input2_output_scale, input1_output_scale, output_min, output_max);
750   }
751   return create_binary_elementwise_nd(
752     flags,
753     &params,
754     sizeof(params),
755     XNN_INIT_FLAG_QS8,
756     xnn_operator_type_subtract_nd_qs8,
757     &xnn_params.qs8.vadd.minmax,
758     subtract_op_out);
759 }
760 
xnn_create_subtract_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)761 enum xnn_status xnn_create_subtract_nd_qu8(
762     uint8_t input1_zero_point,
763     float input1_scale,
764     uint8_t input2_zero_point,
765     float input2_scale,
766     uint8_t output_zero_point,
767     float output_scale,
768     uint8_t output_min,
769     uint8_t output_max,
770     uint32_t flags,
771     xnn_operator_t* subtract_op_out)
772 {
773   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
774     xnn_log_error(
775       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
776       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_scale);
777     return xnn_status_invalid_parameter;
778   }
779 
780   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
781     xnn_log_error(
782       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
783       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_scale);
784     return xnn_status_invalid_parameter;
785   }
786 
787   if (output_scale <= 0.0f || !isnormal(output_scale)) {
788     xnn_log_error(
789       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
790       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_scale);
791     return xnn_status_invalid_parameter;
792   }
793 
794   if (output_min >= output_max) {
795     xnn_log_error(
796       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
797       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_min, output_max);
798     return xnn_status_invalid_parameter;
799   }
800 
801   const float input1_output_scale = input1_scale / output_scale;
802   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
803     xnn_log_error(
804       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
805       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_output_scale);
806     return xnn_status_unsupported_parameter;
807   }
808 
809   const float input2_output_scale = input2_scale / output_scale;
810   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
811     xnn_log_error(
812       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
813       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_output_scale);
814     return xnn_status_unsupported_parameter;
815   }
816 
817   struct {
818     union xnn_qu8_add_minmax_params qu8_add;
819     union xnn_qu8_add_minmax_params qu8_radd;
820   } params;
821   if (xnn_params.qu8.vadd.init.qu8_add != NULL) {
822     xnn_params.qu8.vadd.init.qu8_add(
823       &params.qu8_add, input1_zero_point, input2_zero_point, output_zero_point,
824       input1_output_scale, -input2_output_scale, output_min, output_max);
825     xnn_params.qu8.vadd.init.qu8_add(
826       &params.qu8_radd, input2_zero_point, input1_zero_point, output_zero_point,
827       -input2_output_scale, input1_output_scale, output_min, output_max);
828   }
829   return create_binary_elementwise_nd(
830     flags,
831     &params,
832     sizeof(params),
833     XNN_INIT_FLAG_QU8,
834     xnn_operator_type_subtract_nd_qu8,
835     &xnn_params.qu8.vadd.minmax,
836     subtract_op_out);
837 }
838 
setup_binary_elementwise_nd(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,uint32_t log2_element_size,const void * params,size_t params_size,const void * reversed_params,size_t reversed_params_size,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)839 static enum xnn_status setup_binary_elementwise_nd(
840     xnn_operator_t binary_elementwise_op,
841     enum xnn_operator_type expected_operator_type,
842     size_t num_input1_dims,
843     const size_t* input1_shape,
844     size_t num_input2_dims,
845     const size_t* input2_shape,
846     const void* input1,
847     const void* input2,
848     void* output,
849     uint32_t log2_element_size,
850     const void* params,
851     size_t params_size,
852     const void* reversed_params,
853     size_t reversed_params_size,
854     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
855     size_t num_threads)
856 {
857   if (binary_elementwise_op->type != expected_operator_type) {
858     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
859       xnn_operator_type_to_string(expected_operator_type),
860       xnn_operator_type_to_string(binary_elementwise_op->type));
861     return xnn_status_invalid_parameter;
862   }
863   binary_elementwise_op->state = xnn_run_state_invalid;
864 
865   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
866     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
867       xnn_operator_type_to_string(binary_elementwise_op->type));
868     return xnn_status_uninitialized;
869   }
870 
871   if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
872     xnn_log_error(
873       "failed to setup %s operator with %zu and %zu dimensions in input shapes: "
874       "the number of input dimensions must not exceed %d",
875       xnn_operator_type_to_string(binary_elementwise_op->type), num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
876     return xnn_status_unsupported_parameter;
877   }
878 
879   size_t num_compressed_dims = 0;
880   size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
881   size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
882   size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
883   for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
884     compressed_input1_shape[i] = 1;
885     compressed_input2_shape[i] = 1;
886     compressed_output_shape[i] = 1;
887   }
888   bool broadcast_input1 = false;
889   bool broadcast_input2 = false;
890   bool first_nonunit = true;
891   bool degenerate_shape = false;
892   const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
893   for (size_t i = 1; i <= num_common_dims; i++) {
894     const size_t input1_dim = input1_shape[num_input1_dims - i];
895     const size_t input2_dim = input2_shape[num_input2_dims - i];
896     degenerate_shape |= input1_dim == 0;
897     degenerate_shape |= input2_dim == 0;
898     if (input1_dim == 1 && input2_dim == 1) {
899       continue;
900     }
901     assert(!broadcast_input1 || !broadcast_input2);
902 
903     if (input1_dim == 1) {
904       if (!broadcast_input1) {
905         broadcast_input1 = true;
906         broadcast_input2 = false;
907         num_compressed_dims++;
908       }
909       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
910       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
911     } else if (input2_dim == 1) {
912       if (!broadcast_input2) {
913         broadcast_input1 = false;
914         broadcast_input2 = true;
915         num_compressed_dims++;
916       }
917       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
918       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
919     } else if (input1_dim == input2_dim) {
920       if (broadcast_input1 || broadcast_input2 || first_nonunit) {
921         broadcast_input1 = false;
922         broadcast_input2 = false;
923         num_compressed_dims++;
924       }
925       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
926       compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
927       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
928     } else {
929       xnn_log_error(
930         "failed to setup %s operator: "
931         "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
932         xnn_operator_type_to_string(binary_elementwise_op->type),
933         num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
934       return xnn_status_invalid_parameter;
935     }
936     first_nonunit = false;
937   }
938   if (num_input1_dims > num_input2_dims) {
939     if (!broadcast_input2) {
940       num_compressed_dims++;
941     }
942     for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
943       const size_t input1_dim = input1_shape[i];
944       degenerate_shape |= input1_dim == 0;
945       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
946       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
947     }
948   } else if (num_input2_dims > num_input1_dims) {
949     if (!broadcast_input1) {
950       num_compressed_dims++;
951     }
952     for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
953       const size_t input2_dim = input2_shape[i];
954       degenerate_shape |= input2_dim == 0;
955       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
956       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
957     }
958   }
959   num_compressed_dims = max(num_compressed_dims, 1);
960 
961   // Early exit without setting up context if any shape dimension is zero.
962   if (degenerate_shape) {
963     binary_elementwise_op->state = xnn_run_state_skip;
964     return xnn_status_success;
965   }
966 
967   binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
968     .a = input1,
969     .b = input2,
970     .y = output,
971     .elements = compressed_output_shape[0] << log2_element_size,
972   };
973   if (params_size != 0) {
974     memcpy(&binary_elementwise_op->context.elementwise_binary.params, params, params_size);
975   }
976 
977   const size_t* compressed_a_shape = compressed_input1_shape;
978   const size_t* compressed_b_shape = compressed_input2_shape;
979   if (compressed_input1_shape[0] == 1) {
980     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.ropc_function;
981     binary_elementwise_op->context.elementwise_binary.a = input2;
982     binary_elementwise_op->context.elementwise_binary.b = input1;
983     compressed_a_shape = compressed_input2_shape;
984     compressed_b_shape = compressed_input1_shape;
985     if (reversed_params_size != 0) {
986       memcpy(&binary_elementwise_op->context.elementwise_binary.params, reversed_params, reversed_params_size);
987     }
988   } else if (compressed_input2_shape[0] == 1) {
989     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.opc_function;
990   } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
991     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.op_function;
992   }
993   size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
994   for (size_t i = 1; i < num_compressed_dims; i++) {
995     if (compressed_a_shape[i] != 1) {
996       binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride << log2_element_size;
997     }
998     if (compressed_b_shape[i] != 1) {
999       binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride << log2_element_size;
1000     }
1001     binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride << log2_element_size;
1002     a_stride *= compressed_a_shape[i];
1003     b_stride *= compressed_b_shape[i];
1004     y_stride *= compressed_output_shape[i];
1005   }
1006 
1007   if (compressed_output_shape[5] == 1) {
1008     if (compressed_output_shape[4] == 1) {
1009       if (compressed_output_shape[3] == 1) {
1010         if (compressed_output_shape[2] == 1) {
1011           binary_elementwise_op->compute.type = xnn_parallelization_type_1d;
1012           binary_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_elementwise_binary_1d;
1013           binary_elementwise_op->compute.range[0] = compressed_output_shape[1];
1014         } else {
1015           binary_elementwise_op->compute.type = xnn_parallelization_type_2d;
1016           binary_elementwise_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_elementwise_binary_2d;
1017           binary_elementwise_op->compute.range[0] = compressed_output_shape[2];
1018           binary_elementwise_op->compute.range[1] = compressed_output_shape[1];
1019         }
1020       } else {
1021         binary_elementwise_op->compute.type = xnn_parallelization_type_3d;
1022         binary_elementwise_op->compute.task_3d = (pthreadpool_task_3d_t) xnn_compute_elementwise_binary_3d;
1023         binary_elementwise_op->compute.range[0] = compressed_output_shape[3];
1024         binary_elementwise_op->compute.range[1] = compressed_output_shape[2];
1025         binary_elementwise_op->compute.range[2] = compressed_output_shape[1];
1026       }
1027     } else {
1028       binary_elementwise_op->compute.type = xnn_parallelization_type_4d;
1029       binary_elementwise_op->compute.task_4d = (pthreadpool_task_4d_t) xnn_compute_elementwise_binary_4d;
1030       binary_elementwise_op->compute.range[0] = compressed_output_shape[4];
1031       binary_elementwise_op->compute.range[1] = compressed_output_shape[3];
1032       binary_elementwise_op->compute.range[2] = compressed_output_shape[2];
1033       binary_elementwise_op->compute.range[3] = compressed_output_shape[1];
1034     }
1035   } else {
1036     binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
1037     binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
1038     binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
1039     binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
1040     binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
1041     binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
1042     binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
1043   }
1044   binary_elementwise_op->state = xnn_run_state_ready;
1045 
1046   return xnn_status_success;
1047 }
1048 
setup_binary_elementwise_nd_f16(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)1049 static enum xnn_status setup_binary_elementwise_nd_f16(
1050     xnn_operator_t binary_elementwise_op,
1051     enum xnn_operator_type expected_operator_type,
1052     size_t num_input1_dims,
1053     const size_t* input1_shape,
1054     size_t num_input2_dims,
1055     const size_t* input2_shape,
1056     const void* input1,
1057     const void* input2,
1058     void* output,
1059     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
1060     size_t num_threads)
1061 {
1062   return setup_binary_elementwise_nd(
1063     binary_elementwise_op,
1064     expected_operator_type,
1065     num_input1_dims,
1066     input1_shape,
1067     num_input2_dims,
1068     input2_shape,
1069     input1,
1070     input2,
1071     output,
1072     1 /* log2(sizeof(half)) */,
1073     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
1074     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
1075     vbinary,
1076     num_threads);
1077 }
1078 
setup_binary_elementwise_nd_f32(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)1079 static enum xnn_status setup_binary_elementwise_nd_f32(
1080     xnn_operator_t binary_elementwise_op,
1081     enum xnn_operator_type expected_operator_type,
1082     size_t num_input1_dims,
1083     const size_t* input1_shape,
1084     size_t num_input2_dims,
1085     const size_t* input2_shape,
1086     const float* input1,
1087     const float* input2,
1088     float* output,
1089     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
1090     size_t num_threads)
1091 {
1092   return setup_binary_elementwise_nd(
1093     binary_elementwise_op, expected_operator_type,
1094     num_input1_dims, input1_shape,
1095     num_input2_dims, input2_shape,
1096     input1, input2, output,
1097     2 /* log2(sizeof(float)) */,
1098     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1099     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1100     vbinary,
1101     num_threads);
1102 }
1103 
xnn_setup_add_nd_f16(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1104 enum xnn_status xnn_setup_add_nd_f16(
1105     xnn_operator_t add_op,
1106     size_t num_input1_dims,
1107     const size_t* input1_shape,
1108     size_t num_input2_dims,
1109     const size_t* input2_shape,
1110     const void* input1,
1111     const void* input2,
1112     void* output,
1113     pthreadpool_t threadpool)
1114 {
1115   return setup_binary_elementwise_nd_f16(
1116     add_op, xnn_operator_type_add_nd_f16,
1117     num_input1_dims, input1_shape,
1118     num_input2_dims, input2_shape,
1119     input1, input2, output,
1120     &xnn_params.f16.vadd,
1121     pthreadpool_get_threads_count(threadpool));
1122 }
1123 
xnn_setup_add_nd_f32(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1124 enum xnn_status xnn_setup_add_nd_f32(
1125     xnn_operator_t add_op,
1126     size_t num_input1_dims,
1127     const size_t* input1_shape,
1128     size_t num_input2_dims,
1129     const size_t* input2_shape,
1130     const float* input1,
1131     const float* input2,
1132     float* output,
1133     pthreadpool_t threadpool)
1134 {
1135   return setup_binary_elementwise_nd_f32(
1136     add_op, xnn_operator_type_add_nd_f32,
1137     num_input1_dims, input1_shape,
1138     num_input2_dims, input2_shape,
1139     input1, input2, output,
1140     &xnn_params.f32.vadd,
1141     pthreadpool_get_threads_count(threadpool));
1142 }
1143 
xnn_setup_add_nd_qs8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1144 enum xnn_status xnn_setup_add_nd_qs8(
1145     xnn_operator_t add_op,
1146     size_t num_input1_dims,
1147     const size_t* input1_shape,
1148     size_t num_input2_dims,
1149     const size_t* input2_shape,
1150     const int8_t* input1,
1151     const int8_t* input2,
1152     int8_t* output,
1153     pthreadpool_t threadpool)
1154 {
1155   return setup_binary_elementwise_nd(
1156     add_op, xnn_operator_type_add_nd_qs8,
1157     num_input1_dims, input1_shape,
1158     num_input2_dims, input2_shape,
1159     input1, input2, output,
1160     0 /* log2(sizeof(int8_t))) */,
1161     &add_op->params.qs8_add, sizeof(add_op->params.qs8_add),
1162     &add_op->params.qs8_radd, sizeof(add_op->params.qs8_radd),
1163     &xnn_params.qs8.vadd,
1164     pthreadpool_get_threads_count(threadpool));
1165 }
1166 
xnn_setup_add_nd_qu8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1167 enum xnn_status xnn_setup_add_nd_qu8(
1168     xnn_operator_t add_op,
1169     size_t num_input1_dims,
1170     const size_t* input1_shape,
1171     size_t num_input2_dims,
1172     const size_t* input2_shape,
1173     const uint8_t* input1,
1174     const uint8_t* input2,
1175     uint8_t* output,
1176     pthreadpool_t threadpool)
1177 {
1178   return setup_binary_elementwise_nd(
1179     add_op, xnn_operator_type_add_nd_qu8,
1180     num_input1_dims, input1_shape,
1181     num_input2_dims, input2_shape,
1182     input1, input2, output,
1183     0 /* log2(sizeof(uint8_t))) */,
1184     &add_op->params.qu8_add, sizeof(add_op->params.qu8_add),
1185     &add_op->params.qu8_radd, sizeof(add_op->params.qu8_radd),
1186     &xnn_params.qu8.vadd,
1187     pthreadpool_get_threads_count(threadpool));
1188 }
1189 
xnn_setup_divide_nd_f16(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1190 enum xnn_status xnn_setup_divide_nd_f16(
1191     xnn_operator_t divide_op,
1192     size_t num_input1_dims,
1193     const size_t* input1_shape,
1194     size_t num_input2_dims,
1195     const size_t* input2_shape,
1196     const void* input1,
1197     const void* input2,
1198     void* output,
1199     pthreadpool_t threadpool)
1200 {
1201   return setup_binary_elementwise_nd_f16(
1202     divide_op, xnn_operator_type_divide_nd_f16,
1203     num_input1_dims, input1_shape,
1204     num_input2_dims, input2_shape,
1205     input1, input2, output,
1206     &xnn_params.f16.vdiv,
1207     pthreadpool_get_threads_count(threadpool));
1208 }
1209 
xnn_setup_divide_nd_f32(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1210 enum xnn_status xnn_setup_divide_nd_f32(
1211     xnn_operator_t divide_op,
1212     size_t num_input1_dims,
1213     const size_t* input1_shape,
1214     size_t num_input2_dims,
1215     const size_t* input2_shape,
1216     const float* input1,
1217     const float* input2,
1218     float* output,
1219     pthreadpool_t threadpool)
1220 {
1221   return setup_binary_elementwise_nd_f32(
1222     divide_op, xnn_operator_type_divide_nd_f32,
1223     num_input1_dims, input1_shape,
1224     num_input2_dims, input2_shape,
1225     input1, input2, output,
1226     &xnn_params.f32.vdiv,
1227     pthreadpool_get_threads_count(threadpool));
1228 }
1229 
xnn_setup_maximum_nd_f16(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1230 enum xnn_status xnn_setup_maximum_nd_f16(
1231     xnn_operator_t maximum_op,
1232     size_t num_input1_dims,
1233     const size_t* input1_shape,
1234     size_t num_input2_dims,
1235     const size_t* input2_shape,
1236     const void* input1,
1237     const void* input2,
1238     void* output,
1239     pthreadpool_t threadpool)
1240 {
1241   return setup_binary_elementwise_nd_f16(
1242     maximum_op, xnn_operator_type_maximum_nd_f16,
1243     num_input1_dims, input1_shape,
1244     num_input2_dims, input2_shape,
1245     input1, input2, output,
1246     &xnn_params.f16.vmax,
1247     pthreadpool_get_threads_count(threadpool));
1248 }
1249 
xnn_setup_maximum_nd_f32(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1250 enum xnn_status xnn_setup_maximum_nd_f32(
1251     xnn_operator_t maximum_op,
1252     size_t num_input1_dims,
1253     const size_t* input1_shape,
1254     size_t num_input2_dims,
1255     const size_t* input2_shape,
1256     const float* input1,
1257     const float* input2,
1258     float* output,
1259     pthreadpool_t threadpool)
1260 {
1261   return setup_binary_elementwise_nd_f32(
1262     maximum_op, xnn_operator_type_maximum_nd_f32,
1263     num_input1_dims, input1_shape,
1264     num_input2_dims, input2_shape,
1265     input1, input2, output,
1266     &xnn_params.f32.vmax,
1267     pthreadpool_get_threads_count(threadpool));
1268 }
1269 
xnn_setup_minimum_nd_f16(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1270 enum xnn_status xnn_setup_minimum_nd_f16(
1271     xnn_operator_t minimum_op,
1272     size_t num_input1_dims,
1273     const size_t* input1_shape,
1274     size_t num_input2_dims,
1275     const size_t* input2_shape,
1276     const void* input1,
1277     const void* input2,
1278     void* output,
1279     pthreadpool_t threadpool)
1280 {
1281   return setup_binary_elementwise_nd_f16(
1282     minimum_op, xnn_operator_type_minimum_nd_f16,
1283     num_input1_dims, input1_shape,
1284     num_input2_dims, input2_shape,
1285     input1, input2, output,
1286     &xnn_params.f16.vmin,
1287     pthreadpool_get_threads_count(threadpool));
1288 }
1289 
xnn_setup_minimum_nd_f32(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1290 enum xnn_status xnn_setup_minimum_nd_f32(
1291     xnn_operator_t minimum_op,
1292     size_t num_input1_dims,
1293     const size_t* input1_shape,
1294     size_t num_input2_dims,
1295     const size_t* input2_shape,
1296     const float* input1,
1297     const float* input2,
1298     float* output,
1299     pthreadpool_t threadpool)
1300 {
1301   return setup_binary_elementwise_nd_f32(
1302     minimum_op, xnn_operator_type_minimum_nd_f32,
1303     num_input1_dims, input1_shape,
1304     num_input2_dims, input2_shape,
1305     input1, input2, output,
1306     &xnn_params.f32.vmin,
1307     pthreadpool_get_threads_count(threadpool));
1308 }
1309 
xnn_setup_multiply_nd_f16(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1310 enum xnn_status xnn_setup_multiply_nd_f16(
1311     xnn_operator_t multiply_op,
1312     size_t num_input1_dims,
1313     const size_t* input1_shape,
1314     size_t num_input2_dims,
1315     const size_t* input2_shape,
1316     const void* input1,
1317     const void* input2,
1318     void* output,
1319     pthreadpool_t threadpool)
1320 {
1321   return setup_binary_elementwise_nd_f16(
1322     multiply_op, xnn_operator_type_multiply_nd_f16,
1323     num_input1_dims, input1_shape,
1324     num_input2_dims, input2_shape,
1325     input1, input2, output,
1326     &xnn_params.f16.vmul,
1327     pthreadpool_get_threads_count(threadpool));
1328 }
1329 
xnn_setup_multiply_nd_f32(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1330 enum xnn_status xnn_setup_multiply_nd_f32(
1331     xnn_operator_t multiply_op,
1332     size_t num_input1_dims,
1333     const size_t* input1_shape,
1334     size_t num_input2_dims,
1335     const size_t* input2_shape,
1336     const float* input1,
1337     const float* input2,
1338     float* output,
1339     pthreadpool_t threadpool)
1340 {
1341   return setup_binary_elementwise_nd_f32(
1342     multiply_op, xnn_operator_type_multiply_nd_f32,
1343     num_input1_dims, input1_shape,
1344     num_input2_dims, input2_shape,
1345     input1, input2, output,
1346     &xnn_params.f32.vmul,
1347     pthreadpool_get_threads_count(threadpool));
1348 }
1349 
xnn_setup_multiply_nd_qs8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1350 enum xnn_status xnn_setup_multiply_nd_qs8(
1351     xnn_operator_t multiply_op,
1352     size_t num_input1_dims,
1353     const size_t* input1_shape,
1354     size_t num_input2_dims,
1355     const size_t* input2_shape,
1356     const int8_t* input1,
1357     const int8_t* input2,
1358     int8_t* output,
1359     pthreadpool_t threadpool)
1360 {
1361   return setup_binary_elementwise_nd(
1362     multiply_op, xnn_operator_type_multiply_nd_qs8,
1363     num_input1_dims, input1_shape,
1364     num_input2_dims, input2_shape,
1365     input1, input2, output,
1366     0 /* log2(sizeof(int8_t))) */,
1367     &multiply_op->params.qs8_mul, sizeof(multiply_op->params.qs8_mul),
1368     &multiply_op->params.qs8_rmul, sizeof(multiply_op->params.qs8_rmul),
1369     &xnn_params.qs8.vmul,
1370     pthreadpool_get_threads_count(threadpool));
1371 }
1372 
xnn_setup_multiply_nd_qu8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1373 enum xnn_status xnn_setup_multiply_nd_qu8(
1374     xnn_operator_t multiply_op,
1375     size_t num_input1_dims,
1376     const size_t* input1_shape,
1377     size_t num_input2_dims,
1378     const size_t* input2_shape,
1379     const uint8_t* input1,
1380     const uint8_t* input2,
1381     uint8_t* output,
1382     pthreadpool_t threadpool)
1383 {
1384   return setup_binary_elementwise_nd(
1385     multiply_op, xnn_operator_type_multiply_nd_qu8,
1386     num_input1_dims, input1_shape,
1387     num_input2_dims, input2_shape,
1388     input1, input2, output,
1389     0 /* log2(sizeof(uint8_t))) */,
1390     &multiply_op->params.qu8_mul, sizeof(multiply_op->params.qu8_mul),
1391     &multiply_op->params.qu8_rmul, sizeof(multiply_op->params.qu8_rmul),
1392     &xnn_params.qu8.vmul,
1393     pthreadpool_get_threads_count(threadpool));
1394 }
1395 
xnn_setup_squared_difference_nd_f16(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1396 enum xnn_status xnn_setup_squared_difference_nd_f16(
1397     xnn_operator_t squared_difference_op,
1398     size_t num_input1_dims,
1399     const size_t* input1_shape,
1400     size_t num_input2_dims,
1401     const size_t* input2_shape,
1402     const void* input1,
1403     const void* input2,
1404     void* output,
1405     pthreadpool_t threadpool)
1406 {
1407   return setup_binary_elementwise_nd_f16(
1408     squared_difference_op, xnn_operator_type_squared_difference_nd_f16,
1409     num_input1_dims, input1_shape,
1410     num_input2_dims, input2_shape,
1411     input1, input2, output,
1412     &xnn_params.f16.vsqrdiff,
1413     pthreadpool_get_threads_count(threadpool));
1414 }
1415 
xnn_setup_squared_difference_nd_f32(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1416 enum xnn_status xnn_setup_squared_difference_nd_f32(
1417     xnn_operator_t squared_difference_op,
1418     size_t num_input1_dims,
1419     const size_t* input1_shape,
1420     size_t num_input2_dims,
1421     const size_t* input2_shape,
1422     const float* input1,
1423     const float* input2,
1424     float* output,
1425     pthreadpool_t threadpool)
1426 {
1427   return setup_binary_elementwise_nd_f32(
1428     squared_difference_op, xnn_operator_type_squared_difference_nd_f32,
1429     num_input1_dims, input1_shape,
1430     num_input2_dims, input2_shape,
1431     input1, input2, output,
1432     &xnn_params.f32.vsqrdiff,
1433     pthreadpool_get_threads_count(threadpool));
1434 }
1435 
xnn_setup_subtract_nd_f16(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1436 enum xnn_status xnn_setup_subtract_nd_f16(
1437     xnn_operator_t subtract_op,
1438     size_t num_input1_dims,
1439     const size_t* input1_shape,
1440     size_t num_input2_dims,
1441     const size_t* input2_shape,
1442     const void* input1,
1443     const void* input2,
1444     void* output,
1445     pthreadpool_t threadpool)
1446 {
1447   return setup_binary_elementwise_nd_f16(
1448     subtract_op, xnn_operator_type_subtract_nd_f16,
1449     num_input1_dims, input1_shape,
1450     num_input2_dims, input2_shape,
1451     input1, input2, output,
1452     &xnn_params.f16.vsub,
1453     pthreadpool_get_threads_count(threadpool));
1454 }
1455 
xnn_setup_subtract_nd_f32(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1456 enum xnn_status xnn_setup_subtract_nd_f32(
1457     xnn_operator_t subtract_op,
1458     size_t num_input1_dims,
1459     const size_t* input1_shape,
1460     size_t num_input2_dims,
1461     const size_t* input2_shape,
1462     const float* input1,
1463     const float* input2,
1464     float* output,
1465     pthreadpool_t threadpool)
1466 {
1467   return setup_binary_elementwise_nd_f32(
1468     subtract_op, xnn_operator_type_subtract_nd_f32,
1469     num_input1_dims, input1_shape,
1470     num_input2_dims, input2_shape,
1471     input1, input2, output,
1472     &xnn_params.f32.vsub,
1473     pthreadpool_get_threads_count(threadpool));
1474 }
1475 
xnn_setup_subtract_nd_qs8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1476 enum xnn_status xnn_setup_subtract_nd_qs8(
1477     xnn_operator_t subtract_op,
1478     size_t num_input1_dims,
1479     const size_t* input1_shape,
1480     size_t num_input2_dims,
1481     const size_t* input2_shape,
1482     const int8_t* input1,
1483     const int8_t* input2,
1484     int8_t* output,
1485     pthreadpool_t threadpool)
1486 {
1487   return setup_binary_elementwise_nd(
1488     subtract_op, xnn_operator_type_subtract_nd_qs8,
1489     num_input1_dims, input1_shape,
1490     num_input2_dims, input2_shape,
1491     input1, input2, output,
1492     0 /* log2(sizeof(int8_t))) */,
1493     &subtract_op->params.qs8_add, sizeof(subtract_op->params.qs8_add),
1494     &subtract_op->params.qs8_radd, sizeof(subtract_op->params.qs8_radd),
1495     &xnn_params.qs8.vadd,
1496     pthreadpool_get_threads_count(threadpool));
1497 }
1498 
xnn_setup_subtract_nd_qu8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1499 enum xnn_status xnn_setup_subtract_nd_qu8(
1500     xnn_operator_t subtract_op,
1501     size_t num_input1_dims,
1502     const size_t* input1_shape,
1503     size_t num_input2_dims,
1504     const size_t* input2_shape,
1505     const uint8_t* input1,
1506     const uint8_t* input2,
1507     uint8_t* output,
1508     pthreadpool_t threadpool)
1509 {
1510   return setup_binary_elementwise_nd(
1511     subtract_op, xnn_operator_type_subtract_nd_qu8,
1512     num_input1_dims, input1_shape,
1513     num_input2_dims, input2_shape,
1514     input1, input2, output,
1515     0 /* log2(sizeof(uint8_t))) */,
1516     &subtract_op->params.qu8_add, sizeof(subtract_op->params.qu8_add),
1517     &subtract_op->params.qu8_radd, sizeof(subtract_op->params.qu8_radd),
1518     &xnn_params.qu8.vadd,
1519     pthreadpool_get_threads_count(threadpool));
1520 }
1521