• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <fp16.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/params-init.h>
19 #include <xnnpack/params.h>
20 
21 
create_binary_elementwise_nd(uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,const struct vbinary_fused_ukernels * vbinary_fused_ukernels,xnn_operator_t * binary_elementwise_op_out)22 static enum xnn_status create_binary_elementwise_nd(
23     uint32_t flags,
24     const void* params,
25     size_t params_size,
26     uint32_t datatype_init_flags,
27     enum xnn_operator_type operator_type,
28     const struct vbinary_fused_ukernels* vbinary_fused_ukernels,
29     xnn_operator_t* binary_elementwise_op_out)
30 {
31   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
32     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
33       xnn_operator_type_to_string(operator_type));
34     return xnn_status_uninitialized;
35   }
36 
37   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
38     xnn_log_error("failed to create %s operator: operations on data type are not supported",
39       xnn_operator_type_to_string(operator_type));
40     return xnn_status_unsupported_hardware;
41   }
42 
43   xnn_operator_t binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
44   if (binary_elementwise_op == NULL) {
45     xnn_log_error(
46       "failed to allocate %zu bytes for %s operator descriptor",
47       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
48     return xnn_status_out_of_memory;
49   }
50 
51   if (params_size != 0) {
52     memcpy(&binary_elementwise_op->params, params, params_size);
53   }
54 
55   binary_elementwise_op->ukernel.vbinary.op_function   = vbinary_fused_ukernels->op_ukernel;
56   binary_elementwise_op->ukernel.vbinary.opc_function  = vbinary_fused_ukernels->opc_ukernel;
57   binary_elementwise_op->ukernel.vbinary.ropc_function = vbinary_fused_ukernels->ropc_ukernel;
58 
59   binary_elementwise_op->type = operator_type;
60   binary_elementwise_op->flags = flags;
61 
62   binary_elementwise_op->state = xnn_run_state_invalid;
63 
64   *binary_elementwise_op_out = binary_elementwise_op;
65   return xnn_status_success;
66 }
67 
create_binary_elementwise_nd_f16(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)68 static enum xnn_status create_binary_elementwise_nd_f16(
69     float output_min,
70     float output_max,
71     uint32_t flags,
72     enum xnn_operator_type operator_type,
73     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
74     xnn_operator_t* binary_elementwise_op_out)
75 {
76   if (isnan(output_min)) {
77     xnn_log_error(
78       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
79       xnn_operator_type_to_string(operator_type));
80     return xnn_status_invalid_parameter;
81   }
82 
83   if (isnan(output_max)) {
84     xnn_log_error(
85       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
86       xnn_operator_type_to_string(operator_type));
87     return xnn_status_invalid_parameter;
88   }
89 
90   if (fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)) >= fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max))) {
91     xnn_log_error(
92       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
93       xnn_operator_type_to_string(operator_type),
94       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)),
95       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)));
96     return xnn_status_invalid_parameter;
97   }
98 
99   union xnn_f16_minmax_params params;
100   if (vbinary->init.f16_minmax != NULL) {
101     vbinary->init.f16_minmax(&params,
102       fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max));
103   }
104   return create_binary_elementwise_nd(
105     flags,
106     &params,
107     sizeof(params),
108     XNN_INIT_FLAG_F16,
109     operator_type,
110     &vbinary->minmax,
111     binary_elementwise_op_out);
112 }
113 
create_binary_elementwise_nd_f32(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)114 static enum xnn_status create_binary_elementwise_nd_f32(
115     float output_min,
116     float output_max,
117     uint32_t flags,
118     enum xnn_operator_type operator_type,
119     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
120     xnn_operator_t* binary_elementwise_op_out)
121 {
122   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
123     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
124       xnn_operator_type_to_string(operator_type));
125     return xnn_status_uninitialized;
126   }
127 
128   if (isnan(output_min)) {
129     xnn_log_error(
130       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
131       xnn_operator_type_to_string(operator_type));
132     return xnn_status_invalid_parameter;
133   }
134 
135   if (isnan(output_max)) {
136     xnn_log_error(
137       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
138       xnn_operator_type_to_string(operator_type));
139     return xnn_status_invalid_parameter;
140   }
141 
142   if (output_min >= output_max) {
143     xnn_log_error(
144       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
145       xnn_operator_type_to_string(operator_type), output_min, output_max);
146     return xnn_status_invalid_parameter;
147   }
148 
149   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
150   const struct vbinary_fused_ukernels* vbinary_fused_ukernels = &vbinary->minmax;
151   if (linear_activation && vbinary->linear.op_ukernel != NULL) {
152     vbinary_fused_ukernels = &vbinary->linear;
153   }
154 
155   union xnn_f32_minmax_params params;
156   if (vbinary->init.f32_minmax != NULL) {
157     vbinary->init.f32_minmax(&params, output_min, output_max);
158   }
159   return create_binary_elementwise_nd(
160     flags,
161     &params,
162     sizeof(params),
163     XNN_INIT_FLAG_F32,
164     operator_type,
165     vbinary_fused_ukernels,
166     binary_elementwise_op_out);
167 }
168 
xnn_create_add_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)169 enum xnn_status xnn_create_add_nd_qs8(
170     int8_t input1_zero_point,
171     float input1_scale,
172     int8_t input2_zero_point,
173     float input2_scale,
174     int8_t output_zero_point,
175     float output_scale,
176     int8_t output_min,
177     int8_t output_max,
178     uint32_t flags,
179     xnn_operator_t* add_op_out)
180 {
181   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
182     xnn_log_error(
183       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
184       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_scale);
185     return xnn_status_invalid_parameter;
186   }
187 
188   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
189     xnn_log_error(
190       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
191       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_scale);
192     return xnn_status_invalid_parameter;
193   }
194 
195   if (output_scale <= 0.0f || !isnormal(output_scale)) {
196     xnn_log_error(
197       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
198       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_scale);
199     return xnn_status_invalid_parameter;
200   }
201 
202   if (output_min >= output_max) {
203     xnn_log_error(
204       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
205       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_min, output_max);
206     return xnn_status_invalid_parameter;
207   }
208 
209   const float input1_output_scale = input1_scale / output_scale;
210   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
211     xnn_log_error(
212       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
213       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_output_scale);
214     return xnn_status_unsupported_parameter;
215   }
216 
217   const float input2_output_scale = input2_scale / output_scale;
218   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
219     xnn_log_error(
220       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
221       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_output_scale);
222     return xnn_status_unsupported_parameter;
223   }
224 
225   struct {
226     union xnn_qs8_addsub_minmax_params qs8_addsub;
227     union xnn_qs8_addsub_minmax_params qs8_raddsub;
228   } params;
229   if (xnn_params.qs8.vadd.init.qs8_addsub != NULL) {
230     xnn_params.qs8.vadd.init.qs8_addsub(
231       &params.qs8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
232       input1_output_scale, input2_output_scale, output_min, output_max);
233     xnn_params.qs8.vadd.init.qs8_addsub(
234       &params.qs8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
235       input2_output_scale, input1_output_scale, output_min, output_max);
236   }
237   return create_binary_elementwise_nd(
238     flags,
239     &params,
240     sizeof(params),
241     XNN_INIT_FLAG_QS8,
242     xnn_operator_type_add_nd_qs8,
243     &xnn_params.qs8.vadd.minmax,
244     add_op_out);
245 }
246 
xnn_create_add_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)247 enum xnn_status xnn_create_add_nd_qu8(
248     uint8_t input1_zero_point,
249     float input1_scale,
250     uint8_t input2_zero_point,
251     float input2_scale,
252     uint8_t output_zero_point,
253     float output_scale,
254     uint8_t output_min,
255     uint8_t output_max,
256     uint32_t flags,
257     xnn_operator_t* add_op_out)
258 {
259   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
260     xnn_log_error(
261       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
262       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_scale);
263     return xnn_status_invalid_parameter;
264   }
265 
266   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
267     xnn_log_error(
268       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
269       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_scale);
270     return xnn_status_invalid_parameter;
271   }
272 
273   if (output_scale <= 0.0f || !isnormal(output_scale)) {
274     xnn_log_error(
275       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
276       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_scale);
277     return xnn_status_invalid_parameter;
278   }
279 
280   if (output_min >= output_max) {
281     xnn_log_error(
282       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
283       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_min, output_max);
284     return xnn_status_invalid_parameter;
285   }
286 
287   const float input1_output_scale = input1_scale / output_scale;
288   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
289     xnn_log_error(
290       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
291       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_output_scale);
292     return xnn_status_unsupported_parameter;
293   }
294 
295   const float input2_output_scale = input2_scale / output_scale;
296   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
297     xnn_log_error(
298       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
299       xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_output_scale);
300     return xnn_status_unsupported_parameter;
301   }
302 
303   struct {
304     union xnn_qu8_addsub_minmax_params qu8_addsub;
305     union xnn_qu8_addsub_minmax_params qu8_raddsub;
306   } params;
307   if (xnn_params.qu8.vadd.init.qu8_addsub != NULL) {
308     xnn_params.qu8.vadd.init.qu8_addsub(
309       &params.qu8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
310       input1_output_scale, input2_output_scale, output_min, output_max);
311     xnn_params.qu8.vadd.init.qu8_addsub(
312       &params.qu8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
313       input2_output_scale, input1_output_scale, output_min, output_max);
314   }
315   return create_binary_elementwise_nd(
316     flags,
317     &params,
318     sizeof(params),
319     XNN_INIT_FLAG_QU8,
320     xnn_operator_type_add_nd_qu8,
321     &xnn_params.qu8.vadd.minmax,
322     add_op_out);
323 }
324 
xnn_create_add_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)325 enum xnn_status xnn_create_add_nd_f16(
326     float output_min,
327     float output_max,
328     uint32_t flags,
329     xnn_operator_t* add_op_out)
330 {
331   return create_binary_elementwise_nd_f16(
332     output_min,
333     output_max,
334     flags,
335     xnn_operator_type_add_nd_f16,
336     &xnn_params.f16.vadd,
337     add_op_out);
338 }
339 
xnn_create_add_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)340 enum xnn_status xnn_create_add_nd_f32(
341     float output_min,
342     float output_max,
343     uint32_t flags,
344     xnn_operator_t* add_op_out)
345 {
346   return create_binary_elementwise_nd_f32(
347     output_min,
348     output_max,
349     flags,
350     xnn_operator_type_add_nd_f32,
351     &xnn_params.f32.vadd,
352     add_op_out);
353 }
354 
xnn_create_divide_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)355 enum xnn_status xnn_create_divide_nd_f32(
356     float output_min,
357     float output_max,
358     uint32_t flags,
359     xnn_operator_t* divide_op_out)
360 {
361   return create_binary_elementwise_nd_f32(
362     output_min,
363     output_max,
364     flags,
365     xnn_operator_type_divide_nd_f32,
366     &xnn_params.f32.vdiv,
367     divide_op_out);
368 }
369 
xnn_create_maximum_nd_f32(uint32_t flags,xnn_operator_t * maximum_op_out)370 enum xnn_status xnn_create_maximum_nd_f32(
371     uint32_t flags,
372     xnn_operator_t* maximum_op_out)
373 {
374   union xnn_f32_default_params params;
375   if (xnn_params.f32.vmin.init.f32_default != NULL) {
376     xnn_params.f32.vmin.init.f32_default(&params);
377   }
378   return create_binary_elementwise_nd(
379     flags,
380     &params,
381     sizeof(params),
382     XNN_INIT_FLAG_F32,
383     xnn_operator_type_maximum_nd_f32,
384     &xnn_params.f32.vmax.minmax,
385     maximum_op_out);
386 }
387 
xnn_create_minimum_nd_f32(uint32_t flags,xnn_operator_t * minimum_op_out)388 enum xnn_status xnn_create_minimum_nd_f32(
389     uint32_t flags,
390     xnn_operator_t* minimum_op_out)
391 {
392   union xnn_f32_default_params params;
393   if (xnn_params.f32.vmin.init.f32_default != NULL) {
394     xnn_params.f32.vmin.init.f32_default(&params);
395   }
396   return create_binary_elementwise_nd(
397     flags,
398     &params,
399     sizeof(params),
400     XNN_INIT_FLAG_F32,
401     xnn_operator_type_minimum_nd_f32,
402     &xnn_params.f32.vmin.minmax,
403     minimum_op_out);
404 }
405 
xnn_create_multiply_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)406 enum xnn_status xnn_create_multiply_nd_qs8(
407     int8_t input1_zero_point,
408     float input1_scale,
409     int8_t input2_zero_point,
410     float input2_scale,
411     int8_t output_zero_point,
412     float output_scale,
413     int8_t output_min,
414     int8_t output_max,
415     uint32_t flags,
416     xnn_operator_t* multiply_op_out)
417 {
418   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
419     xnn_log_error(
420       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
421       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input1_scale);
422     return xnn_status_invalid_parameter;
423   }
424 
425   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
426     xnn_log_error(
427       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
428       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input2_scale);
429     return xnn_status_invalid_parameter;
430   }
431 
432   if (output_scale <= 0.0f || !isnormal(output_scale)) {
433     xnn_log_error(
434       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
435       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_scale);
436     return xnn_status_invalid_parameter;
437   }
438 
439   if (output_min >= output_max) {
440     xnn_log_error(
441       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
442       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_min, output_max);
443     return xnn_status_invalid_parameter;
444   }
445 
446   const float product_scale = input1_scale * input2_scale;
447   const float product_output_scale = product_scale / output_scale;
448   if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
449     xnn_log_error(
450       "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
451       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), product_output_scale);
452     return xnn_status_unsupported_parameter;
453   }
454 
455   struct {
456     union xnn_qs8_mul_minmax_params qs8_mul;
457     union xnn_qs8_mul_minmax_params qs8_rmul;
458   } params;
459   if (xnn_params.qs8.vmul.init.qs8_mul != NULL) {
460     xnn_params.qs8.vmul.init.qs8_mul(
461       &params.qs8_mul, input1_zero_point, input2_zero_point, output_zero_point,
462       product_output_scale, output_min, output_max);
463     xnn_params.qs8.vmul.init.qs8_mul(
464       &params.qs8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
465       product_output_scale, output_min, output_max);
466   }
467   return create_binary_elementwise_nd(
468     flags,
469     &params,
470     sizeof(params),
471     XNN_INIT_FLAG_QS8,
472     xnn_operator_type_multiply_nd_qs8,
473     &xnn_params.qs8.vmul.minmax,
474     multiply_op_out);
475 }
476 
xnn_create_multiply_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)477 enum xnn_status xnn_create_multiply_nd_qu8(
478     uint8_t input1_zero_point,
479     float input1_scale,
480     uint8_t input2_zero_point,
481     float input2_scale,
482     uint8_t output_zero_point,
483     float output_scale,
484     uint8_t output_min,
485     uint8_t output_max,
486     uint32_t flags,
487     xnn_operator_t* multiply_op_out)
488 {
489   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
490     xnn_log_error(
491       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
492       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input1_scale);
493     return xnn_status_invalid_parameter;
494   }
495 
496   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
497     xnn_log_error(
498       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
499       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input2_scale);
500     return xnn_status_invalid_parameter;
501   }
502 
503   if (output_scale <= 0.0f || !isnormal(output_scale)) {
504     xnn_log_error(
505       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
506       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_scale);
507     return xnn_status_invalid_parameter;
508   }
509 
510   if (output_min >= output_max) {
511     xnn_log_error(
512       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
513       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_min, output_max);
514     return xnn_status_invalid_parameter;
515   }
516 
517   const float product_scale = input1_scale * input2_scale;
518   const float product_output_scale = product_scale / output_scale;
519   if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
520     xnn_log_error(
521       "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
522       xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), product_output_scale);
523     return xnn_status_unsupported_parameter;
524   }
525 
526   struct {
527     union xnn_qu8_mul_minmax_params qu8_mul;
528     union xnn_qu8_mul_minmax_params qu8_rmul;
529   } params;
530   if (xnn_params.qu8.vmul.init.qu8_mul != NULL) {
531     xnn_params.qu8.vmul.init.qu8_mul(
532       &params.qu8_mul, input1_zero_point, input2_zero_point, output_zero_point,
533       product_output_scale, output_min, output_max);
534     xnn_params.qu8.vmul.init.qu8_mul(
535       &params.qu8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
536       product_output_scale, output_min, output_max);
537   }
538   return create_binary_elementwise_nd(
539     flags,
540     &params,
541     sizeof(params),
542     XNN_INIT_FLAG_QU8,
543     xnn_operator_type_multiply_nd_qu8,
544     &xnn_params.qu8.vmul.minmax,
545     multiply_op_out);
546 }
547 
xnn_create_multiply_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)548 enum xnn_status xnn_create_multiply_nd_f16(
549     float output_min,
550     float output_max,
551     uint32_t flags,
552     xnn_operator_t* multiply_op_out)
553 {
554   return create_binary_elementwise_nd_f16(
555     output_min,
556     output_max,
557     flags,
558     xnn_operator_type_multiply_nd_f16,
559     &xnn_params.f16.vmul,
560     multiply_op_out);
561 }
562 
xnn_create_multiply_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)563 enum xnn_status xnn_create_multiply_nd_f32(
564     float output_min,
565     float output_max,
566     uint32_t flags,
567     xnn_operator_t* multiply_op_out)
568 {
569   return create_binary_elementwise_nd_f32(
570     output_min,
571     output_max,
572     flags,
573     xnn_operator_type_multiply_nd_f32,
574     &xnn_params.f32.vmul,
575     multiply_op_out);
576 }
577 
xnn_create_squared_difference_nd_f32(uint32_t flags,xnn_operator_t * squared_difference_op_out)578 enum xnn_status xnn_create_squared_difference_nd_f32(
579     uint32_t flags,
580     xnn_operator_t* squared_difference_op_out)
581 {
582   union xnn_f32_default_params params;
583   if (xnn_params.f32.vmin.init.f32_default != NULL) {
584     xnn_params.f32.vmin.init.f32_default(&params);
585   }
586   return create_binary_elementwise_nd(
587     flags,
588     &params,
589     sizeof(params),
590     XNN_INIT_FLAG_F32,
591     xnn_operator_type_squared_difference_nd_f32,
592     &xnn_params.f32.vsqrdiff.minmax,
593     squared_difference_op_out);
594 }
595 
xnn_create_subtract_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)596 enum xnn_status xnn_create_subtract_nd_qs8(
597     int8_t input1_zero_point,
598     float input1_scale,
599     int8_t input2_zero_point,
600     float input2_scale,
601     int8_t output_zero_point,
602     float output_scale,
603     int8_t output_min,
604     int8_t output_max,
605     uint32_t flags,
606     xnn_operator_t* subtract_op_out)
607 {
608   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
609     xnn_log_error(
610       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
611       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_scale);
612     return xnn_status_invalid_parameter;
613   }
614 
615   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
616     xnn_log_error(
617       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
618       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_scale);
619     return xnn_status_invalid_parameter;
620   }
621 
622   if (output_scale <= 0.0f || !isnormal(output_scale)) {
623     xnn_log_error(
624       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
625       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_scale);
626     return xnn_status_invalid_parameter;
627   }
628 
629   if (output_min >= output_max) {
630     xnn_log_error(
631       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
632       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_min, output_max);
633     return xnn_status_invalid_parameter;
634   }
635 
636   const float input1_output_scale = input1_scale / output_scale;
637   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
638     xnn_log_error(
639       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
640       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_output_scale);
641     return xnn_status_unsupported_parameter;
642   }
643 
644   const float input2_output_scale = input2_scale / output_scale;
645   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
646     xnn_log_error(
647       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
648       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_output_scale);
649     return xnn_status_unsupported_parameter;
650   }
651 
652   struct {
653     union xnn_qs8_addsub_minmax_params qs8_addsub;
654     union xnn_qs8_addsub_minmax_params qs8_raddsub;
655   } params;
656   if (xnn_params.qs8.vadd.init.qs8_addsub != NULL) {
657     xnn_params.qs8.vadd.init.qs8_addsub(
658       &params.qs8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
659       input1_output_scale, -input2_output_scale, output_min, output_max);
660     xnn_params.qs8.vadd.init.qs8_addsub(
661       &params.qs8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
662       -input2_output_scale, input1_output_scale, output_min, output_max);
663   }
664   return create_binary_elementwise_nd(
665     flags,
666     &params,
667     sizeof(params),
668     XNN_INIT_FLAG_QS8,
669     xnn_operator_type_subtract_nd_qs8,
670     &xnn_params.qs8.vadd.minmax,
671     subtract_op_out);
672 }
673 
xnn_create_subtract_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)674 enum xnn_status xnn_create_subtract_nd_qu8(
675     uint8_t input1_zero_point,
676     float input1_scale,
677     uint8_t input2_zero_point,
678     float input2_scale,
679     uint8_t output_zero_point,
680     float output_scale,
681     uint8_t output_min,
682     uint8_t output_max,
683     uint32_t flags,
684     xnn_operator_t* subtract_op_out)
685 {
686   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
687     xnn_log_error(
688       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
689       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_scale);
690     return xnn_status_invalid_parameter;
691   }
692 
693   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
694     xnn_log_error(
695       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
696       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_scale);
697     return xnn_status_invalid_parameter;
698   }
699 
700   if (output_scale <= 0.0f || !isnormal(output_scale)) {
701     xnn_log_error(
702       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
703       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_scale);
704     return xnn_status_invalid_parameter;
705   }
706 
707   if (output_min >= output_max) {
708     xnn_log_error(
709       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
710       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_min, output_max);
711     return xnn_status_invalid_parameter;
712   }
713 
714   const float input1_output_scale = input1_scale / output_scale;
715   if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
716     xnn_log_error(
717       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
718       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_output_scale);
719     return xnn_status_unsupported_parameter;
720   }
721 
722   const float input2_output_scale = input2_scale / output_scale;
723   if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
724     xnn_log_error(
725       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
726       xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_output_scale);
727     return xnn_status_unsupported_parameter;
728   }
729 
730   struct {
731     union xnn_qu8_addsub_minmax_params qu8_addsub;
732     union xnn_qu8_addsub_minmax_params qu8_raddsub;
733   } params;
734   if (xnn_params.qu8.vadd.init.qu8_addsub != NULL) {
735     xnn_params.qu8.vadd.init.qu8_addsub(
736       &params.qu8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
737       input1_output_scale, -input2_output_scale, output_min, output_max);
738     xnn_params.qu8.vadd.init.qu8_addsub(
739       &params.qu8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
740       -input2_output_scale, input1_output_scale, output_min, output_max);
741   }
742   return create_binary_elementwise_nd(
743     flags,
744     &params,
745     sizeof(params),
746     XNN_INIT_FLAG_QU8,
747     xnn_operator_type_subtract_nd_qu8,
748     &xnn_params.qu8.vadd.minmax,
749     subtract_op_out);
750 }
751 
xnn_create_subtract_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)752 enum xnn_status xnn_create_subtract_nd_f32(
753     float output_min,
754     float output_max,
755     uint32_t flags,
756     xnn_operator_t* subtract_op_out)
757 {
758   return create_binary_elementwise_nd_f32(
759     output_min,
760     output_max,
761     flags,
762     xnn_operator_type_subtract_nd_f32,
763     &xnn_params.f32.vsub,
764     subtract_op_out);
765 }
766 
setup_binary_elementwise_nd(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,uint32_t datatype_init_flags,uint32_t log2_element_size,const void * params,size_t params_size,const void * reversed_params,size_t reversed_params_size,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)767 static enum xnn_status setup_binary_elementwise_nd(
768     xnn_operator_t binary_elementwise_op,
769     enum xnn_operator_type expected_operator_type,
770     size_t num_input1_dims,
771     const size_t* input1_shape,
772     size_t num_input2_dims,
773     const size_t* input2_shape,
774     const void* input1,
775     const void* input2,
776     void* output,
777     uint32_t datatype_init_flags,
778     uint32_t log2_element_size,
779     const void* params,
780     size_t params_size,
781     const void* reversed_params,
782     size_t reversed_params_size,
783     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
784     size_t num_threads)
785 {
786   binary_elementwise_op->state = xnn_run_state_invalid;
787 
788   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
789     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
790       xnn_operator_type_to_string(binary_elementwise_op->type));
791     return xnn_status_uninitialized;
792   }
793 
794   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
795     xnn_log_error("failed to setup %s operator: operations on data type are not supported",
796       xnn_operator_type_to_string(binary_elementwise_op->type));
797     return xnn_status_unsupported_hardware;
798   }
799 
800   if (binary_elementwise_op->type != expected_operator_type) {
801     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
802       xnn_operator_type_to_string(expected_operator_type),
803       xnn_operator_type_to_string(binary_elementwise_op->type));
804     return xnn_status_invalid_parameter;
805   }
806 
807   if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
808     xnn_log_error(
809       "failed to setup %s operator with %zu and %zu dimensions in input shapes: "
810       "the number of input dimensions must not exceed %d",
811       xnn_operator_type_to_string(binary_elementwise_op->type), num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
812     return xnn_status_unsupported_parameter;
813   }
814 
815   size_t num_compressed_dims = 0;
816   size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
817   size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
818   size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
819   for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
820     compressed_input1_shape[i] = 1;
821     compressed_input2_shape[i] = 1;
822     compressed_output_shape[i] = 1;
823   }
824   bool broadcast_input1 = false;
825   bool broadcast_input2 = false;
826   bool first_nonunit = true;
827   bool degenerate_shape = false;
828   const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
829   for (size_t i = 1; i <= num_common_dims; i++) {
830     const size_t input1_dim = input1_shape[num_input1_dims - i];
831     const size_t input2_dim = input2_shape[num_input2_dims - i];
832     degenerate_shape |= input1_dim == 0;
833     degenerate_shape |= input2_dim == 0;
834     if (input1_dim == 1 && input2_dim == 1) {
835       continue;
836     }
837     assert(!broadcast_input1 || !broadcast_input2);
838 
839     if (input1_dim == 1) {
840       if (!broadcast_input1) {
841         broadcast_input1 = true;
842         broadcast_input2 = false;
843         num_compressed_dims++;
844       }
845       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
846       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
847     } else if (input2_dim == 1) {
848       if (!broadcast_input2) {
849         broadcast_input1 = false;
850         broadcast_input2 = true;
851         num_compressed_dims++;
852       }
853       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
854       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
855     } else if (input1_dim == input2_dim) {
856       if (broadcast_input1 || broadcast_input2 || first_nonunit) {
857         broadcast_input1 = false;
858         broadcast_input2 = false;
859         num_compressed_dims++;
860       }
861       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
862       compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
863       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
864     } else {
865       xnn_log_error(
866         "failed to setup %s operator: "
867         "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
868         xnn_operator_type_to_string(binary_elementwise_op->type),
869         num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
870       return xnn_status_invalid_parameter;
871     }
872     first_nonunit = false;
873   }
874   if (num_input1_dims > num_input2_dims) {
875     if (!broadcast_input2) {
876       num_compressed_dims++;
877     }
878     for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
879       const size_t input1_dim = input1_shape[i];
880       degenerate_shape |= input1_dim == 0;
881       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
882       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
883     }
884   } else if (num_input2_dims > num_input1_dims) {
885     if (!broadcast_input1) {
886       num_compressed_dims++;
887     }
888     for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
889       const size_t input2_dim = input2_shape[i];
890       degenerate_shape |= input2_dim == 0;
891       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
892       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
893     }
894   }
895   num_compressed_dims = max(num_compressed_dims, 1);
896 
897   // Early exit without setting up context if any shape dimension is zero.
898   if (degenerate_shape) {
899     binary_elementwise_op->state = xnn_run_state_skip;
900     return xnn_status_success;
901   }
902 
903   binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
904     .a = input1,
905     .b = input2,
906     .y = output,
907     .elements = compressed_output_shape[0] << log2_element_size,
908   };
909   if (params_size != 0) {
910     memcpy(&binary_elementwise_op->context.elementwise_binary.params, params, params_size);
911   }
912 
913   const size_t* compressed_a_shape = compressed_input1_shape;
914   const size_t* compressed_b_shape = compressed_input2_shape;
915   if (compressed_input1_shape[0] == 1) {
916     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.ropc_function;
917     binary_elementwise_op->context.elementwise_binary.a = input2;
918     binary_elementwise_op->context.elementwise_binary.b = input1;
919     compressed_a_shape = compressed_input2_shape;
920     compressed_b_shape = compressed_input1_shape;
921     if (reversed_params_size != 0) {
922       memcpy(&binary_elementwise_op->context.elementwise_binary.params, reversed_params, reversed_params_size);
923     }
924   } else if (compressed_input2_shape[0] == 1) {
925     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.opc_function;
926   } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
927     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.op_function;
928   }
929   size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
930   for (size_t i = 1; i < num_compressed_dims; i++) {
931     if (compressed_a_shape[i] != 1) {
932       binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride << log2_element_size;
933     }
934     if (compressed_b_shape[i] != 1) {
935       binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride << log2_element_size;
936     }
937     binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride << log2_element_size;
938     a_stride *= compressed_a_shape[i];
939     b_stride *= compressed_b_shape[i];
940     y_stride *= compressed_output_shape[i];
941   }
942 
943   binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
944   binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
945   binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
946   binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
947   binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
948   binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
949   binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
950   binary_elementwise_op->compute.tile[0] = 1;
951   binary_elementwise_op->compute.tile[1] = 1;
952   binary_elementwise_op->state = xnn_run_state_ready;
953 
954   return xnn_status_success;
955 }
956 
setup_binary_elementwise_nd_f16(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)957 static enum xnn_status setup_binary_elementwise_nd_f16(
958     xnn_operator_t binary_elementwise_op,
959     enum xnn_operator_type expected_operator_type,
960     size_t num_input1_dims,
961     const size_t* input1_shape,
962     size_t num_input2_dims,
963     const size_t* input2_shape,
964     const void* input1,
965     const void* input2,
966     void* output,
967     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
968     size_t num_threads)
969 {
970   return setup_binary_elementwise_nd(
971     binary_elementwise_op,
972     expected_operator_type,
973     num_input1_dims,
974     input1_shape,
975     num_input2_dims,
976     input2_shape,
977     input1,
978     input2,
979     output,
980     XNN_INIT_FLAG_F16,
981     1 /* log2(sizeof(half)) */,
982     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
983     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
984     vbinary,
985     num_threads);
986 }
987 
setup_binary_elementwise_nd_f32(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)988 static enum xnn_status setup_binary_elementwise_nd_f32(
989     xnn_operator_t binary_elementwise_op,
990     enum xnn_operator_type expected_operator_type,
991     size_t num_input1_dims,
992     const size_t* input1_shape,
993     size_t num_input2_dims,
994     const size_t* input2_shape,
995     const float* input1,
996     const float* input2,
997     float* output,
998     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
999     size_t num_threads)
1000 {
1001   return setup_binary_elementwise_nd(
1002     binary_elementwise_op, expected_operator_type,
1003     num_input1_dims, input1_shape,
1004     num_input2_dims, input2_shape,
1005     input1, input2, output,
1006     XNN_INIT_FLAG_F32,
1007     2 /* log2(sizeof(float)) */,
1008     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1009     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1010     vbinary,
1011     num_threads);
1012 }
1013 
xnn_setup_add_nd_qs8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1014 enum xnn_status xnn_setup_add_nd_qs8(
1015     xnn_operator_t add_op,
1016     size_t num_input1_dims,
1017     const size_t* input1_shape,
1018     size_t num_input2_dims,
1019     const size_t* input2_shape,
1020     const int8_t* input1,
1021     const int8_t* input2,
1022     int8_t* output,
1023     pthreadpool_t threadpool)
1024 {
1025   return setup_binary_elementwise_nd(
1026     add_op, xnn_operator_type_add_nd_qs8,
1027     num_input1_dims, input1_shape,
1028     num_input2_dims, input2_shape,
1029     input1, input2, output,
1030     XNN_INIT_FLAG_QS8,
1031     0 /* log2(sizeof(int8_t))) */,
1032     &add_op->params.qs8_addsub, sizeof(add_op->params.qs8_addsub),
1033     &add_op->params.qs8_raddsub, sizeof(add_op->params.qs8_raddsub),
1034     &xnn_params.qs8.vadd,
1035     pthreadpool_get_threads_count(threadpool));
1036 }
1037 
xnn_setup_add_nd_qu8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1038 enum xnn_status xnn_setup_add_nd_qu8(
1039     xnn_operator_t add_op,
1040     size_t num_input1_dims,
1041     const size_t* input1_shape,
1042     size_t num_input2_dims,
1043     const size_t* input2_shape,
1044     const uint8_t* input1,
1045     const uint8_t* input2,
1046     uint8_t* output,
1047     pthreadpool_t threadpool)
1048 {
1049   return setup_binary_elementwise_nd(
1050     add_op, xnn_operator_type_add_nd_qu8,
1051     num_input1_dims, input1_shape,
1052     num_input2_dims, input2_shape,
1053     input1, input2, output,
1054     XNN_INIT_FLAG_QU8,
1055     0 /* log2(sizeof(uint8_t))) */,
1056     &add_op->params.qu8_addsub, sizeof(add_op->params.qu8_addsub),
1057     &add_op->params.qu8_raddsub, sizeof(add_op->params.qu8_raddsub),
1058     &xnn_params.qu8.vadd,
1059     pthreadpool_get_threads_count(threadpool));
1060 }
1061 
xnn_setup_add_nd_f16(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1062 enum xnn_status xnn_setup_add_nd_f16(
1063     xnn_operator_t add_op,
1064     size_t num_input1_dims,
1065     const size_t* input1_shape,
1066     size_t num_input2_dims,
1067     const size_t* input2_shape,
1068     const void* input1,
1069     const void* input2,
1070     void* output,
1071     pthreadpool_t threadpool)
1072 {
1073   return setup_binary_elementwise_nd_f16(
1074     add_op, xnn_operator_type_add_nd_f16,
1075     num_input1_dims, input1_shape,
1076     num_input2_dims, input2_shape,
1077     input1, input2, output,
1078     &xnn_params.f16.vadd,
1079     pthreadpool_get_threads_count(threadpool));
1080 }
1081 
xnn_setup_add_nd_f32(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1082 enum xnn_status xnn_setup_add_nd_f32(
1083     xnn_operator_t add_op,
1084     size_t num_input1_dims,
1085     const size_t* input1_shape,
1086     size_t num_input2_dims,
1087     const size_t* input2_shape,
1088     const float* input1,
1089     const float* input2,
1090     float* output,
1091     pthreadpool_t threadpool)
1092 {
1093   return setup_binary_elementwise_nd_f32(
1094     add_op, xnn_operator_type_add_nd_f32,
1095     num_input1_dims, input1_shape,
1096     num_input2_dims, input2_shape,
1097     input1, input2, output,
1098     &xnn_params.f32.vadd,
1099     pthreadpool_get_threads_count(threadpool));
1100 }
1101 
xnn_setup_divide_nd_f32(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1102 enum xnn_status xnn_setup_divide_nd_f32(
1103     xnn_operator_t divide_op,
1104     size_t num_input1_dims,
1105     const size_t* input1_shape,
1106     size_t num_input2_dims,
1107     const size_t* input2_shape,
1108     const float* input1,
1109     const float* input2,
1110     float* output,
1111     pthreadpool_t threadpool)
1112 {
1113   return setup_binary_elementwise_nd_f32(
1114     divide_op, xnn_operator_type_divide_nd_f32,
1115     num_input1_dims, input1_shape,
1116     num_input2_dims, input2_shape,
1117     input1, input2, output,
1118     &xnn_params.f32.vdiv,
1119     pthreadpool_get_threads_count(threadpool));
1120 }
1121 
xnn_setup_maximum_nd_f32(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1122 enum xnn_status xnn_setup_maximum_nd_f32(
1123     xnn_operator_t maximum_op,
1124     size_t num_input1_dims,
1125     const size_t* input1_shape,
1126     size_t num_input2_dims,
1127     const size_t* input2_shape,
1128     const float* input1,
1129     const float* input2,
1130     float* output,
1131     pthreadpool_t threadpool)
1132 {
1133   return setup_binary_elementwise_nd_f32(
1134     maximum_op, xnn_operator_type_maximum_nd_f32,
1135     num_input1_dims, input1_shape,
1136     num_input2_dims, input2_shape,
1137     input1, input2, output,
1138     &xnn_params.f32.vmax,
1139     pthreadpool_get_threads_count(threadpool));
1140 }
1141 
xnn_setup_minimum_nd_f32(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1142 enum xnn_status xnn_setup_minimum_nd_f32(
1143     xnn_operator_t minimum_op,
1144     size_t num_input1_dims,
1145     const size_t* input1_shape,
1146     size_t num_input2_dims,
1147     const size_t* input2_shape,
1148     const float* input1,
1149     const float* input2,
1150     float* output,
1151     pthreadpool_t threadpool)
1152 {
1153   return setup_binary_elementwise_nd_f32(
1154     minimum_op, xnn_operator_type_minimum_nd_f32,
1155     num_input1_dims, input1_shape,
1156     num_input2_dims, input2_shape,
1157     input1, input2, output,
1158     &xnn_params.f32.vmin,
1159     pthreadpool_get_threads_count(threadpool));
1160 }
1161 
xnn_setup_multiply_nd_qs8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1162 enum xnn_status xnn_setup_multiply_nd_qs8(
1163     xnn_operator_t multiply_op,
1164     size_t num_input1_dims,
1165     const size_t* input1_shape,
1166     size_t num_input2_dims,
1167     const size_t* input2_shape,
1168     const int8_t* input1,
1169     const int8_t* input2,
1170     int8_t* output,
1171     pthreadpool_t threadpool)
1172 {
1173   return setup_binary_elementwise_nd(
1174     multiply_op, xnn_operator_type_multiply_nd_qs8,
1175     num_input1_dims, input1_shape,
1176     num_input2_dims, input2_shape,
1177     input1, input2, output,
1178     XNN_INIT_FLAG_QS8,
1179     0 /* log2(sizeof(int8_t))) */,
1180     &multiply_op->params.qs8_mul, sizeof(multiply_op->params.qs8_mul),
1181     &multiply_op->params.qs8_rmul, sizeof(multiply_op->params.qs8_rmul),
1182     &xnn_params.qs8.vmul,
1183     pthreadpool_get_threads_count(threadpool));
1184 }
1185 
xnn_setup_multiply_nd_qu8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1186 enum xnn_status xnn_setup_multiply_nd_qu8(
1187     xnn_operator_t multiply_op,
1188     size_t num_input1_dims,
1189     const size_t* input1_shape,
1190     size_t num_input2_dims,
1191     const size_t* input2_shape,
1192     const uint8_t* input1,
1193     const uint8_t* input2,
1194     uint8_t* output,
1195     pthreadpool_t threadpool)
1196 {
1197   return setup_binary_elementwise_nd(
1198     multiply_op, xnn_operator_type_multiply_nd_qu8,
1199     num_input1_dims, input1_shape,
1200     num_input2_dims, input2_shape,
1201     input1, input2, output,
1202     XNN_INIT_FLAG_QU8,
1203     0 /* log2(sizeof(uint8_t))) */,
1204     &multiply_op->params.qu8_mul, sizeof(multiply_op->params.qu8_mul),
1205     &multiply_op->params.qu8_rmul, sizeof(multiply_op->params.qu8_rmul),
1206     &xnn_params.qu8.vmul,
1207     pthreadpool_get_threads_count(threadpool));
1208 }
1209 
xnn_setup_multiply_nd_f16(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1210 enum xnn_status xnn_setup_multiply_nd_f16(
1211     xnn_operator_t multiply_op,
1212     size_t num_input1_dims,
1213     const size_t* input1_shape,
1214     size_t num_input2_dims,
1215     const size_t* input2_shape,
1216     const void* input1,
1217     const void* input2,
1218     void* output,
1219     pthreadpool_t threadpool)
1220 {
1221   return setup_binary_elementwise_nd_f16(
1222     multiply_op, xnn_operator_type_multiply_nd_f16,
1223     num_input1_dims, input1_shape,
1224     num_input2_dims, input2_shape,
1225     input1, input2, output,
1226     &xnn_params.f16.vmul,
1227     pthreadpool_get_threads_count(threadpool));
1228 }
1229 
xnn_setup_multiply_nd_f32(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1230 enum xnn_status xnn_setup_multiply_nd_f32(
1231     xnn_operator_t multiply_op,
1232     size_t num_input1_dims,
1233     const size_t* input1_shape,
1234     size_t num_input2_dims,
1235     const size_t* input2_shape,
1236     const float* input1,
1237     const float* input2,
1238     float* output,
1239     pthreadpool_t threadpool)
1240 {
1241   return setup_binary_elementwise_nd_f32(
1242     multiply_op, xnn_operator_type_multiply_nd_f32,
1243     num_input1_dims, input1_shape,
1244     num_input2_dims, input2_shape,
1245     input1, input2, output,
1246     &xnn_params.f32.vmul,
1247     pthreadpool_get_threads_count(threadpool));
1248 }
1249 
xnn_setup_squared_difference_nd_f32(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1250 enum xnn_status xnn_setup_squared_difference_nd_f32(
1251     xnn_operator_t squared_difference_op,
1252     size_t num_input1_dims,
1253     const size_t* input1_shape,
1254     size_t num_input2_dims,
1255     const size_t* input2_shape,
1256     const float* input1,
1257     const float* input2,
1258     float* output,
1259     pthreadpool_t threadpool)
1260 {
1261   return setup_binary_elementwise_nd_f32(
1262     squared_difference_op, xnn_operator_type_squared_difference_nd_f32,
1263     num_input1_dims, input1_shape,
1264     num_input2_dims, input2_shape,
1265     input1, input2, output,
1266     &xnn_params.f32.vsqrdiff,
1267     pthreadpool_get_threads_count(threadpool));
1268 }
1269 
xnn_setup_subtract_nd_qs8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1270 enum xnn_status xnn_setup_subtract_nd_qs8(
1271     xnn_operator_t subtract_op,
1272     size_t num_input1_dims,
1273     const size_t* input1_shape,
1274     size_t num_input2_dims,
1275     const size_t* input2_shape,
1276     const int8_t* input1,
1277     const int8_t* input2,
1278     int8_t* output,
1279     pthreadpool_t threadpool)
1280 {
1281   return setup_binary_elementwise_nd(
1282     subtract_op, xnn_operator_type_subtract_nd_qs8,
1283     num_input1_dims, input1_shape,
1284     num_input2_dims, input2_shape,
1285     input1, input2, output,
1286     XNN_INIT_FLAG_QS8,
1287     0 /* log2(sizeof(int8_t))) */,
1288     &subtract_op->params.qs8_addsub, sizeof(subtract_op->params.qs8_addsub),
1289     &subtract_op->params.qs8_raddsub, sizeof(subtract_op->params.qs8_raddsub),
1290     &xnn_params.qs8.vadd,
1291     pthreadpool_get_threads_count(threadpool));
1292 }
1293 
xnn_setup_subtract_nd_qu8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1294 enum xnn_status xnn_setup_subtract_nd_qu8(
1295     xnn_operator_t subtract_op,
1296     size_t num_input1_dims,
1297     const size_t* input1_shape,
1298     size_t num_input2_dims,
1299     const size_t* input2_shape,
1300     const uint8_t* input1,
1301     const uint8_t* input2,
1302     uint8_t* output,
1303     pthreadpool_t threadpool)
1304 {
1305   return setup_binary_elementwise_nd(
1306     subtract_op, xnn_operator_type_subtract_nd_qu8,
1307     num_input1_dims, input1_shape,
1308     num_input2_dims, input2_shape,
1309     input1, input2, output,
1310     XNN_INIT_FLAG_QU8,
1311     0 /* log2(sizeof(uint8_t))) */,
1312     &subtract_op->params.qu8_addsub, sizeof(subtract_op->params.qu8_addsub),
1313     &subtract_op->params.qu8_raddsub, sizeof(subtract_op->params.qu8_raddsub),
1314     &xnn_params.qu8.vadd,
1315     pthreadpool_get_threads_count(threadpool));
1316 }
1317 
xnn_setup_subtract_nd_f32(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1318 enum xnn_status xnn_setup_subtract_nd_f32(
1319     xnn_operator_t subtract_op,
1320     size_t num_input1_dims,
1321     const size_t* input1_shape,
1322     size_t num_input2_dims,
1323     const size_t* input2_shape,
1324     const float* input1,
1325     const float* input2,
1326     float* output,
1327     pthreadpool_t threadpool)
1328 {
1329   return setup_binary_elementwise_nd_f32(
1330     subtract_op, xnn_operator_type_subtract_nd_f32,
1331     num_input1_dims, input1_shape,
1332     num_input2_dims, input2_shape,
1333     input1, input2, output,
1334     &xnn_params.f32.vsub,
1335     pthreadpool_get_threads_count(threadpool));
1336 }
1337