• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <fp16.h>
13 
14 #include <xnnpack.h>
15 #include <xnnpack/allocator.h>
16 #include <xnnpack/log.h>
17 #include <xnnpack/operator.h>
18 #include <xnnpack/params-init.h>
19 #include <xnnpack/params.h>
20 
21 
create_binary_elementwise_nd(uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,const struct vbinary_fused_ukernels * vbinary_fused_ukernels,xnn_operator_t * binary_elementwise_op_out)22 static enum xnn_status create_binary_elementwise_nd(
23     uint32_t flags,
24     const void* params,
25     size_t params_size,
26     uint32_t datatype_init_flags,
27     enum xnn_operator_type operator_type,
28     const struct vbinary_fused_ukernels* vbinary_fused_ukernels,
29     xnn_operator_t* binary_elementwise_op_out)
30 {
31   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
32     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
33       xnn_operator_type_to_string(operator_type));
34     return xnn_status_uninitialized;
35   }
36 
37   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
38     xnn_log_error("failed to create %s operator: operations on data type are not supported",
39       xnn_operator_type_to_string(operator_type));
40     return xnn_status_unsupported_hardware;
41   }
42 
43   xnn_operator_t binary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
44   if (binary_elementwise_op == NULL) {
45     xnn_log_error(
46       "failed to allocate %zu bytes for %s operator descriptor",
47       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
48     return xnn_status_out_of_memory;
49   }
50 
51   if (params_size != 0) {
52     memcpy(&binary_elementwise_op->params, params, params_size);
53   }
54 
55   binary_elementwise_op->ukernel.vbinary.op_function   = vbinary_fused_ukernels->op_ukernel;
56   binary_elementwise_op->ukernel.vbinary.opc_function  = vbinary_fused_ukernels->opc_ukernel;
57   binary_elementwise_op->ukernel.vbinary.ropc_function = vbinary_fused_ukernels->ropc_ukernel;
58 
59   binary_elementwise_op->type = operator_type;
60 
61   binary_elementwise_op->state = xnn_run_state_invalid;
62 
63   *binary_elementwise_op_out = binary_elementwise_op;
64   return xnn_status_success;
65 }
66 
create_binary_elementwise_nd_f16(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)67 static enum xnn_status create_binary_elementwise_nd_f16(
68     float output_min,
69     float output_max,
70     uint32_t flags,
71     enum xnn_operator_type operator_type,
72     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
73     xnn_operator_t* binary_elementwise_op_out)
74 {
75   if (isnan(output_min)) {
76     xnn_log_error(
77       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
78       xnn_operator_type_to_string(operator_type));
79     return xnn_status_invalid_parameter;
80   }
81 
82   if (isnan(output_max)) {
83     xnn_log_error(
84       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
85       xnn_operator_type_to_string(operator_type));
86     return xnn_status_invalid_parameter;
87   }
88 
89   if (fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)) >= fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max))) {
90     xnn_log_error(
91       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
92       xnn_operator_type_to_string(operator_type),
93       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)),
94       fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)));
95     return xnn_status_invalid_parameter;
96   }
97 
98   const struct xnn_f16_minmax_params params = xnn_init_f16_minmax_params(
99     fp16_ieee_from_fp32_value(output_min),
100     fp16_ieee_from_fp32_value(output_max));
101 
102   return create_binary_elementwise_nd(
103     flags,
104     &params,
105     sizeof(params),
106     XNN_INIT_FLAG_F16,
107     operator_type,
108     &vbinary->minmax,
109     binary_elementwise_op_out);
110 }
111 
create_binary_elementwise_nd_f32(float output_min,float output_max,uint32_t flags,enum xnn_operator_type operator_type,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],xnn_operator_t * binary_elementwise_op_out)112 static enum xnn_status create_binary_elementwise_nd_f32(
113     float output_min,
114     float output_max,
115     uint32_t flags,
116     enum xnn_operator_type operator_type,
117     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
118     xnn_operator_t* binary_elementwise_op_out)
119 {
120   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
121     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
122       xnn_operator_type_to_string(operator_type));
123     return xnn_status_uninitialized;
124   }
125 
126   if (isnan(output_min)) {
127     xnn_log_error(
128       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
129       xnn_operator_type_to_string(operator_type));
130     return xnn_status_invalid_parameter;
131   }
132 
133   if (isnan(output_max)) {
134     xnn_log_error(
135       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
136       xnn_operator_type_to_string(operator_type));
137     return xnn_status_invalid_parameter;
138   }
139 
140   if (output_min >= output_max) {
141     xnn_log_error(
142       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
143       xnn_operator_type_to_string(operator_type), output_min, output_max);
144     return xnn_status_invalid_parameter;
145   }
146 
147   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
148   const struct vbinary_fused_ukernels* vbinary_fused_ukernels = &vbinary->minmax;
149   if (linear_activation && vbinary->linear.op_ukernel != NULL) {
150     vbinary_fused_ukernels = &vbinary->linear;
151   }
152 
153   const union xnn_f32_minmax_params params = xnn_init_f32_minmax_params(output_min, output_max);
154 
155   return create_binary_elementwise_nd(
156     flags,
157     &params,
158     sizeof(params),
159     XNN_INIT_FLAG_F32,
160     operator_type,
161     vbinary_fused_ukernels,
162     binary_elementwise_op_out);
163 }
164 
xnn_create_add_nd_qs8(int8_t input1_zero_point,float input1_scale,int8_t input2_zero_point,float input2_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * add_op_out)165 enum xnn_status xnn_create_add_nd_qs8(
166     int8_t input1_zero_point,
167     float input1_scale,
168     int8_t input2_zero_point,
169     float input2_scale,
170     int8_t output_zero_point,
171     float output_scale,
172     int8_t output_min,
173     int8_t output_max,
174     uint32_t flags,
175     xnn_operator_t* add_op_out)
176 {
177   if (input1_scale <= 0.0f || !isnormal(input1_scale)) {
178     xnn_log_error(
179       "failed to create %s operator with %.7g input 1 scale: scale must be finite and positive",
180       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_scale);
181     return xnn_status_invalid_parameter;
182   }
183 
184   if (input2_scale <= 0.0f || !isnormal(input2_scale)) {
185     xnn_log_error(
186       "failed to create %s operator with %.7g input 2 scale: scale must be finite and positive",
187       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_scale);
188     return xnn_status_invalid_parameter;
189   }
190 
191   if (output_scale <= 0.0f || !isnormal(output_scale)) {
192     xnn_log_error(
193       "failed to create %s operator with %.7g output scale: scale must be finite and positive",
194       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_scale);
195     return xnn_status_invalid_parameter;
196   }
197 
198   if (output_min >= output_max) {
199     xnn_log_error(
200       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: lower bound must be below upper bound",
201       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), output_min, output_max);
202     return xnn_status_invalid_parameter;
203   }
204 
205   const float input1_output_scale = input1_scale / output_scale;
206   if (input1_output_scale < 0x1.0p-14f || input1_output_scale >= 0x1.0p+8f) {
207     xnn_log_error(
208       "failed to create %s operator with %.7g input1-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
209       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input1_output_scale);
210     return xnn_status_unsupported_parameter;
211   }
212 
213   const float input2_output_scale = input2_scale / output_scale;
214   if (input2_output_scale < 0x1.0p-14f || input2_output_scale >= 0x1.0p+8f) {
215     xnn_log_error(
216       "failed to create %s operator with %.7g input2-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
217       xnn_operator_type_to_string(xnn_operator_type_add_nd_qs8), input2_output_scale);
218     return xnn_status_unsupported_parameter;
219   }
220 
221   const struct {
222     union xnn_qs8_add_params qs8_add;
223     union xnn_qs8_add_params qs8_radd;
224   } params = {
225     .qs8_add = xnn_init_qs8_add_params(
226       input1_zero_point, input2_zero_point, output_zero_point, input1_output_scale, input2_output_scale, output_min, output_max),
227     .qs8_radd = xnn_init_qs8_add_params(
228       input2_zero_point, input1_zero_point, output_zero_point, input2_output_scale, input1_output_scale, output_min, output_max),
229   };
230   return create_binary_elementwise_nd(
231     flags,
232     &params,
233     sizeof(params),
234     XNN_INIT_FLAG_QS8,
235     xnn_operator_type_add_nd_qs8,
236     &xnn_params.qs8.vadd.minmax,
237     add_op_out);
238 }
239 
xnn_create_add_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)240 enum xnn_status xnn_create_add_nd_f16(
241     float output_min,
242     float output_max,
243     uint32_t flags,
244     xnn_operator_t* add_op_out)
245 {
246   return create_binary_elementwise_nd_f16(
247     output_min,
248     output_max,
249     flags,
250     xnn_operator_type_add_nd_f16,
251     &xnn_params.f16.vadd,
252     add_op_out);
253 }
254 
xnn_create_add_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * add_op_out)255 enum xnn_status xnn_create_add_nd_f32(
256     float output_min,
257     float output_max,
258     uint32_t flags,
259     xnn_operator_t* add_op_out)
260 {
261   return create_binary_elementwise_nd_f32(
262     output_min,
263     output_max,
264     flags,
265     xnn_operator_type_add_nd_f32,
266     &xnn_params.f32.vadd,
267     add_op_out);
268 }
269 
xnn_create_divide_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * divide_op_out)270 enum xnn_status xnn_create_divide_nd_f32(
271     float output_min,
272     float output_max,
273     uint32_t flags,
274     xnn_operator_t* divide_op_out)
275 {
276   return create_binary_elementwise_nd_f32(
277     output_min,
278     output_max,
279     flags,
280     xnn_operator_type_divide_nd_f32,
281     &xnn_params.f32.vdiv,
282     divide_op_out);
283 }
284 
xnn_create_maximum_nd_f32(uint32_t flags,xnn_operator_t * maximum_op_out)285 enum xnn_status xnn_create_maximum_nd_f32(
286     uint32_t flags,
287     xnn_operator_t* maximum_op_out)
288 {
289   return create_binary_elementwise_nd(
290     flags,
291     NULL /* params */,
292     0 /* params size */,
293     XNN_INIT_FLAG_F32,
294     xnn_operator_type_maximum_nd_f32,
295     &xnn_params.f32.vmax.minmax,
296     maximum_op_out);
297 }
298 
xnn_create_minimum_nd_f32(uint32_t flags,xnn_operator_t * minimum_op_out)299 enum xnn_status xnn_create_minimum_nd_f32(
300     uint32_t flags,
301     xnn_operator_t* minimum_op_out)
302 {
303   return create_binary_elementwise_nd(
304     flags,
305     NULL /* params */,
306     0 /* params size */,
307     XNN_INIT_FLAG_F32,
308     xnn_operator_type_minimum_nd_f32,
309     &xnn_params.f32.vmin.minmax,
310     minimum_op_out);
311 }
312 
xnn_create_multiply_nd_f16(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)313 enum xnn_status xnn_create_multiply_nd_f16(
314     float output_min,
315     float output_max,
316     uint32_t flags,
317     xnn_operator_t* multiply_op_out)
318 {
319   return create_binary_elementwise_nd_f16(
320     output_min,
321     output_max,
322     flags,
323     xnn_operator_type_multiply_nd_f16,
324     &xnn_params.f16.vmul,
325     multiply_op_out);
326 }
327 
xnn_create_multiply_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * multiply_op_out)328 enum xnn_status xnn_create_multiply_nd_f32(
329     float output_min,
330     float output_max,
331     uint32_t flags,
332     xnn_operator_t* multiply_op_out)
333 {
334   return create_binary_elementwise_nd_f32(
335     output_min,
336     output_max,
337     flags,
338     xnn_operator_type_multiply_nd_f32,
339     &xnn_params.f32.vmul,
340     multiply_op_out);
341 }
342 
xnn_create_squared_difference_nd_f32(uint32_t flags,xnn_operator_t * squared_difference_op_out)343 enum xnn_status xnn_create_squared_difference_nd_f32(
344     uint32_t flags,
345     xnn_operator_t* squared_difference_op_out)
346 {
347   return create_binary_elementwise_nd(
348     flags,
349     NULL /* params */,
350     0 /* params size */,
351     XNN_INIT_FLAG_F32,
352     xnn_operator_type_squared_difference_nd_f32,
353     &xnn_params.f32.vsqrdiff.minmax,
354     squared_difference_op_out);
355 }
356 
xnn_create_subtract_nd_f32(float output_min,float output_max,uint32_t flags,xnn_operator_t * subtract_op_out)357 enum xnn_status xnn_create_subtract_nd_f32(
358     float output_min,
359     float output_max,
360     uint32_t flags,
361     xnn_operator_t* subtract_op_out)
362 {
363   return create_binary_elementwise_nd_f32(
364     output_min,
365     output_max,
366     flags,
367     xnn_operator_type_subtract_nd_f32,
368     &xnn_params.f32.vsub,
369     subtract_op_out);
370 }
371 
setup_binary_elementwise_nd(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,uint32_t datatype_init_flags,uint32_t log2_element_size,const void * params,size_t params_size,const void * reversed_params,size_t reversed_params_size,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)372 static enum xnn_status setup_binary_elementwise_nd(
373     xnn_operator_t binary_elementwise_op,
374     enum xnn_operator_type expected_operator_type,
375     size_t num_input1_dims,
376     const size_t* input1_shape,
377     size_t num_input2_dims,
378     const size_t* input2_shape,
379     const void* input1,
380     const void* input2,
381     void* output,
382     uint32_t datatype_init_flags,
383     uint32_t log2_element_size,
384     const void* params,
385     size_t params_size,
386     const void* reversed_params,
387     size_t reversed_params_size,
388     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
389     size_t num_threads)
390 {
391   binary_elementwise_op->state = xnn_run_state_invalid;
392 
393   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
394     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
395       xnn_operator_type_to_string(binary_elementwise_op->type));
396     return xnn_status_uninitialized;
397   }
398 
399   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
400     xnn_log_error("failed to setup %s operator: operations on data type are not supported",
401       xnn_operator_type_to_string(binary_elementwise_op->type));
402     return xnn_status_unsupported_hardware;
403   }
404 
405   if (binary_elementwise_op->type != expected_operator_type) {
406     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
407       xnn_operator_type_to_string(expected_operator_type),
408       xnn_operator_type_to_string(binary_elementwise_op->type));
409     return xnn_status_invalid_parameter;
410   }
411 
412   if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
413     xnn_log_error(
414       "failed to setup %s operator with %zu and %zu dimensions in input shapes: "
415       "the number of input dimensions must not exceed %d",
416       xnn_operator_type_to_string(binary_elementwise_op->type), num_input1_dims, num_input2_dims, XNN_MAX_TENSOR_DIMS);
417     return xnn_status_unsupported_parameter;
418   }
419 
420   for (size_t i = 0; i < num_input1_dims; i++) {
421     if (input1_shape[i] == 0) {
422       xnn_log_error(
423         "failed to setup %s operator: shape dimension #%zu of input #1 is zero",
424         xnn_operator_type_to_string(binary_elementwise_op->type), i);
425       return xnn_status_invalid_parameter;
426     }
427   }
428 
429   for (size_t i = 0; i < num_input2_dims; i++) {
430     if (input2_shape[i] == 0) {
431       xnn_log_error(
432         "failed to setup %s operator: shape dimension #%zu of input #2 is zero",
433         xnn_operator_type_to_string(binary_elementwise_op->type), i);
434       return xnn_status_invalid_parameter;
435     }
436   }
437 
438   size_t num_compressed_dims = 0;
439   size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
440   size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
441   size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
442   for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
443     compressed_input1_shape[i] = 1;
444     compressed_input2_shape[i] = 1;
445     compressed_output_shape[i] = 1;
446   }
447   bool broadcast_input1 = false;
448   bool broadcast_input2 = false;
449   bool first_nonunit = true;
450   const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
451   for (size_t i = 1; i <= num_common_dims; i++) {
452     const size_t input1_dim = input1_shape[num_input1_dims - i];
453     const size_t input2_dim = input2_shape[num_input2_dims - i];
454     if (input1_dim == 1 && input2_dim == 1) {
455       continue;
456     }
457     assert(!broadcast_input1 || !broadcast_input2);
458 
459     if (input1_dim == 1) {
460       if (!broadcast_input1) {
461         broadcast_input1 = true;
462         broadcast_input2 = false;
463         num_compressed_dims++;
464       }
465       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
466       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
467     } else if (input2_dim == 1) {
468       if (!broadcast_input2) {
469         broadcast_input1 = false;
470         broadcast_input2 = true;
471         num_compressed_dims++;
472       }
473       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
474       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
475     } else if (input1_dim == input2_dim) {
476       if (broadcast_input1 || broadcast_input2 || first_nonunit) {
477         broadcast_input1 = false;
478         broadcast_input2 = false;
479         num_compressed_dims++;
480       }
481       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
482       compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
483       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
484     } else {
485       xnn_log_error(
486         "failed to setup %s operator: "
487         "shape dimension #%zu of input1 (%zu) does not match shape dimension #%zu of input2 (%zu)",
488         xnn_operator_type_to_string(binary_elementwise_op->type),
489         num_input1_dims - i, input1_dim, num_input2_dims - i, input2_dim);
490       return xnn_status_invalid_parameter;
491     }
492     first_nonunit = false;
493   }
494   if (num_input1_dims > num_input2_dims) {
495     if (!broadcast_input2) {
496       num_compressed_dims++;
497     }
498     for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
499       const size_t input1_dim = input1_shape[i];
500       compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
501       compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
502     }
503   } else if (num_input2_dims > num_input1_dims) {
504     if (!broadcast_input1) {
505       num_compressed_dims++;
506     }
507     for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
508       const size_t input2_dim = input2_shape[i];
509       compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
510       compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
511     }
512   }
513   num_compressed_dims = max(num_compressed_dims, 1);
514 
515   binary_elementwise_op->context.elementwise_binary = (struct elementwise_binary_context) {
516     .a = input1,
517     .b = input2,
518     .y = output,
519     .elements = compressed_output_shape[0] << log2_element_size,
520   };
521   if (params_size != 0) {
522     memcpy(&binary_elementwise_op->context.elementwise_binary.params, params, params_size);
523   }
524 
525   const size_t* compressed_a_shape = compressed_input1_shape;
526   const size_t* compressed_b_shape = compressed_input2_shape;
527   if (compressed_input1_shape[0] == 1) {
528     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.ropc_function;
529     binary_elementwise_op->context.elementwise_binary.a = input2;
530     binary_elementwise_op->context.elementwise_binary.b = input1;
531     compressed_a_shape = compressed_input2_shape;
532     compressed_b_shape = compressed_input1_shape;
533     if (reversed_params_size != 0) {
534       memcpy(&binary_elementwise_op->context.elementwise_binary.params, reversed_params, reversed_params_size);
535     }
536   } else if (compressed_input2_shape[0] == 1) {
537     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.opc_function;
538   } else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
539     binary_elementwise_op->context.elementwise_binary.ukernel = binary_elementwise_op->ukernel.vbinary.op_function;
540   }
541   size_t a_stride = compressed_a_shape[0], b_stride = compressed_b_shape[0], y_stride = compressed_output_shape[0];
542   for (size_t i = 1; i < num_compressed_dims; i++) {
543     if (compressed_a_shape[i] != 1) {
544       binary_elementwise_op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = a_stride << log2_element_size;
545     }
546     if (compressed_b_shape[i] != 1) {
547       binary_elementwise_op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = b_stride << log2_element_size;
548     }
549     binary_elementwise_op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] = y_stride << log2_element_size;
550     a_stride *= compressed_a_shape[i];
551     b_stride *= compressed_b_shape[i];
552     y_stride *= compressed_output_shape[i];
553   }
554 
555   binary_elementwise_op->compute.type = xnn_parallelization_type_5d;
556   binary_elementwise_op->compute.task_5d = (pthreadpool_task_5d_t) xnn_compute_elementwise_binary_5d;
557   binary_elementwise_op->compute.range[0] = compressed_output_shape[5];
558   binary_elementwise_op->compute.range[1] = compressed_output_shape[4];
559   binary_elementwise_op->compute.range[2] = compressed_output_shape[3];
560   binary_elementwise_op->compute.range[3] = compressed_output_shape[2];
561   binary_elementwise_op->compute.range[4] = compressed_output_shape[1];
562   binary_elementwise_op->compute.tile[0] = 1;
563   binary_elementwise_op->compute.tile[1] = 1;
564   binary_elementwise_op->state = xnn_run_state_ready;
565 
566   return xnn_status_success;
567 }
568 
setup_binary_elementwise_nd_f16(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)569 static enum xnn_status setup_binary_elementwise_nd_f16(
570     xnn_operator_t binary_elementwise_op,
571     enum xnn_operator_type expected_operator_type,
572     size_t num_input1_dims,
573     const size_t* input1_shape,
574     size_t num_input2_dims,
575     const size_t* input2_shape,
576     const void* input1,
577     const void* input2,
578     void* output,
579     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
580     size_t num_threads)
581 {
582   return setup_binary_elementwise_nd(
583     binary_elementwise_op,
584     expected_operator_type,
585     num_input1_dims,
586     input1_shape,
587     num_input2_dims,
588     input2_shape,
589     input1,
590     input2,
591     output,
592     XNN_INIT_FLAG_F16,
593     1 /* log2(sizeof(half)) */,
594     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
595     &binary_elementwise_op->params.f16_minmax, sizeof(binary_elementwise_op->params.f16_minmax),
596     vbinary,
597     num_threads);
598 }
599 
setup_binary_elementwise_nd_f32(xnn_operator_t binary_elementwise_op,enum xnn_operator_type expected_operator_type,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS (1)],size_t num_threads)600 static enum xnn_status setup_binary_elementwise_nd_f32(
601     xnn_operator_t binary_elementwise_op,
602     enum xnn_operator_type expected_operator_type,
603     size_t num_input1_dims,
604     const size_t* input1_shape,
605     size_t num_input2_dims,
606     const size_t* input2_shape,
607     const float* input1,
608     const float* input2,
609     float* output,
610     const struct vbinary_parameters vbinary[restrict XNN_MIN_ELEMENTS(1)],
611     size_t num_threads)
612 {
613   return setup_binary_elementwise_nd(
614     binary_elementwise_op, expected_operator_type,
615     num_input1_dims, input1_shape,
616     num_input2_dims, input2_shape,
617     input1, input2, output,
618     XNN_INIT_FLAG_F32,
619     2 /* log2(sizeof(float)) */,
620     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
621     &binary_elementwise_op->params.f32_minmax, sizeof(binary_elementwise_op->params.f32_minmax),
622     vbinary,
623     num_threads);
624 }
625 
xnn_setup_add_nd_qs8(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const int8_t * input1,const int8_t * input2,int8_t * output,pthreadpool_t threadpool)626 enum xnn_status xnn_setup_add_nd_qs8(
627     xnn_operator_t add_op,
628     size_t num_input1_dims,
629     const size_t* input1_shape,
630     size_t num_input2_dims,
631     const size_t* input2_shape,
632     const int8_t* input1,
633     const int8_t* input2,
634     int8_t* output,
635     pthreadpool_t threadpool)
636 {
637   return setup_binary_elementwise_nd(
638     add_op, xnn_operator_type_add_nd_qs8,
639     num_input1_dims, input1_shape,
640     num_input2_dims, input2_shape,
641     input1, input2, output,
642     XNN_INIT_FLAG_QS8,
643     0 /* log2(sizeof(int8_t))) */,
644     &add_op->params.qs8_add, sizeof(add_op->params.qs8_add),
645     &add_op->params.qs8_radd, sizeof(add_op->params.qs8_radd),
646     &xnn_params.qs8.vadd,
647     pthreadpool_get_threads_count(threadpool));
648 }
649 
xnn_setup_add_nd_f16(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)650 enum xnn_status xnn_setup_add_nd_f16(
651     xnn_operator_t add_op,
652     size_t num_input1_dims,
653     const size_t* input1_shape,
654     size_t num_input2_dims,
655     const size_t* input2_shape,
656     const void* input1,
657     const void* input2,
658     void* output,
659     pthreadpool_t threadpool)
660 {
661   return setup_binary_elementwise_nd_f16(
662     add_op, xnn_operator_type_add_nd_f16,
663     num_input1_dims, input1_shape,
664     num_input2_dims, input2_shape,
665     input1, input2, output,
666     &xnn_params.f16.vadd,
667     pthreadpool_get_threads_count(threadpool));
668 }
669 
xnn_setup_add_nd_f32(xnn_operator_t add_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)670 enum xnn_status xnn_setup_add_nd_f32(
671     xnn_operator_t add_op,
672     size_t num_input1_dims,
673     const size_t* input1_shape,
674     size_t num_input2_dims,
675     const size_t* input2_shape,
676     const float* input1,
677     const float* input2,
678     float* output,
679     pthreadpool_t threadpool)
680 {
681   return setup_binary_elementwise_nd_f32(
682     add_op, xnn_operator_type_add_nd_f32,
683     num_input1_dims, input1_shape,
684     num_input2_dims, input2_shape,
685     input1, input2, output,
686     &xnn_params.f32.vadd,
687     pthreadpool_get_threads_count(threadpool));
688 }
689 
xnn_setup_divide_nd_f32(xnn_operator_t divide_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)690 enum xnn_status xnn_setup_divide_nd_f32(
691     xnn_operator_t divide_op,
692     size_t num_input1_dims,
693     const size_t* input1_shape,
694     size_t num_input2_dims,
695     const size_t* input2_shape,
696     const float* input1,
697     const float* input2,
698     float* output,
699     pthreadpool_t threadpool)
700 {
701   return setup_binary_elementwise_nd_f32(
702     divide_op, xnn_operator_type_divide_nd_f32,
703     num_input1_dims, input1_shape,
704     num_input2_dims, input2_shape,
705     input1, input2, output,
706     &xnn_params.f32.vdiv,
707     pthreadpool_get_threads_count(threadpool));
708 }
709 
xnn_setup_maximum_nd_f32(xnn_operator_t maximum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)710 enum xnn_status xnn_setup_maximum_nd_f32(
711     xnn_operator_t maximum_op,
712     size_t num_input1_dims,
713     const size_t* input1_shape,
714     size_t num_input2_dims,
715     const size_t* input2_shape,
716     const float* input1,
717     const float* input2,
718     float* output,
719     pthreadpool_t threadpool)
720 {
721   return setup_binary_elementwise_nd_f32(
722     maximum_op, xnn_operator_type_maximum_nd_f32,
723     num_input1_dims, input1_shape,
724     num_input2_dims, input2_shape,
725     input1, input2, output,
726     &xnn_params.f32.vmax,
727     pthreadpool_get_threads_count(threadpool));
728 }
729 
xnn_setup_minimum_nd_f32(xnn_operator_t minimum_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)730 enum xnn_status xnn_setup_minimum_nd_f32(
731     xnn_operator_t minimum_op,
732     size_t num_input1_dims,
733     const size_t* input1_shape,
734     size_t num_input2_dims,
735     const size_t* input2_shape,
736     const float* input1,
737     const float* input2,
738     float* output,
739     pthreadpool_t threadpool)
740 {
741   return setup_binary_elementwise_nd_f32(
742     minimum_op, xnn_operator_type_minimum_nd_f32,
743     num_input1_dims, input1_shape,
744     num_input2_dims, input2_shape,
745     input1, input2, output,
746     &xnn_params.f32.vmin,
747     pthreadpool_get_threads_count(threadpool));
748 }
749 
xnn_setup_multiply_nd_f16(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const void * input1,const void * input2,void * output,pthreadpool_t threadpool)750 enum xnn_status xnn_setup_multiply_nd_f16(
751     xnn_operator_t multiply_op,
752     size_t num_input1_dims,
753     const size_t* input1_shape,
754     size_t num_input2_dims,
755     const size_t* input2_shape,
756     const void* input1,
757     const void* input2,
758     void* output,
759     pthreadpool_t threadpool)
760 {
761   return setup_binary_elementwise_nd_f16(
762     multiply_op, xnn_operator_type_multiply_nd_f16,
763     num_input1_dims, input1_shape,
764     num_input2_dims, input2_shape,
765     input1, input2, output,
766     &xnn_params.f16.vmul,
767     pthreadpool_get_threads_count(threadpool));
768 }
769 
xnn_setup_multiply_nd_f32(xnn_operator_t multiply_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)770 enum xnn_status xnn_setup_multiply_nd_f32(
771     xnn_operator_t multiply_op,
772     size_t num_input1_dims,
773     const size_t* input1_shape,
774     size_t num_input2_dims,
775     const size_t* input2_shape,
776     const float* input1,
777     const float* input2,
778     float* output,
779     pthreadpool_t threadpool)
780 {
781   return setup_binary_elementwise_nd_f32(
782     multiply_op, xnn_operator_type_multiply_nd_f32,
783     num_input1_dims, input1_shape,
784     num_input2_dims, input2_shape,
785     input1, input2, output,
786     &xnn_params.f32.vmul,
787     pthreadpool_get_threads_count(threadpool));
788 }
789 
xnn_setup_squared_difference_nd_f32(xnn_operator_t squared_difference_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)790 enum xnn_status xnn_setup_squared_difference_nd_f32(
791     xnn_operator_t squared_difference_op,
792     size_t num_input1_dims,
793     const size_t* input1_shape,
794     size_t num_input2_dims,
795     const size_t* input2_shape,
796     const float* input1,
797     const float* input2,
798     float* output,
799     pthreadpool_t threadpool)
800 {
801   return setup_binary_elementwise_nd_f32(
802     squared_difference_op, xnn_operator_type_squared_difference_nd_f32,
803     num_input1_dims, input1_shape,
804     num_input2_dims, input2_shape,
805     input1, input2, output,
806     &xnn_params.f32.vsqrdiff,
807     pthreadpool_get_threads_count(threadpool));
808 }
809 
xnn_setup_subtract_nd_f32(xnn_operator_t subtract_op,size_t num_input1_dims,const size_t * input1_shape,size_t num_input2_dims,const size_t * input2_shape,const float * input1,const float * input2,float * output,pthreadpool_t threadpool)810 enum xnn_status xnn_setup_subtract_nd_f32(
811     xnn_operator_t subtract_op,
812     size_t num_input1_dims,
813     const size_t* input1_shape,
814     size_t num_input2_dims,
815     const size_t* input2_shape,
816     const float* input1,
817     const float* input2,
818     float* output,
819     pthreadpool_t threadpool)
820 {
821   return setup_binary_elementwise_nd_f32(
822     subtract_op, xnn_operator_type_subtract_nd_f32,
823     num_input1_dims, input1_shape,
824     num_input2_dims, input2_shape,
825     input1, input2, output,
826     &xnn_params.f32.vsub,
827     pthreadpool_get_threads_count(threadpool));
828 }
829