1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11
12 #include <fp16.h>
13
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/params-init.h>
19 #include <xnnpack/params.h>
20
21
create_binary_elementwise_nd(uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,const struct vbinary_fused_ukernels * vbinary_fused_ukernels,xnn_operator_t * binary_elementwise_op_out)22 static enum xnn_status create_binary_elementwise_nd(
23 uint32_t flags,
24 const void* params,
25 size_t params_size,
26 uint32_t datatype_init_flags,
27 enum xnn_operator_type operator_type,
28 const struct vbinary_fused_ukernels* vbinary_fused_ukernels,
29 xnn_operator_t* binary_elementwise_op_out)
30 {
31 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
32 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
33 xnn_operator_type_to_string(operator_type));
34 return xnn_status_uninitialized;
35 }
36
37 if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
38 xnn_log_error("failed to create %s operator: operations on data type are not supported",
39 xnn_operator_type_to_string(operator_type));
40 return xnn_status_unsupported_hardware;
41 }
42
43 xnn_operator_t binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
44 if (binary_elementwise_op == NULL) {
45 xnn_log_error(
46 "failed to allocate %zu bytes for %s operator descriptor",
47 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
48 return xnn_status_out_of_memory;
49 }
50
51 if (params_size != 0) {
52 memcpy(&binary_elementwise_op->params, params, params_size);
53 }
54
55 binary_elementwise_op->ukernel.vbinary.op_function = vbinary_fused_ukernels->op_ukernel;
56 binary_elementwise_op->ukernel.vbinary.opc_function = vbinary_fused_ukernels->opc_ukernel;
57 binary_elementwise_op->ukernel.vbinary.ropc_function = vbinary_fused_ukernels->ropc_ukernel;
58
59 binary_elementwise_op->type = operator_type;
60
61 binary_elementwise_op->state = xnn_run_state_invalid;
62
63 *binary_elementwise_op_out = binary_elementwise_op;
64 return xnn_status_success;
65 }
66
create_binary_elementwise_nd_f16(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)67 static enum xnn_status create_binary_elementwise_nd_f16(
68 float output_min,
69 float output_max,
70 uint32_t flags,
71 enum xnn_operator_type operator_type,
72 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
73 xnn_operator_t* binary_elementwise_op_out)
74 {
75 if (isnan(output_min)) {
76 xnn_log_error(
77 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
78 xnn_operator_type_to_string(operator_type));
79 return xnn_status_invalid_parameter;
80 }
81
82 if (isnan(output_max)) {
83 xnn_log_error(
84 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
85 xnn_operator_type_to_string(operator_type));
86 return xnn_status_invalid_parameter;
87 }
88
89 if (fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)) >= fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max))) {
90 xnn_log_error(
91 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
92 xnn_operator_type_to_string(operator_type),
93 fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)),
94 fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)));
95 return xnn_status_invalid_parameter;
96 }
97
98 const struct xnn_f16_minmax_params params = xnn_init_f16_minmax_params(
99 fp16_ieee_from_fp32_value(output_min),
100 fp16_ieee_from_fp32_value(output_max));
101
102 return create_binary_elementwise_nd(
103 flags,
104 ¶ms,
105 sizeof(params),
106 XNN_INIT_FLAG_F16,
107 operator_type,
108 &vbinary->minmax,
109 binary_elementwise_op_out);
110 }
111
create_binary_elementwise_nd_f32(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)112 static enum xnn_status create_binary_elementwise_nd_f32(
113 float output_min,
114 float output_max,
115 uint32_t flags,
116 enum xnn_operator_type operator_type,
117 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
118 xnn_operator_t* binary_elementwise_op_out)
119 {
120 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
121 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
122 xnn_operator_type_to_string(operator_type));
123 return xnn_status_uninitialized;
124 }
125
126 if (isnan(output_min)) {
127 xnn_log_error(
128 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
129 xnn_operator_type_to_string(operator_type));
130 return xnn_status_invalid_parameter;
131 }
132
133 if (isnan(output_max)) {
134 xnn_log_error(
135 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
136 xnn_operator_type_to_string(operator_type));
137 return xnn_status_invalid_parameter;
138 }
139
140 if (output_min >= output_max) {
141 xnn_log_error(
142 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
143 xnn_operator_type_to_string(operator_type), output_min, output_max);
144 return xnn_status_invalid_parameter;
145 }
146
147 const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
148 const struct vbinary_fused_ukernels* vbinary_fused_ukernels = &vbinary->minmax;
149 if (linear_activation && vbinary->linear.op_ukernel != NULL) {
150 vbinary_fused_ukernels = &vbinary->linear;
151 }
152
153 const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
154
155 return create_binary_elementwise_nd(
156 flags,
157 ¶ms,
158 sizeof(params),
159 XNN_INIT_FLAG_F32,
160 operator_type,
161 vbinary_fused_ukernels,
162 binary_elementwise_op_out);
163 }
164
xnn_create_add_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)165 enum xnn_status xnn_create_add_nd_qs8(
166 int8_t input1_zero_point,
167 float input1_scale,
168 int8_t input2_zero_point,
169 float input2_scale,
170 int8_t output_zero_point,
171 float output_scale,
172 int8_t output_min,
173 int8_t output_max,
174 uint32_t flags,
175 xnn_operator_t* add_op_out)
176 {
177 if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
178 xnn_log_error(
179 "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
180 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_scale);
181 return xnn_status_invalid_parameter;
182 }
183
184 if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
185 xnn_log_error(
186 "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
187 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_scale);
188 return xnn_status_invalid_parameter;
189 }
190
191 if (output_scale <= 0.0f || !isnormal(output_scale)) {
192 xnn_log_error(
193 "failed to create %s operator with %.7g output scale: scale must be finite and positive",
194 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_scale);
195 return xnn_status_invalid_parameter;
196 }
197
198 if (output_min >= output_max) {
199 xnn_log_error(
200 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
201 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_min, output_max);
202 return xnn_status_invalid_parameter;
203 }
204
205 const float input1_output_scale = input1_scale / output_scale;
206 if (input1_output_scale < 0x1.0p-14f || input1_output_scale >= 0x1.0p+8f) {
207 xnn_log_error(
208 "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
209 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_output_scale);
210 return xnn_status_unsupported_parameter;
211 }
212
213 const float input2_output_scale = input2_scale / output_scale;
214 if (input2_output_scale < 0x1.0p-14f || input2_output_scale >= 0x1.0p+8f) {
215 xnn_log_error(
216 "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
217 xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_output_scale);
218 return xnn_status_unsupported_parameter;
219 }
220
221 const struct {
222 union xnn_qs8_add_params qs8_add;
223 union xnn_qs8_add_params qs8_radd;
224 } params = {
225 .qs8_add = xnn_init_qs8_add_params(
226 input1_zero_point, input2_zero_point, output_zero_point, input1_output_scale, input2_output_scale, output_min, output_max),
227 .qs8_radd = xnn_init_qs8_add_params(
228 input2_zero_point, input1_zero_point, output_zero_point, input2_output_scale, input1_output_scale, output_min, output_max),
229 };
230 return create_binary_elementwise_nd(
231 flags,
232 ¶ms,
233 sizeof(params),
234 XNN_INIT_FLAG_QS8,
235 xnn_operator_type_add_nd_qs8,
236 &xnn_params.qs8.vadd.minmax,
237 add_op_out);
238 }
239
xnn_create_add_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)240 enum xnn_status xnn_create_add_nd_f16(
241 float output_min,
242 float output_max,
243 uint32_t flags,
244 xnn_operator_t* add_op_out)
245 {
246 return create_binary_elementwise_nd_f16(
247 output_min,
248 output_max,
249 flags,
250 xnn_operator_type_add_nd_f16,
251 &xnn_params.f16.vadd,
252 add_op_out);
253 }
254
xnn_create_add_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)255 enum xnn_status xnn_create_add_nd_f32(
256 float output_min,
257 float output_max,
258 uint32_t flags,
259 xnn_operator_t* add_op_out)
260 {
261 return create_binary_elementwise_nd_f32(
262 output_min,
263 output_max,
264 flags,
265 xnn_operator_type_add_nd_f32,
266 &xnn_params.f32.vadd,
267 add_op_out);
268 }
269
xnn_create_divide_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)270 enum xnn_status xnn_create_divide_nd_f32(
271 float output_min,
272 float output_max,
273 uint32_t flags,
274 xnn_operator_t* divide_op_out)
275 {
276 return create_binary_elementwise_nd_f32(
277 output_min,
278 output_max,
279 flags,
280 xnn_operator_type_divide_nd_f32,
281 &xnn_params.f32.vdiv,
282 divide_op_out);
283 }
284
xnn_create_maximum_nd_f32(uint32_t flags,xnn_operator_t * maximum_op_out)285 enum xnn_status xnn_create_maximum_nd_f32(
286 uint32_t flags,
287 xnn_operator_t* maximum_op_out)
288 {
289 return create_binary_elementwise_nd(
290 flags,
291 NULL /* params */,
292 0 /* params size */,
293 XNN_INIT_FLAG_F32,
294 xnn_operator_type_maximum_nd_f32,
295 &xnn_params.f32.vmax.minmax,
296 maximum_op_out);
297 }
298
xnn_create_minimum_nd_f32(uint32_t flags,xnn_operator_t * minimum_op_out)299 enum xnn_status xnn_create_minimum_nd_f32(
300 uint32_t flags,
301 xnn_operator_t* minimum_op_out)
302 {
303 return create_binary_elementwise_nd(
304 flags,
305 NULL /* params */,
306 0 /* params size */,
307 XNN_INIT_FLAG_F32,
308 xnn_operator_type_minimum_nd_f32,
309 &xnn_params.f32.vmin.minmax,
310 minimum_op_out);
311 }
312
xnn_create_multiply_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)313 enum xnn_status xnn_create_multiply_nd_f16(
314 float output_min,
315 float output_max,
316 uint32_t flags,
317 xnn_operator_t* multiply_op_out)
318 {
319 return create_binary_elementwise_nd_f16(
320 output_min,
321 output_max,
322 flags,
323 xnn_operator_type_multiply_nd_f16,
324 &xnn_params.f16.vmul,
325 multiply_op_out);
326 }
327
xnn_create_multiply_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)328 enum xnn_status xnn_create_multiply_nd_f32(
329 float output_min,
330 float output_max,
331 uint32_t flags,
332 xnn_operator_t* multiply_op_out)
333 {
334 return create_binary_elementwise_nd_f32(
335 output_min,
336 output_max,
337 flags,
338 xnn_operator_type_multiply_nd_f32,
339 &xnn_params.f32.vmul,
340 multiply_op_out);
341 }
342
xnn_create_squared_difference_nd_f32(uint32_t flags,xnn_operator_t * squared_difference_op_out)343 enum xnn_status xnn_create_squared_difference_nd_f32(
344 uint32_t flags,
345 xnn_operator_t* squared_difference_op_out)
346 {
347 return create_binary_elementwise_nd(
348 flags,
349 NULL /* params */,
350 0 /* params size */,
351 XNN_INIT_FLAG_F32,
352 xnn_operator_type_squared_difference_nd_f32,
353 &xnn_params.f32.vsqrdiff.minmax,
354 squared_difference_op_out);
355 }
356
xnn_create_subtract_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)357 enum xnn_status xnn_create_subtract_nd_f32(
358 float output_min,
359 float output_max,
360 uint32_t flags,
361 xnn_operator_t* subtract_op_out)
362 {
363 return create_binary_elementwise_nd_f32(
364 output_min,
365 output_max,
366 flags,
367 xnn_operator_type_subtract_nd_f32,
368 &xnn_params.f32.vsub,
369 subtract_op_out);
370 }
371
setup_binary_elementwise_nd(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,uint32_t datatype_init_flags,uint32_t log2_element_size,const void * params,size_t params_size,const void * reversed_params,size_t reversed_params_size,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)372 static enum xnn_status setup_binary_elementwise_nd(
373 xnn_operator_t binary_elementwise_op,
374 enum xnn_operator_type expected_operator_type,
375 size_t num_input1_dims,
376 const size_t* input1_shape,
377 size_t num_input2_dims,
378 const size_t* input2_shape,
379 const void* input1,
380 const void* input2,
381 void* output,
382 uint32_t datatype_init_flags,
383 uint32_t log2_element_size,
384 const void* params,
385 size_t params_size,
386 const void* reversed_params,
387 size_t reversed_params_size,
388 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
389 size_t num_threads)
390 {
391 binary_elementwise_op->state = xnn_run_state_invalid;
392
393 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
394 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
395 xnn_operator_type_to_string(binary_elementwise_op->type));
396 return xnn_status_uninitialized;
397 }
398
399 if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
400 xnn_log_error("failed to setup %s operator: operations on data type are not supported",
401 xnn_operator_type_to_string(binary_elementwise_op->type));
402 return xnn_status_unsupported_hardware;
403 }
404
405 if (binary_elementwise_op->type != expected_operator_type) {
406 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
407 xnn_operator_type_to_string(expected_operator_type),
408 xnn_operator_type_to_string(binary_elementwise_op->type));
409 return xnn_status_invalid_parameter;
410 }
411
412 if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
413 xnn_log_error(
414 "failed to setup %s operator with %zu and %zu dimensions in input shapes: "
415 "the number of input dimensions must not exceed %d",
416 xnn_operator_type_to_string(binary_elementwise_op->type), num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
417 return xnn_status_unsupported_parameter;
418 }
419
420 for (size_t i = 0; i < num_input1_dims; i++) {
421 if (input1_shape[i] == 0) {
422 xnn_log_error(
423 "failed to setup %s operator: shape dimension #%zu of input #1 is zero",
424 xnn_operator_type_to_string(binary_elementwise_op->type), i);
425 return xnn_status_invalid_parameter;
426 }
427 }
428
429 for (size_t i = 0; i < num_input2_dims; i++) {
430 if (input2_shape[i] == 0) {
431 xnn_log_error(
432 "failed to setup %s operator: shape dimension #%zu of input #2 is zero",
433 xnn_operator_type_to_string(binary_elementwise_op->type), i);
434 return xnn_status_invalid_parameter;
435 }
436 }
437
438 size_t num_compressed_dims = 0;
439 size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
440 size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
441 size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
442 for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
443 compressed_input1_shape[i] = 1;
444 compressed_input2_shape[i] = 1;
445 compressed_output_shape[i] = 1;
446 }
447 bool broadcast_input1 = false;
448 bool broadcast_input2 = false;
449 bool first_nonunit = true;
450 const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
451 for (size_t i = 1; i <= num_common_dims; i++) {
452 const size_t input1_dim = input1_shape[num_input1_dims - i];
453 const size_t input2_dim = input2_shape[num_input2_dims - i];
454 if (input1_dim == 1 && input2_dim == 1) {
455 continue;
456 }
457 assert(!broadcast_input1 || !broadcast_input2);
458
459 if (input1_dim == 1) {
460 if (!broadcast_input1) {
461 broadcast_input1 = true;
462 broadcast_input2 = false;
463 num_compressed_dims++;
464 }
465 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
466 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
467 } else if (input2_dim == 1) {
468 if (!broadcast_input2) {
469 broadcast_input1 = false;
470 broadcast_input2 = true;
471 num_compressed_dims++;
472 }
473 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
474 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
475 } else if (input1_dim == input2_dim) {
476 if (broadcast_input1 || broadcast_input2 || first_nonunit) {
477 broadcast_input1 = false;
478 broadcast_input2 = false;
479 num_compressed_dims++;
480 }
481 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
482 compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
483 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
484 } else {
485 xnn_log_error(
486 "failed to setup %s operator: "
487 "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
488 xnn_operator_type_to_string(binary_elementwise_op->type),
489 num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
490 return xnn_status_invalid_parameter;
491 }
492 first_nonunit = false;
493 }
494 if (num_input1_dims > num_input2_dims) {
495 if (!broadcast_input2) {
496 num_compressed_dims++;
497 }
498 for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
499 const size_t input1_dim = input1_shape[i];
500 compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
501 compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
502 }
503 } else if (num_input2_dims > num_input1_dims) {
504 if (!broadcast_input1) {
505 num_compressed_dims++;
506 }
507 for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
508 const size_t input2_dim = input2_shape[i];
509 compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
510 compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
511 }
512 }
513 num_compressed_dims = max(num_compressed_dims, 1);
514
515 binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
516 .a = input1,
517 .b = input2,
518 .y = output,
519 .elements = compressed_output_shape[0] << log2_element_size,
520 };
521 if (params_size != 0) {
522 memcpy(&binary_elementwise_op->context.elementwise_binary.params, params, params_size);
523 }
524
525 const size_t* compressed_a_shape = compressed_input1_shape;
526 const size_t* compressed_b_shape = compressed_input2_shape;
527 if (compressed_input1_shape[0] == 1) {
528 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.ropc_function;
529 binary_elementwise_op->context.elementwise_binary.a = input2;
530 binary_elementwise_op->context.elementwise_binary.b = input1;
531 compressed_a_shape = compressed_input2_shape;
532 compressed_b_shape = compressed_input1_shape;
533 if (reversed_params_size != 0) {
534 memcpy(&binary_elementwise_op->context.elementwise_binary.params, reversed_params, reversed_params_size);
535 }
536 } else if (compressed_input2_shape[0] == 1) {
537 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.opc_function;
538 } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
539 binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.op_function;
540 }
541 size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
542 for (size_t i = 1; i < num_compressed_dims; i++) {
543 if (compressed_a_shape[i] != 1) {
544 binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride << log2_element_size;
545 }
546 if (compressed_b_shape[i] != 1) {
547 binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride << log2_element_size;
548 }
549 binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride << log2_element_size;
550 a_stride *= compressed_a_shape[i];
551 b_stride *= compressed_b_shape[i];
552 y_stride *= compressed_output_shape[i];
553 }
554
555 binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
556 binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
557 binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
558 binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
559 binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
560 binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
561 binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
562 binary_elementwise_op->compute.tile[0] = 1;
563 binary_elementwise_op->compute.tile[1] = 1;
564 binary_elementwise_op->state = xnn_run_state_ready;
565
566 return xnn_status_success;
567 }
568
setup_binary_elementwise_nd_f16(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)569 static enum xnn_status setup_binary_elementwise_nd_f16(
570 xnn_operator_t binary_elementwise_op,
571 enum xnn_operator_type expected_operator_type,
572 size_t num_input1_dims,
573 const size_t* input1_shape,
574 size_t num_input2_dims,
575 const size_t* input2_shape,
576 const void* input1,
577 const void* input2,
578 void* output,
579 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
580 size_t num_threads)
581 {
582 return setup_binary_elementwise_nd(
583 binary_elementwise_op,
584 expected_operator_type,
585 num_input1_dims,
586 input1_shape,
587 num_input2_dims,
588 input2_shape,
589 input1,
590 input2,
591 output,
592 XNN_INIT_FLAG_F16,
593 1 /* log2(sizeof(half)) */,
594 &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
595 &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
596 vbinary,
597 num_threads);
598 }
599
setup_binary_elementwise_nd_f32(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)600 static enum xnn_status setup_binary_elementwise_nd_f32(
601 xnn_operator_t binary_elementwise_op,
602 enum xnn_operator_type expected_operator_type,
603 size_t num_input1_dims,
604 const size_t* input1_shape,
605 size_t num_input2_dims,
606 const size_t* input2_shape,
607 const float* input1,
608 const float* input2,
609 float* output,
610 const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
611 size_t num_threads)
612 {
613 return setup_binary_elementwise_nd(
614 binary_elementwise_op, expected_operator_type,
615 num_input1_dims, input1_shape,
616 num_input2_dims, input2_shape,
617 input1, input2, output,
618 XNN_INIT_FLAG_F32,
619 2 /* log2(sizeof(float)) */,
620 &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
621 &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
622 vbinary,
623 num_threads);
624 }
625
xnn_setup_add_nd_qs8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)626 enum xnn_status xnn_setup_add_nd_qs8(
627 xnn_operator_t add_op,
628 size_t num_input1_dims,
629 const size_t* input1_shape,
630 size_t num_input2_dims,
631 const size_t* input2_shape,
632 const int8_t* input1,
633 const int8_t* input2,
634 int8_t* output,
635 pthreadpool_t threadpool)
636 {
637 return setup_binary_elementwise_nd(
638 add_op, xnn_operator_type_add_nd_qs8,
639 num_input1_dims, input1_shape,
640 num_input2_dims, input2_shape,
641 input1, input2, output,
642 XNN_INIT_FLAG_QS8,
643 0 /* log2(sizeof(int8_t))) */,
644 &add_op->params.qs8_add, sizeof(add_op->params.qs8_add),
645 &add_op->params.qs8_radd, sizeof(add_op->params.qs8_radd),
646 &xnn_params.qs8.vadd,
647 pthreadpool_get_threads_count(threadpool));
648 }
649
xnn_setup_add_nd_f16(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)650 enum xnn_status xnn_setup_add_nd_f16(
651 xnn_operator_t add_op,
652 size_t num_input1_dims,
653 const size_t* input1_shape,
654 size_t num_input2_dims,
655 const size_t* input2_shape,
656 const void* input1,
657 const void* input2,
658 void* output,
659 pthreadpool_t threadpool)
660 {
661 return setup_binary_elementwise_nd_f16(
662 add_op, xnn_operator_type_add_nd_f16,
663 num_input1_dims, input1_shape,
664 num_input2_dims, input2_shape,
665 input1, input2, output,
666 &xnn_params.f16.vadd,
667 pthreadpool_get_threads_count(threadpool));
668 }
669
xnn_setup_add_nd_f32(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)670 enum xnn_status xnn_setup_add_nd_f32(
671 xnn_operator_t add_op,
672 size_t num_input1_dims,
673 const size_t* input1_shape,
674 size_t num_input2_dims,
675 const size_t* input2_shape,
676 const float* input1,
677 const float* input2,
678 float* output,
679 pthreadpool_t threadpool)
680 {
681 return setup_binary_elementwise_nd_f32(
682 add_op, xnn_operator_type_add_nd_f32,
683 num_input1_dims, input1_shape,
684 num_input2_dims, input2_shape,
685 input1, input2, output,
686 &xnn_params.f32.vadd,
687 pthreadpool_get_threads_count(threadpool));
688 }
689
xnn_setup_divide_nd_f32(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)690 enum xnn_status xnn_setup_divide_nd_f32(
691 xnn_operator_t divide_op,
692 size_t num_input1_dims,
693 const size_t* input1_shape,
694 size_t num_input2_dims,
695 const size_t* input2_shape,
696 const float* input1,
697 const float* input2,
698 float* output,
699 pthreadpool_t threadpool)
700 {
701 return setup_binary_elementwise_nd_f32(
702 divide_op, xnn_operator_type_divide_nd_f32,
703 num_input1_dims, input1_shape,
704 num_input2_dims, input2_shape,
705 input1, input2, output,
706 &xnn_params.f32.vdiv,
707 pthreadpool_get_threads_count(threadpool));
708 }
709
xnn_setup_maximum_nd_f32(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)710 enum xnn_status xnn_setup_maximum_nd_f32(
711 xnn_operator_t maximum_op,
712 size_t num_input1_dims,
713 const size_t* input1_shape,
714 size_t num_input2_dims,
715 const size_t* input2_shape,
716 const float* input1,
717 const float* input2,
718 float* output,
719 pthreadpool_t threadpool)
720 {
721 return setup_binary_elementwise_nd_f32(
722 maximum_op, xnn_operator_type_maximum_nd_f32,
723 num_input1_dims, input1_shape,
724 num_input2_dims, input2_shape,
725 input1, input2, output,
726 &xnn_params.f32.vmax,
727 pthreadpool_get_threads_count(threadpool));
728 }
729
xnn_setup_minimum_nd_f32(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)730 enum xnn_status xnn_setup_minimum_nd_f32(
731 xnn_operator_t minimum_op,
732 size_t num_input1_dims,
733 const size_t* input1_shape,
734 size_t num_input2_dims,
735 const size_t* input2_shape,
736 const float* input1,
737 const float* input2,
738 float* output,
739 pthreadpool_t threadpool)
740 {
741 return setup_binary_elementwise_nd_f32(
742 minimum_op, xnn_operator_type_minimum_nd_f32,
743 num_input1_dims, input1_shape,
744 num_input2_dims, input2_shape,
745 input1, input2, output,
746 &xnn_params.f32.vmin,
747 pthreadpool_get_threads_count(threadpool));
748 }
749
xnn_setup_multiply_nd_f16(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)750 enum xnn_status xnn_setup_multiply_nd_f16(
751 xnn_operator_t multiply_op,
752 size_t num_input1_dims,
753 const size_t* input1_shape,
754 size_t num_input2_dims,
755 const size_t* input2_shape,
756 const void* input1,
757 const void* input2,
758 void* output,
759 pthreadpool_t threadpool)
760 {
761 return setup_binary_elementwise_nd_f16(
762 multiply_op, xnn_operator_type_multiply_nd_f16,
763 num_input1_dims, input1_shape,
764 num_input2_dims, input2_shape,
765 input1, input2, output,
766 &xnn_params.f16.vmul,
767 pthreadpool_get_threads_count(threadpool));
768 }
769
xnn_setup_multiply_nd_f32(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)770 enum xnn_status xnn_setup_multiply_nd_f32(
771 xnn_operator_t multiply_op,
772 size_t num_input1_dims,
773 const size_t* input1_shape,
774 size_t num_input2_dims,
775 const size_t* input2_shape,
776 const float* input1,
777 const float* input2,
778 float* output,
779 pthreadpool_t threadpool)
780 {
781 return setup_binary_elementwise_nd_f32(
782 multiply_op, xnn_operator_type_multiply_nd_f32,
783 num_input1_dims, input1_shape,
784 num_input2_dims, input2_shape,
785 input1, input2, output,
786 &xnn_params.f32.vmul,
787 pthreadpool_get_threads_count(threadpool));
788 }
789
xnn_setup_squared_difference_nd_f32(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)790 enum xnn_status xnn_setup_squared_difference_nd_f32(
791 xnn_operator_t squared_difference_op,
792 size_t num_input1_dims,
793 const size_t* input1_shape,
794 size_t num_input2_dims,
795 const size_t* input2_shape,
796 const float* input1,
797 const float* input2,
798 float* output,
799 pthreadpool_t threadpool)
800 {
801 return setup_binary_elementwise_nd_f32(
802 squared_difference_op, xnn_operator_type_squared_difference_nd_f32,
803 num_input1_dims, input1_shape,
804 num_input2_dims, input2_shape,
805 input1, input2, output,
806 &xnn_params.f32.vsqrdiff,
807 pthreadpool_get_threads_count(threadpool));
808 }
809
xnn_setup_subtract_nd_f32(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)810 enum xnn_status xnn_setup_subtract_nd_f32(
811 xnn_operator_t subtract_op,
812 size_t num_input1_dims,
813 const size_t* input1_shape,
814 size_t num_input2_dims,
815 const size_t* input2_shape,
816 const float* input1,
817 const float* input2,
818 float* output,
819 pthreadpool_t threadpool)
820 {
821 return setup_binary_elementwise_nd_f32(
822 subtract_op, xnn_operator_type_subtract_nd_f32,
823 num_input1_dims, input1_shape,
824 num_input2_dims, input2_shape,
825 input1, input2, output,
826 &xnn_params.f32.vsub,
827 pthreadpool_get_threads_count(threadpool));
828 }
829