1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11
12 #include <fp16.h>
13
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/params-init.h>
19 #include <xnnpack/params.h>
20
21
create_binary_elementwise_nd(uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,const struct vbinary_fused_ukernels * vbinary_fused_ukernels,xnn_operator_t * binary_elementwise_op_out)22 static enum xnn_status create_binary_elementwise_nd(
23 uint32_t flags,
24 const void* params,
25 size_t params_size,
26 uint32_t datatype_init_flags,
27 enum xnn_operator_type operator_type,
28 const struct vbinary_fused_ukernels* vbinary_fused_ukernels,
29 xnn_operator_t* binary_elementwise_op_out)
30 {
31 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
32 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
33 xnn_operator_type_to_string(operator_type));
34 return xnn_status_uninitialized;
35 }
36
37 if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
38 xnn_log_error("failed to create %s operator: operations on data type are not supported",
39 xnn_operator_type_to_string(operator_type));
40 return xnn_status_unsupported_hardware;
41 }
42
43 xnn_operator_t binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
44 if (binary_elementwise_op == NULL) {
45 xnn_log_error(
46 "failed to allocate %zu bytes for %s operator descriptor",
47 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
48 return xnn_status_out_of_memory;
49 }
50
51 if (params_size != 0) {
52 memcpy(&binary_elementwise_op->params, params, params_size);
53 }
54
55 binary_elementwise_op->ukernel.vbinary.op_function = vbinary_fused_ukernels->op_ukernel;
56 binary_elementwise_op->ukernel.vbinary.opc_function = vbinary_fused_ukernels->opc_ukernel;
57 binary_elementwise_op->ukernel.vbinary.ropc_function = vbinary_fused_ukernels->ropc_ukernel;
58
59 binary_elementwise_op->type = operator_type;
60 binary_elementwise_op->flags = flags;
61
62 binary_elementwise_op->state = xnn_run_state_invalid;
63
64 *binary_elementwise_op_out = binary_elementwise_op;
65 return xnn_status_success;
66 }
67
create_binary_elementwise_nd_f16(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)68 static enum xnn_status create_binary_elementwise_nd_f16(
69 float output_min,
70 float output_max,
71 uint32_t flags,
72 enum xnn_operator_type operator_type,
73 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
74 xnn_operator_t* binary_elementwise_op_out)
75 {
76 if (isnan(output_min)) {
77 xnn_log_error(
78 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
79 xnn_operator_type_to_string(operator_type));
80 return xnn_status_invalid_parameter;
81 }
82
83 if (isnan(output_max)) {
84 xnn_log_error(
85 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
86 xnn_operator_type_to_string(operator_type));
87 return xnn_status_invalid_parameter;
88 }
89
90 if (fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)) >= fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max))) {
91 xnn_log_error(
92 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
93 xnn_operator_type_to_string(operator_type),
94 fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)),
95 fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)));
96 return xnn_status_invalid_parameter;
97 }
98
99 union xnn_f16_minmax_params params;
100 if (vbinary->init.f16_minmax != NULL) {
101 vbinary->init.f16_minmax(¶ms,
102 fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max));
103 }
104 return create_binary_elementwise_nd(
105 flags,
106 ¶ms,
107 sizeof(params),
108 XNN_INIT_FLAG_F16,
109 operator_type,
110 &vbinary->minmax,
111 binary_elementwise_op_out);
112 }
113
create_binary_elementwise_nd_f32(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)114 static enum xnn_status create_binary_elementwise_nd_f32(
115 float output_min,
116 float output_max,
117 uint32_t flags,
118 enum xnn_operator_type operator_type,
119 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
120 xnn_operator_t* binary_elementwise_op_out)
121 {
122 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
123 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
124 xnn_operator_type_to_string(operator_type));
125 return xnn_status_uninitialized;
126 }
127
128 if (isnan(output_min)) {
129 xnn_log_error(
130 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
131 xnn_operator_type_to_string(operator_type));
132 return xnn_status_invalid_parameter;
133 }
134
135 if (isnan(output_max)) {
136 xnn_log_error(
137 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
138 xnn_operator_type_to_string(operator_type));
139 return xnn_status_invalid_parameter;
140 }
141
142 if (output_min >= output_max) {
143 xnn_log_error(
144 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
145 xnn_operator_type_to_string(operator_type), output_min, output_max);
146 return xnn_status_invalid_parameter;
147 }
148
149 const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
150 const struct vbinary_fused_ukernels* vbinary_fused_ukernels = &vbinary->minmax;
151 if (linear_activation && vbinary->linear.op_ukernel != NULL) {
152 vbinary_fused_ukernels = &vbinary->linear;
153 }
154
155 union xnn_f32_minmax_params params;
156 if (vbinary->init.f32_minmax != NULL) {
157 vbinary->init.f32_minmax(¶ms, output_min, output_max);
158 }
159 return create_binary_elementwise_nd(
160 flags,
161 ¶ms,
162 sizeof(params),
163 XNN_INIT_FLAG_F32,
164 operator_type,
165 vbinary_fused_ukernels,
166 binary_elementwise_op_out);
167 }
168
xnn_create_add_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)169 enum xnn_status xnn_create_add_nd_qs8(
170 int8_t input1_zero_point,
171 float input1_scale,
172 int8_t input2_zero_point,
173 float input2_scale,
174 int8_t output_zero_point,
175 float output_scale,
176 int8_t output_min,
177 int8_t output_max,
178 uint32_t flags,
179 xnn_operator_t* add_op_out)
180 {
181 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
182 xnn_log_error(
183 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
184 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_scale);
185 return xnn_status_invalid_parameter;
186 }
187
188 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
189 xnn_log_error(
190 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
191 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_scale);
192 return xnn_status_invalid_parameter;
193 }
194
195 if (output_scale <= 0.0f || !isnormal(output_scale)) {
196 xnn_log_error(
197 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
198 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_scale);
199 return xnn_status_invalid_parameter;
200 }
201
202 if (output_min >= output_max) {
203 xnn_log_error(
204 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
205 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_min, output_max);
206 return xnn_status_invalid_parameter;
207 }
208
209 const float input1_output_scale = input1_scale / output_scale;
210 if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
211 xnn_log_error(
212 "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
213 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_output_scale);
214 return xnn_status_unsupported_parameter;
215 }
216
217 const float input2_output_scale = input2_scale / output_scale;
218 if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
219 xnn_log_error(
220 "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
221 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_output_scale);
222 return xnn_status_unsupported_parameter;
223 }
224
225 struct {
226 union xnn_qs8_addsub_minmax_params qs8_addsub;
227 union xnn_qs8_addsub_minmax_params qs8_raddsub;
228 } params;
229 if (xnn_params.qs8.vadd.init.qs8_addsub != NULL) {
230 xnn_params.qs8.vadd.init.qs8_addsub(
231 ¶ms.qs8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
232 input1_output_scale, input2_output_scale, output_min, output_max);
233 xnn_params.qs8.vadd.init.qs8_addsub(
234 ¶ms.qs8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
235 input2_output_scale, input1_output_scale, output_min, output_max);
236 }
237 return create_binary_elementwise_nd(
238 flags,
239 ¶ms,
240 sizeof(params),
241 XNN_INIT_FLAG_QS8,
242 xnn_operator_type_add_nd_qs8,
243 &xnn_params.qs8.vadd.minmax,
244 add_op_out);
245 }
246
xnn_create_add_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)247 enum xnn_status xnn_create_add_nd_qu8(
248 uint8_t input1_zero_point,
249 float input1_scale,
250 uint8_t input2_zero_point,
251 float input2_scale,
252 uint8_t output_zero_point,
253 float output_scale,
254 uint8_t output_min,
255 uint8_t output_max,
256 uint32_t flags,
257 xnn_operator_t* add_op_out)
258 {
259 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
260 xnn_log_error(
261 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
262 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_scale);
263 return xnn_status_invalid_parameter;
264 }
265
266 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
267 xnn_log_error(
268 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
269 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_scale);
270 return xnn_status_invalid_parameter;
271 }
272
273 if (output_scale <= 0.0f || !isnormal(output_scale)) {
274 xnn_log_error(
275 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
276 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_scale);
277 return xnn_status_invalid_parameter;
278 }
279
280 if (output_min >= output_max) {
281 xnn_log_error(
282 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
283 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), output_min, output_max);
284 return xnn_status_invalid_parameter;
285 }
286
287 const float input1_output_scale = input1_scale / output_scale;
288 if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
289 xnn_log_error(
290 "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
291 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input1_output_scale);
292 return xnn_status_unsupported_parameter;
293 }
294
295 const float input2_output_scale = input2_scale / output_scale;
296 if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
297 xnn_log_error(
298 "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
299 xnn_operator_type_to_string(xnn_operator_type_add_nd_qu8), input2_output_scale);
300 return xnn_status_unsupported_parameter;
301 }
302
303 struct {
304 union xnn_qu8_addsub_minmax_params qu8_addsub;
305 union xnn_qu8_addsub_minmax_params qu8_raddsub;
306 } params;
307 if (xnn_params.qu8.vadd.init.qu8_addsub != NULL) {
308 xnn_params.qu8.vadd.init.qu8_addsub(
309 ¶ms.qu8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
310 input1_output_scale, input2_output_scale, output_min, output_max);
311 xnn_params.qu8.vadd.init.qu8_addsub(
312 ¶ms.qu8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
313 input2_output_scale, input1_output_scale, output_min, output_max);
314 }
315 return create_binary_elementwise_nd(
316 flags,
317 ¶ms,
318 sizeof(params),
319 XNN_INIT_FLAG_QU8,
320 xnn_operator_type_add_nd_qu8,
321 &xnn_params.qu8.vadd.minmax,
322 add_op_out);
323 }
324
xnn_create_add_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)325 enum xnn_status xnn_create_add_nd_f16(
326 float output_min,
327 float output_max,
328 uint32_t flags,
329 xnn_operator_t* add_op_out)
330 {
331 return create_binary_elementwise_nd_f16(
332 output_min,
333 output_max,
334 flags,
335 xnn_operator_type_add_nd_f16,
336 &xnn_params.f16.vadd,
337 add_op_out);
338 }
339
xnn_create_add_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)340 enum xnn_status xnn_create_add_nd_f32(
341 float output_min,
342 float output_max,
343 uint32_t flags,
344 xnn_operator_t* add_op_out)
345 {
346 return create_binary_elementwise_nd_f32(
347 output_min,
348 output_max,
349 flags,
350 xnn_operator_type_add_nd_f32,
351 &xnn_params.f32.vadd,
352 add_op_out);
353 }
354
xnn_create_divide_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)355 enum xnn_status xnn_create_divide_nd_f32(
356 float output_min,
357 float output_max,
358 uint32_t flags,
359 xnn_operator_t* divide_op_out)
360 {
361 return create_binary_elementwise_nd_f32(
362 output_min,
363 output_max,
364 flags,
365 xnn_operator_type_divide_nd_f32,
366 &xnn_params.f32.vdiv,
367 divide_op_out);
368 }
369
xnn_create_maximum_nd_f32(uint32_t flags,xnn_operator_t * maximum_op_out)370 enum xnn_status xnn_create_maximum_nd_f32(
371 uint32_t flags,
372 xnn_operator_t* maximum_op_out)
373 {
374 union xnn_f32_default_params params;
375 if (xnn_params.f32.vmin.init.f32_default != NULL) {
376 xnn_params.f32.vmin.init.f32_default(¶ms);
377 }
378 return create_binary_elementwise_nd(
379 flags,
380 ¶ms,
381 sizeof(params),
382 XNN_INIT_FLAG_F32,
383 xnn_operator_type_maximum_nd_f32,
384 &xnn_params.f32.vmax.minmax,
385 maximum_op_out);
386 }
387
xnn_create_minimum_nd_f32(uint32_t flags,xnn_operator_t * minimum_op_out)388 enum xnn_status xnn_create_minimum_nd_f32(
389 uint32_t flags,
390 xnn_operator_t* minimum_op_out)
391 {
392 union xnn_f32_default_params params;
393 if (xnn_params.f32.vmin.init.f32_default != NULL) {
394 xnn_params.f32.vmin.init.f32_default(¶ms);
395 }
396 return create_binary_elementwise_nd(
397 flags,
398 ¶ms,
399 sizeof(params),
400 XNN_INIT_FLAG_F32,
401 xnn_operator_type_minimum_nd_f32,
402 &xnn_params.f32.vmin.minmax,
403 minimum_op_out);
404 }
405
xnn_create_multiply_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)406 enum xnn_status xnn_create_multiply_nd_qs8(
407 int8_t input1_zero_point,
408 float input1_scale,
409 int8_t input2_zero_point,
410 float input2_scale,
411 int8_t output_zero_point,
412 float output_scale,
413 int8_t output_min,
414 int8_t output_max,
415 uint32_t flags,
416 xnn_operator_t* multiply_op_out)
417 {
418 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
419 xnn_log_error(
420 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
421 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input1_scale);
422 return xnn_status_invalid_parameter;
423 }
424
425 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
426 xnn_log_error(
427 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
428 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), input2_scale);
429 return xnn_status_invalid_parameter;
430 }
431
432 if (output_scale <= 0.0f || !isnormal(output_scale)) {
433 xnn_log_error(
434 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
435 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_scale);
436 return xnn_status_invalid_parameter;
437 }
438
439 if (output_min >= output_max) {
440 xnn_log_error(
441 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
442 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), output_min, output_max);
443 return xnn_status_invalid_parameter;
444 }
445
446 const float product_scale = input1_scale * input2_scale;
447 const float product_output_scale = product_scale / output_scale;
448 if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
449 xnn_log_error(
450 "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
451 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qs8), product_output_scale);
452 return xnn_status_unsupported_parameter;
453 }
454
455 struct {
456 union xnn_qs8_mul_minmax_params qs8_mul;
457 union xnn_qs8_mul_minmax_params qs8_rmul;
458 } params;
459 if (xnn_params.qs8.vmul.init.qs8_mul != NULL) {
460 xnn_params.qs8.vmul.init.qs8_mul(
461 ¶ms.qs8_mul, input1_zero_point, input2_zero_point, output_zero_point,
462 product_output_scale, output_min, output_max);
463 xnn_params.qs8.vmul.init.qs8_mul(
464 ¶ms.qs8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
465 product_output_scale, output_min, output_max);
466 }
467 return create_binary_elementwise_nd(
468 flags,
469 ¶ms,
470 sizeof(params),
471 XNN_INIT_FLAG_QS8,
472 xnn_operator_type_multiply_nd_qs8,
473 &xnn_params.qs8.vmul.minmax,
474 multiply_op_out);
475 }
476
xnn_create_multiply_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * multiply_op_out)477 enum xnn_status xnn_create_multiply_nd_qu8(
478 uint8_t input1_zero_point,
479 float input1_scale,
480 uint8_t input2_zero_point,
481 float input2_scale,
482 uint8_t output_zero_point,
483 float output_scale,
484 uint8_t output_min,
485 uint8_t output_max,
486 uint32_t flags,
487 xnn_operator_t* multiply_op_out)
488 {
489 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
490 xnn_log_error(
491 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
492 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input1_scale);
493 return xnn_status_invalid_parameter;
494 }
495
496 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
497 xnn_log_error(
498 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
499 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), input2_scale);
500 return xnn_status_invalid_parameter;
501 }
502
503 if (output_scale <= 0.0f || !isnormal(output_scale)) {
504 xnn_log_error(
505 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
506 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_scale);
507 return xnn_status_invalid_parameter;
508 }
509
510 if (output_min >= output_max) {
511 xnn_log_error(
512 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
513 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), output_min, output_max);
514 return xnn_status_invalid_parameter;
515 }
516
517 const float product_scale = input1_scale * input2_scale;
518 const float product_output_scale = product_scale / output_scale;
519 if (product_output_scale < 0x1.0p-16f || product_output_scale >= 0x1.0p+8f) {
520 xnn_log_error(
521 "failed to create %s operator with %.7g product-to-output scale ratio: scale ratio must be in [2**-16, 2**8) range",
522 xnn_operator_type_to_string(xnn_operator_type_multiply_nd_qu8), product_output_scale);
523 return xnn_status_unsupported_parameter;
524 }
525
526 struct {
527 union xnn_qu8_mul_minmax_params qu8_mul;
528 union xnn_qu8_mul_minmax_params qu8_rmul;
529 } params;
530 if (xnn_params.qu8.vmul.init.qu8_mul != NULL) {
531 xnn_params.qu8.vmul.init.qu8_mul(
532 ¶ms.qu8_mul, input1_zero_point, input2_zero_point, output_zero_point,
533 product_output_scale, output_min, output_max);
534 xnn_params.qu8.vmul.init.qu8_mul(
535 ¶ms.qu8_rmul, input2_zero_point, input1_zero_point, output_zero_point,
536 product_output_scale, output_min, output_max);
537 }
538 return create_binary_elementwise_nd(
539 flags,
540 ¶ms,
541 sizeof(params),
542 XNN_INIT_FLAG_QU8,
543 xnn_operator_type_multiply_nd_qu8,
544 &xnn_params.qu8.vmul.minmax,
545 multiply_op_out);
546 }
547
xnn_create_multiply_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)548 enum xnn_status xnn_create_multiply_nd_f16(
549 float output_min,
550 float output_max,
551 uint32_t flags,
552 xnn_operator_t* multiply_op_out)
553 {
554 return create_binary_elementwise_nd_f16(
555 output_min,
556 output_max,
557 flags,
558 xnn_operator_type_multiply_nd_f16,
559 &xnn_params.f16.vmul,
560 multiply_op_out);
561 }
562
xnn_create_multiply_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)563 enum xnn_status xnn_create_multiply_nd_f32(
564 float output_min,
565 float output_max,
566 uint32_t flags,
567 xnn_operator_t* multiply_op_out)
568 {
569 return create_binary_elementwise_nd_f32(
570 output_min,
571 output_max,
572 flags,
573 xnn_operator_type_multiply_nd_f32,
574 &xnn_params.f32.vmul,
575 multiply_op_out);
576 }
577
xnn_create_squared_difference_nd_f32(uint32_t flags,xnn_operator_t * squared_difference_op_out)578 enum xnn_status xnn_create_squared_difference_nd_f32(
579 uint32_t flags,
580 xnn_operator_t* squared_difference_op_out)
581 {
582 union xnn_f32_default_params params;
583 if (xnn_params.f32.vmin.init.f32_default != NULL) {
584 xnn_params.f32.vmin.init.f32_default(¶ms);
585 }
586 return create_binary_elementwise_nd(
587 flags,
588 ¶ms,
589 sizeof(params),
590 XNN_INIT_FLAG_F32,
591 xnn_operator_type_squared_difference_nd_f32,
592 &xnn_params.f32.vsqrdiff.minmax,
593 squared_difference_op_out);
594 }
595
xnn_create_subtract_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)596 enum xnn_status xnn_create_subtract_nd_qs8(
597 int8_t input1_zero_point,
598 float input1_scale,
599 int8_t input2_zero_point,
600 float input2_scale,
601 int8_t output_zero_point,
602 float output_scale,
603 int8_t output_min,
604 int8_t output_max,
605 uint32_t flags,
606 xnn_operator_t* subtract_op_out)
607 {
608 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
609 xnn_log_error(
610 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
611 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_scale);
612 return xnn_status_invalid_parameter;
613 }
614
615 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
616 xnn_log_error(
617 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
618 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_scale);
619 return xnn_status_invalid_parameter;
620 }
621
622 if (output_scale <= 0.0f || !isnormal(output_scale)) {
623 xnn_log_error(
624 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
625 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_scale);
626 return xnn_status_invalid_parameter;
627 }
628
629 if (output_min >= output_max) {
630 xnn_log_error(
631 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
632 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), output_min, output_max);
633 return xnn_status_invalid_parameter;
634 }
635
636 const float input1_output_scale = input1_scale / output_scale;
637 if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
638 xnn_log_error(
639 "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
640 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input1_output_scale);
641 return xnn_status_unsupported_parameter;
642 }
643
644 const float input2_output_scale = input2_scale / output_scale;
645 if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
646 xnn_log_error(
647 "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
648 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qs8), input2_output_scale);
649 return xnn_status_unsupported_parameter;
650 }
651
652 struct {
653 union xnn_qs8_addsub_minmax_params qs8_addsub;
654 union xnn_qs8_addsub_minmax_params qs8_raddsub;
655 } params;
656 if (xnn_params.qs8.vadd.init.qs8_addsub != NULL) {
657 xnn_params.qs8.vadd.init.qs8_addsub(
658 ¶ms.qs8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
659 input1_output_scale, -input2_output_scale, output_min, output_max);
660 xnn_params.qs8.vadd.init.qs8_addsub(
661 ¶ms.qs8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
662 -input2_output_scale, input1_output_scale, output_min, output_max);
663 }
664 return create_binary_elementwise_nd(
665 flags,
666 ¶ms,
667 sizeof(params),
668 XNN_INIT_FLAG_QS8,
669 xnn_operator_type_subtract_nd_qs8,
670 &xnn_params.qs8.vadd.minmax,
671 subtract_op_out);
672 }
673
xnn_create_subtract_nd_qu8(uint8_t input1_zero_point,float input1_scale,uint8_t input2_zero_point,float input2_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * subtract_op_out)674 enum xnn_status xnn_create_subtract_nd_qu8(
675 uint8_t input1_zero_point,
676 float input1_scale,
677 uint8_t input2_zero_point,
678 float input2_scale,
679 uint8_t output_zero_point,
680 float output_scale,
681 uint8_t output_min,
682 uint8_t output_max,
683 uint32_t flags,
684 xnn_operator_t* subtract_op_out)
685 {
686 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
687 xnn_log_error(
688 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
689 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_scale);
690 return xnn_status_invalid_parameter;
691 }
692
693 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
694 xnn_log_error(
695 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
696 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_scale);
697 return xnn_status_invalid_parameter;
698 }
699
700 if (output_scale <= 0.0f || !isnormal(output_scale)) {
701 xnn_log_error(
702 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
703 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_scale);
704 return xnn_status_invalid_parameter;
705 }
706
707 if (output_min >= output_max) {
708 xnn_log_error(
709 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: lower bound must be below upper bound",
710 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), output_min, output_max);
711 return xnn_status_invalid_parameter;
712 }
713
714 const float input1_output_scale = input1_scale / output_scale;
715 if (input1_output_scale < 0x1.0p-10f || input1_output_scale >= 0x1.0p+8f) {
716 xnn_log_error(
717 "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
718 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input1_output_scale);
719 return xnn_status_unsupported_parameter;
720 }
721
722 const float input2_output_scale = input2_scale / output_scale;
723 if (input2_output_scale < 0x1.0p-10f || input2_output_scale >= 0x1.0p+8f) {
724 xnn_log_error(
725 "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-10, 2**8) range",
726 xnn_operator_type_to_string(xnn_operator_type_subtract_nd_qu8), input2_output_scale);
727 return xnn_status_unsupported_parameter;
728 }
729
730 struct {
731 union xnn_qu8_addsub_minmax_params qu8_addsub;
732 union xnn_qu8_addsub_minmax_params qu8_raddsub;
733 } params;
734 if (xnn_params.qu8.vadd.init.qu8_addsub != NULL) {
735 xnn_params.qu8.vadd.init.qu8_addsub(
736 ¶ms.qu8_addsub, input1_zero_point, input2_zero_point, output_zero_point,
737 input1_output_scale, -input2_output_scale, output_min, output_max);
738 xnn_params.qu8.vadd.init.qu8_addsub(
739 ¶ms.qu8_raddsub, input2_zero_point, input1_zero_point, output_zero_point,
740 -input2_output_scale, input1_output_scale, output_min, output_max);
741 }
742 return create_binary_elementwise_nd(
743 flags,
744 ¶ms,
745 sizeof(params),
746 XNN_INIT_FLAG_QU8,
747 xnn_operator_type_subtract_nd_qu8,
748 &xnn_params.qu8.vadd.minmax,
749 subtract_op_out);
750 }
751
xnn_create_subtract_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)752 enum xnn_status xnn_create_subtract_nd_f32(
753 float output_min,
754 float output_max,
755 uint32_t flags,
756 xnn_operator_t* subtract_op_out)
757 {
758 return create_binary_elementwise_nd_f32(
759 output_min,
760 output_max,
761 flags,
762 xnn_operator_type_subtract_nd_f32,
763 &xnn_params.f32.vsub,
764 subtract_op_out);
765 }
766
setup_binary_elementwise_nd(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,uint32_t datatype_init_flags,uint32_t log2_element_size,const void * params,size_t params_size,const void * reversed_params,size_t reversed_params_size,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)767 static enum xnn_status setup_binary_elementwise_nd(
768 xnn_operator_t binary_elementwise_op,
769 enum xnn_operator_type expected_operator_type,
770 size_t num_input1_dims,
771 const size_t* input1_shape,
772 size_t num_input2_dims,
773 const size_t* input2_shape,
774 const void* input1,
775 const void* input2,
776 void* output,
777 uint32_t datatype_init_flags,
778 uint32_t log2_element_size,
779 const void* params,
780 size_t params_size,
781 const void* reversed_params,
782 size_t reversed_params_size,
783 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
784 size_t num_threads)
785 {
786 binary_elementwise_op->state = xnn_run_state_invalid;
787
788 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
789 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
790 xnn_operator_type_to_string(binary_elementwise_op->type));
791 return xnn_status_uninitialized;
792 }
793
794 if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
795 xnn_log_error("failed to setup %s operator: operations on data type are not supported",
796 xnn_operator_type_to_string(binary_elementwise_op->type));
797 return xnn_status_unsupported_hardware;
798 }
799
800 if (binary_elementwise_op->type != expected_operator_type) {
801 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
802 xnn_operator_type_to_string(expected_operator_type),
803 xnn_operator_type_to_string(binary_elementwise_op->type));
804 return xnn_status_invalid_parameter;
805 }
806
807 if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
808 xnn_log_error(
809 "failed to setup %s operator with %zu and %zu dimensions in input shapes: "
810 "the number of input dimensions must not exceed %d",
811 xnn_operator_type_to_string(binary_elementwise_op->type), num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
812 return xnn_status_unsupported_parameter;
813 }
814
815 size_t num_compressed_dims = 0;
816 size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
817 size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
818 size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
819 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
820 compressed_input1_shape[i] = 1;
821 compressed_input2_shape[i] = 1;
822 compressed_output_shape[i] = 1;
823 }
824 bool broadcast_input1 = false;
825 bool broadcast_input2 = false;
826 bool first_nonunit = true;
827 bool degenerate_shape = false;
828 const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
829 for (size_t i = 1; i <= num_common_dims; i++) {
830 const size_t input1_dim = input1_shape[num_input1_dims - i];
831 const size_t input2_dim = input2_shape[num_input2_dims - i];
832 degenerate_shape |= input1_dim == 0;
833 degenerate_shape |= input2_dim == 0;
834 if (input1_dim == 1 && input2_dim == 1) {
835 continue;
836 }
837 assert(!broadcast_input1 || !broadcast_input2);
838
839 if (input1_dim == 1) {
840 if (!broadcast_input1) {
841 broadcast_input1 = true;
842 broadcast_input2 = false;
843 num_compressed_dims++;
844 }
845 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
846 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
847 } else if (input2_dim == 1) {
848 if (!broadcast_input2) {
849 broadcast_input1 = false;
850 broadcast_input2 = true;
851 num_compressed_dims++;
852 }
853 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
854 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
855 } else if (input1_dim == input2_dim) {
856 if (broadcast_input1 || broadcast_input2 || first_nonunit) {
857 broadcast_input1 = false;
858 broadcast_input2 = false;
859 num_compressed_dims++;
860 }
861 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
862 compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
863 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
864 } else {
865 xnn_log_error(
866 "failed to setup %s operator: "
867 "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
868 xnn_operator_type_to_string(binary_elementwise_op->type),
869 num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
870 return xnn_status_invalid_parameter;
871 }
872 first_nonunit = false;
873 }
874 if (num_input1_dims > num_input2_dims) {
875 if (!broadcast_input2) {
876 num_compressed_dims++;
877 }
878 for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
879 const size_t input1_dim = input1_shape[i];
880 degenerate_shape |= input1_dim == 0;
881 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
882 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
883 }
884 } else if (num_input2_dims > num_input1_dims) {
885 if (!broadcast_input1) {
886 num_compressed_dims++;
887 }
888 for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
889 const size_t input2_dim = input2_shape[i];
890 degenerate_shape |= input2_dim == 0;
891 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
892 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
893 }
894 }
895 num_compressed_dims = max(num_compressed_dims, 1);
896
897 // Early exit without setting up context if any shape dimension is zero.
898 if (degenerate_shape) {
899 binary_elementwise_op->state = xnn_run_state_skip;
900 return xnn_status_success;
901 }
902
903 binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
904 .a = input1,
905 .b = input2,
906 .y = output,
907 .elements = compressed_output_shape[0] << log2_element_size,
908 };
909 if (params_size != 0) {
910 memcpy(&binary_elementwise_op->context.elementwise_binary.params, params, params_size);
911 }
912
913 const size_t* compressed_a_shape = compressed_input1_shape;
914 const size_t* compressed_b_shape = compressed_input2_shape;
915 if (compressed_input1_shape[0] == 1) {
916 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.ropc_function;
917 binary_elementwise_op->context.elementwise_binary.a = input2;
918 binary_elementwise_op->context.elementwise_binary.b = input1;
919 compressed_a_shape = compressed_input2_shape;
920 compressed_b_shape = compressed_input1_shape;
921 if (reversed_params_size != 0) {
922 memcpy(&binary_elementwise_op->context.elementwise_binary.params, reversed_params, reversed_params_size);
923 }
924 } else if (compressed_input2_shape[0] == 1) {
925 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.opc_function;
926 } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
927 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.op_function;
928 }
929 size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
930 for (size_t i = 1; i < num_compressed_dims; i++) {
931 if (compressed_a_shape[i] != 1) {
932 binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride << log2_element_size;
933 }
934 if (compressed_b_shape[i] != 1) {
935 binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride << log2_element_size;
936 }
937 binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride << log2_element_size;
938 a_stride *= compressed_a_shape[i];
939 b_stride *= compressed_b_shape[i];
940 y_stride *= compressed_output_shape[i];
941 }
942
943 binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
944 binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
945 binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
946 binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
947 binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
948 binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
949 binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
950 binary_elementwise_op->compute.tile[0] = 1;
951 binary_elementwise_op->compute.tile[1] = 1;
952 binary_elementwise_op->state = xnn_run_state_ready;
953
954 return xnn_status_success;
955 }
956
setup_binary_elementwise_nd_f16(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)957 static enum xnn_status setup_binary_elementwise_nd_f16(
958 xnn_operator_t binary_elementwise_op,
959 enum xnn_operator_type expected_operator_type,
960 size_t num_input1_dims,
961 const size_t* input1_shape,
962 size_t num_input2_dims,
963 const size_t* input2_shape,
964 const void* input1,
965 const void* input2,
966 void* output,
967 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
968 size_t num_threads)
969 {
970 return setup_binary_elementwise_nd(
971 binary_elementwise_op,
972 expected_operator_type,
973 num_input1_dims,
974 input1_shape,
975 num_input2_dims,
976 input2_shape,
977 input1,
978 input2,
979 output,
980 XNN_INIT_FLAG_F16,
981 1 /* log2(sizeof(half)) */,
982 &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
983 &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
984 vbinary,
985 num_threads);
986 }
987
setup_binary_elementwise_nd_f32(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)988 static enum xnn_status setup_binary_elementwise_nd_f32(
989 xnn_operator_t binary_elementwise_op,
990 enum xnn_operator_type expected_operator_type,
991 size_t num_input1_dims,
992 const size_t* input1_shape,
993 size_t num_input2_dims,
994 const size_t* input2_shape,
995 const float* input1,
996 const float* input2,
997 float* output,
998 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
999 size_t num_threads)
1000 {
1001 return setup_binary_elementwise_nd(
1002 binary_elementwise_op, expected_operator_type,
1003 num_input1_dims, input1_shape,
1004 num_input2_dims, input2_shape,
1005 input1, input2, output,
1006 XNN_INIT_FLAG_F32,
1007 2 /* log2(sizeof(float)) */,
1008 &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1009 &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
1010 vbinary,
1011 num_threads);
1012 }
1013
xnn_setup_add_nd_qs8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1014 enum xnn_status xnn_setup_add_nd_qs8(
1015 xnn_operator_t add_op,
1016 size_t num_input1_dims,
1017 const size_t* input1_shape,
1018 size_t num_input2_dims,
1019 const size_t* input2_shape,
1020 const int8_t* input1,
1021 const int8_t* input2,
1022 int8_t* output,
1023 pthreadpool_t threadpool)
1024 {
1025 return setup_binary_elementwise_nd(
1026 add_op, xnn_operator_type_add_nd_qs8,
1027 num_input1_dims, input1_shape,
1028 num_input2_dims, input2_shape,
1029 input1, input2, output,
1030 XNN_INIT_FLAG_QS8,
1031 0 /* log2(sizeof(int8_t))) */,
1032 &add_op->params.qs8_addsub, sizeof(add_op->params.qs8_addsub),
1033 &add_op->params.qs8_raddsub, sizeof(add_op->params.qs8_raddsub),
1034 &xnn_params.qs8.vadd,
1035 pthreadpool_get_threads_count(threadpool));
1036 }
1037
xnn_setup_add_nd_qu8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1038 enum xnn_status xnn_setup_add_nd_qu8(
1039 xnn_operator_t add_op,
1040 size_t num_input1_dims,
1041 const size_t* input1_shape,
1042 size_t num_input2_dims,
1043 const size_t* input2_shape,
1044 const uint8_t* input1,
1045 const uint8_t* input2,
1046 uint8_t* output,
1047 pthreadpool_t threadpool)
1048 {
1049 return setup_binary_elementwise_nd(
1050 add_op, xnn_operator_type_add_nd_qu8,
1051 num_input1_dims, input1_shape,
1052 num_input2_dims, input2_shape,
1053 input1, input2, output,
1054 XNN_INIT_FLAG_QU8,
1055 0 /* log2(sizeof(uint8_t))) */,
1056 &add_op->params.qu8_addsub, sizeof(add_op->params.qu8_addsub),
1057 &add_op->params.qu8_raddsub, sizeof(add_op->params.qu8_raddsub),
1058 &xnn_params.qu8.vadd,
1059 pthreadpool_get_threads_count(threadpool));
1060 }
1061
xnn_setup_add_nd_f16(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1062 enum xnn_status xnn_setup_add_nd_f16(
1063 xnn_operator_t add_op,
1064 size_t num_input1_dims,
1065 const size_t* input1_shape,
1066 size_t num_input2_dims,
1067 const size_t* input2_shape,
1068 const void* input1,
1069 const void* input2,
1070 void* output,
1071 pthreadpool_t threadpool)
1072 {
1073 return setup_binary_elementwise_nd_f16(
1074 add_op, xnn_operator_type_add_nd_f16,
1075 num_input1_dims, input1_shape,
1076 num_input2_dims, input2_shape,
1077 input1, input2, output,
1078 &xnn_params.f16.vadd,
1079 pthreadpool_get_threads_count(threadpool));
1080 }
1081
xnn_setup_add_nd_f32(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1082 enum xnn_status xnn_setup_add_nd_f32(
1083 xnn_operator_t add_op,
1084 size_t num_input1_dims,
1085 const size_t* input1_shape,
1086 size_t num_input2_dims,
1087 const size_t* input2_shape,
1088 const float* input1,
1089 const float* input2,
1090 float* output,
1091 pthreadpool_t threadpool)
1092 {
1093 return setup_binary_elementwise_nd_f32(
1094 add_op, xnn_operator_type_add_nd_f32,
1095 num_input1_dims, input1_shape,
1096 num_input2_dims, input2_shape,
1097 input1, input2, output,
1098 &xnn_params.f32.vadd,
1099 pthreadpool_get_threads_count(threadpool));
1100 }
1101
xnn_setup_divide_nd_f32(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1102 enum xnn_status xnn_setup_divide_nd_f32(
1103 xnn_operator_t divide_op,
1104 size_t num_input1_dims,
1105 const size_t* input1_shape,
1106 size_t num_input2_dims,
1107 const size_t* input2_shape,
1108 const float* input1,
1109 const float* input2,
1110 float* output,
1111 pthreadpool_t threadpool)
1112 {
1113 return setup_binary_elementwise_nd_f32(
1114 divide_op, xnn_operator_type_divide_nd_f32,
1115 num_input1_dims, input1_shape,
1116 num_input2_dims, input2_shape,
1117 input1, input2, output,
1118 &xnn_params.f32.vdiv,
1119 pthreadpool_get_threads_count(threadpool));
1120 }
1121
xnn_setup_maximum_nd_f32(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1122 enum xnn_status xnn_setup_maximum_nd_f32(
1123 xnn_operator_t maximum_op,
1124 size_t num_input1_dims,
1125 const size_t* input1_shape,
1126 size_t num_input2_dims,
1127 const size_t* input2_shape,
1128 const float* input1,
1129 const float* input2,
1130 float* output,
1131 pthreadpool_t threadpool)
1132 {
1133 return setup_binary_elementwise_nd_f32(
1134 maximum_op, xnn_operator_type_maximum_nd_f32,
1135 num_input1_dims, input1_shape,
1136 num_input2_dims, input2_shape,
1137 input1, input2, output,
1138 &xnn_params.f32.vmax,
1139 pthreadpool_get_threads_count(threadpool));
1140 }
1141
xnn_setup_minimum_nd_f32(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1142 enum xnn_status xnn_setup_minimum_nd_f32(
1143 xnn_operator_t minimum_op,
1144 size_t num_input1_dims,
1145 const size_t* input1_shape,
1146 size_t num_input2_dims,
1147 const size_t* input2_shape,
1148 const float* input1,
1149 const float* input2,
1150 float* output,
1151 pthreadpool_t threadpool)
1152 {
1153 return setup_binary_elementwise_nd_f32(
1154 minimum_op, xnn_operator_type_minimum_nd_f32,
1155 num_input1_dims, input1_shape,
1156 num_input2_dims, input2_shape,
1157 input1, input2, output,
1158 &xnn_params.f32.vmin,
1159 pthreadpool_get_threads_count(threadpool));
1160 }
1161
xnn_setup_multiply_nd_qs8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1162 enum xnn_status xnn_setup_multiply_nd_qs8(
1163 xnn_operator_t multiply_op,
1164 size_t num_input1_dims,
1165 const size_t* input1_shape,
1166 size_t num_input2_dims,
1167 const size_t* input2_shape,
1168 const int8_t* input1,
1169 const int8_t* input2,
1170 int8_t* output,
1171 pthreadpool_t threadpool)
1172 {
1173 return setup_binary_elementwise_nd(
1174 multiply_op, xnn_operator_type_multiply_nd_qs8,
1175 num_input1_dims, input1_shape,
1176 num_input2_dims, input2_shape,
1177 input1, input2, output,
1178 XNN_INIT_FLAG_QS8,
1179 0 /* log2(sizeof(int8_t))) */,
1180 &multiply_op->params.qs8_mul, sizeof(multiply_op->params.qs8_mul),
1181 &multiply_op->params.qs8_rmul, sizeof(multiply_op->params.qs8_rmul),
1182 &xnn_params.qs8.vmul,
1183 pthreadpool_get_threads_count(threadpool));
1184 }
1185
xnn_setup_multiply_nd_qu8(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1186 enum xnn_status xnn_setup_multiply_nd_qu8(
1187 xnn_operator_t multiply_op,
1188 size_t num_input1_dims,
1189 const size_t* input1_shape,
1190 size_t num_input2_dims,
1191 const size_t* input2_shape,
1192 const uint8_t* input1,
1193 const uint8_t* input2,
1194 uint8_t* output,
1195 pthreadpool_t threadpool)
1196 {
1197 return setup_binary_elementwise_nd(
1198 multiply_op, xnn_operator_type_multiply_nd_qu8,
1199 num_input1_dims, input1_shape,
1200 num_input2_dims, input2_shape,
1201 input1, input2, output,
1202 XNN_INIT_FLAG_QU8,
1203 0 /* log2(sizeof(uint8_t))) */,
1204 &multiply_op->params.qu8_mul, sizeof(multiply_op->params.qu8_mul),
1205 &multiply_op->params.qu8_rmul, sizeof(multiply_op->params.qu8_rmul),
1206 &xnn_params.qu8.vmul,
1207 pthreadpool_get_threads_count(threadpool));
1208 }
1209
xnn_setup_multiply_nd_f16(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)1210 enum xnn_status xnn_setup_multiply_nd_f16(
1211 xnn_operator_t multiply_op,
1212 size_t num_input1_dims,
1213 const size_t* input1_shape,
1214 size_t num_input2_dims,
1215 const size_t* input2_shape,
1216 const void* input1,
1217 const void* input2,
1218 void* output,
1219 pthreadpool_t threadpool)
1220 {
1221 return setup_binary_elementwise_nd_f16(
1222 multiply_op, xnn_operator_type_multiply_nd_f16,
1223 num_input1_dims, input1_shape,
1224 num_input2_dims, input2_shape,
1225 input1, input2, output,
1226 &xnn_params.f16.vmul,
1227 pthreadpool_get_threads_count(threadpool));
1228 }
1229
xnn_setup_multiply_nd_f32(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1230 enum xnn_status xnn_setup_multiply_nd_f32(
1231 xnn_operator_t multiply_op,
1232 size_t num_input1_dims,
1233 const size_t* input1_shape,
1234 size_t num_input2_dims,
1235 const size_t* input2_shape,
1236 const float* input1,
1237 const float* input2,
1238 float* output,
1239 pthreadpool_t threadpool)
1240 {
1241 return setup_binary_elementwise_nd_f32(
1242 multiply_op, xnn_operator_type_multiply_nd_f32,
1243 num_input1_dims, input1_shape,
1244 num_input2_dims, input2_shape,
1245 input1, input2, output,
1246 &xnn_params.f32.vmul,
1247 pthreadpool_get_threads_count(threadpool));
1248 }
1249
xnn_setup_squared_difference_nd_f32(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1250 enum xnn_status xnn_setup_squared_difference_nd_f32(
1251 xnn_operator_t squared_difference_op,
1252 size_t num_input1_dims,
1253 const size_t* input1_shape,
1254 size_t num_input2_dims,
1255 const size_t* input2_shape,
1256 const float* input1,
1257 const float* input2,
1258 float* output,
1259 pthreadpool_t threadpool)
1260 {
1261 return setup_binary_elementwise_nd_f32(
1262 squared_difference_op, xnn_operator_type_squared_difference_nd_f32,
1263 num_input1_dims, input1_shape,
1264 num_input2_dims, input2_shape,
1265 input1, input2, output,
1266 &xnn_params.f32.vsqrdiff,
1267 pthreadpool_get_threads_count(threadpool));
1268 }
1269
xnn_setup_subtract_nd_qs8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)1270 enum xnn_status xnn_setup_subtract_nd_qs8(
1271 xnn_operator_t subtract_op,
1272 size_t num_input1_dims,
1273 const size_t* input1_shape,
1274 size_t num_input2_dims,
1275 const size_t* input2_shape,
1276 const int8_t* input1,
1277 const int8_t* input2,
1278 int8_t* output,
1279 pthreadpool_t threadpool)
1280 {
1281 return setup_binary_elementwise_nd(
1282 subtract_op, xnn_operator_type_subtract_nd_qs8,
1283 num_input1_dims, input1_shape,
1284 num_input2_dims, input2_shape,
1285 input1, input2, output,
1286 XNN_INIT_FLAG_QS8,
1287 0 /* log2(sizeof(int8_t))) */,
1288 &subtract_op->params.qs8_addsub, sizeof(subtract_op->params.qs8_addsub),
1289 &subtract_op->params.qs8_raddsub, sizeof(subtract_op->params.qs8_raddsub),
1290 &xnn_params.qs8.vadd,
1291 pthreadpool_get_threads_count(threadpool));
1292 }
1293
xnn_setup_subtract_nd_qu8(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const uint8_t * input1,const uint8_t * input2,uint8_t * output,pthreadpool_t threadpool)1294 enum xnn_status xnn_setup_subtract_nd_qu8(
1295 xnn_operator_t subtract_op,
1296 size_t num_input1_dims,
1297 const size_t* input1_shape,
1298 size_t num_input2_dims,
1299 const size_t* input2_shape,
1300 const uint8_t* input1,
1301 const uint8_t* input2,
1302 uint8_t* output,
1303 pthreadpool_t threadpool)
1304 {
1305 return setup_binary_elementwise_nd(
1306 subtract_op, xnn_operator_type_subtract_nd_qu8,
1307 num_input1_dims, input1_shape,
1308 num_input2_dims, input2_shape,
1309 input1, input2, output,
1310 XNN_INIT_FLAG_QU8,
1311 0 /* log2(sizeof(uint8_t))) */,
1312 &subtract_op->params.qu8_addsub, sizeof(subtract_op->params.qu8_addsub),
1313 &subtract_op->params.qu8_raddsub, sizeof(subtract_op->params.qu8_raddsub),
1314 &xnn_params.qu8.vadd,
1315 pthreadpool_get_threads_count(threadpool));
1316 }
1317
xnn_setup_subtract_nd_f32(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)1318 enum xnn_status xnn_setup_subtract_nd_f32(
1319 xnn_operator_t subtract_op,
1320 size_t num_input1_dims,
1321 const size_t* input1_shape,
1322 size_t num_input2_dims,
1323 const size_t* input2_shape,
1324 const float* input1,
1325 const float* input2,
1326 float* output,
1327 pthreadpool_t threadpool)
1328 {
1329 return setup_binary_elementwise_nd_f32(
1330 subtract_op, xnn_operator_type_subtract_nd_f32,
1331 num_input1_dims, input1_shape,
1332 num_input2_dims, input2_shape,
1333 input1, input2, output,
1334 &xnn_params.f32.vsub,
1335 pthreadpool_get_threads_count(threadpool));
1336 }
1337