• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/operator.h>
16 #include <xnnpack/params-init.h>
17 #include <xnnpack/params.h>
18 
19 
create_unary_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,const void * params,size_t params_size,enum xnn_operator_type operator_type,xnn_univector_ukernel_function ukernel,xnn_operator_t * unary_elementwise_op_out)20 static enum xnn_status create_unary_elementwise_nc(
21     size_t channels,
22     size_t input_stride,
23     size_t output_stride,
24     uint32_t flags,
25     const void* params,
26     size_t params_size,
27     enum xnn_operator_type operator_type,
28     xnn_univector_ukernel_function ukernel,
29     xnn_operator_t* unary_elementwise_op_out)
30 {
31   xnn_operator_t unary_elementwise_op = NULL;
32 
33   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
35       xnn_operator_type_to_string(operator_type));
36     return xnn_status_uninitialized;
37   }
38 
39   if (channels == 0) {
40     xnn_log_error(
41       "failed to create %s operator with %zu channels: number of channels must be non-zero",
42       xnn_operator_type_to_string(operator_type), channels);
43     return xnn_status_invalid_parameter;
44   }
45 
46   if (input_stride < channels) {
47     xnn_log_error(
48       "failed to create %s operator with input element stride of %zu: "
49       "stride must be at least as large as the number of channels (%zu)",
50       xnn_operator_type_to_string(operator_type), input_stride, channels);
51     return xnn_status_invalid_parameter;
52   }
53 
54   if (output_stride < channels) {
55     xnn_log_error(
56       "failed to create %s operator with output element stride of %zu: "
57       "stride must be at least as large as the number of channels (%zu)",
58       xnn_operator_type_to_string(operator_type), output_stride, channels);
59     return xnn_status_invalid_parameter;
60   }
61 
62   unary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
63   if (unary_elementwise_op == NULL) {
64     xnn_log_error(
65       "failed to allocate %zu bytes for %s operator descriptor",
66       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
67     return xnn_status_out_of_memory;
68   }
69 
70   unary_elementwise_op->channels = channels;
71   unary_elementwise_op->input_pixel_stride = input_stride;
72   unary_elementwise_op->output_pixel_stride = output_stride;
73   if (params_size != 0) {
74     memcpy(&unary_elementwise_op->params, params, params_size);
75   }
76 
77   unary_elementwise_op->ukernel.vunary.function = ukernel;
78   unary_elementwise_op->type = operator_type;
79   unary_elementwise_op->flags = flags;
80 
81   unary_elementwise_op->state = xnn_run_state_invalid;
82 
83   *unary_elementwise_op_out = unary_elementwise_op;
84   return xnn_status_success;
85 }
86 
setup_unary_elementwise_nc(xnn_operator_t unary_elementwise_op,size_t batch_size,const void * input,void * output,uint32_t log2_input_size,uint32_t log2_output_size,const void * params,size_t params_size,size_t num_threads)87 static enum xnn_status setup_unary_elementwise_nc(
88     xnn_operator_t unary_elementwise_op,
89     size_t batch_size,
90     const void* input,
91     void* output,
92     uint32_t log2_input_size,
93     uint32_t log2_output_size,
94     const void* params,
95     size_t params_size,
96     size_t num_threads)
97 {
98   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
99     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
100       xnn_operator_type_to_string(unary_elementwise_op->type));
101     return xnn_status_uninitialized;
102   }
103 
104   if (batch_size == 0) {
105     unary_elementwise_op->state = xnn_run_state_skip;
106     return xnn_status_success;
107   }
108 
109   const size_t channels = unary_elementwise_op->channels;
110   const size_t input_stride = unary_elementwise_op->input_pixel_stride;
111   const size_t output_stride = unary_elementwise_op->output_pixel_stride;
112 
113   xnn_univector_ukernel_function ukernel = unary_elementwise_op->ukernel.vunary.function;
114 
115   if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
116     const size_t block_size = 4096;
117     unary_elementwise_op->context.univector_contiguous = (struct univector_contiguous_context) {
118       .x = input,
119       .y = output,
120       .log2_xsize = log2_input_size,
121       .log2_ysize = log2_output_size,
122       .ukernel = ukernel,
123     };
124     if (params_size != 0) {
125       memcpy(&unary_elementwise_op->context.univector_contiguous.params, params, params_size);
126     }
127 
128     const size_t range = (batch_size * channels) << log2_input_size;
129     unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
130     unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
131     unary_elementwise_op->compute.range[0] = range;
132     unary_elementwise_op->compute.tile[0] = (num_threads == 1) ? range : block_size;
133   } else {
134     unary_elementwise_op->context.univector_strided = (struct univector_strided_context) {
135       .n = channels << log2_input_size,
136       .x = input,
137       .x_stride = input_stride << log2_input_size,
138       .y = output,
139       .y_stride = output_stride << log2_output_size,
140       .ukernel = ukernel,
141     };
142     if (params_size != 0) {
143       memcpy(&unary_elementwise_op->context.univector_strided.params, params, params_size);
144     }
145     unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
146     unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
147     unary_elementwise_op->compute.range[0] = batch_size;
148     unary_elementwise_op->compute.tile[0] = (num_threads == 1) ? batch_size : 1;
149   }
150   unary_elementwise_op->state = xnn_run_state_ready;
151 
152   return xnn_status_success;
153 }
154 
xnn_create_clamp_nc_s8(size_t channels,size_t input_stride,size_t output_stride,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * clamp_op_out)155 enum xnn_status xnn_create_clamp_nc_s8(
156     size_t channels,
157     size_t input_stride,
158     size_t output_stride,
159     int8_t output_min,
160     int8_t output_max,
161     uint32_t flags,
162     xnn_operator_t* clamp_op_out)
163 {
164   if (output_min >= output_max) {
165     xnn_log_error(
166       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
167       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8), output_min, output_max);
168     return xnn_status_invalid_parameter;
169   }
170 
171   union xnn_s8_minmax_params params;
172   if (xnn_params.s8.clamp.init.s8_minmax != NULL) {
173     xnn_params.s8.clamp.init.s8_minmax(&params, output_min, output_max);
174   }
175   return create_unary_elementwise_nc(
176     channels, input_stride, output_stride, flags,
177     &params, sizeof(params),
178     xnn_operator_type_clamp_nc_s8,
179     xnn_params.s8.clamp.ukernel,
180     clamp_op_out);
181 }
182 
xnn_create_clamp_nc_u8(size_t channels,size_t input_stride,size_t output_stride,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * clamp_op_out)183 enum xnn_status xnn_create_clamp_nc_u8(
184     size_t channels,
185     size_t input_stride,
186     size_t output_stride,
187     uint8_t output_min,
188     uint8_t output_max,
189     uint32_t flags,
190     xnn_operator_t* clamp_op_out)
191 {
192   if (output_min >= output_max) {
193     xnn_log_error(
194       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
195       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8), output_min, output_max);
196     return xnn_status_invalid_parameter;
197   }
198 
199   union xnn_u8_minmax_params params;
200   if (xnn_params.u8.clamp.init.u8_minmax != NULL) {
201     xnn_params.u8.clamp.init.u8_minmax(&params, output_min, output_max);
202   }
203   return create_unary_elementwise_nc(
204     channels, input_stride, output_stride, flags,
205     &params, sizeof(params),
206     xnn_operator_type_clamp_nc_u8,
207     xnn_params.u8.clamp.ukernel,
208     clamp_op_out);
209 }
210 
xnn_create_clamp_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float output_min,float output_max,uint32_t flags,xnn_operator_t * clamp_op_out)211 enum xnn_status xnn_create_clamp_nc_f32(
212     size_t channels,
213     size_t input_stride,
214     size_t output_stride,
215     float output_min,
216     float output_max,
217     uint32_t flags,
218     xnn_operator_t* clamp_op_out)
219 {
220   if (isnan(output_min)) {
221     xnn_log_error(
222       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
223       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
224     return xnn_status_invalid_parameter;
225   }
226 
227   if (isnan(output_max)) {
228     xnn_log_error(
229       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
230       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
231     return xnn_status_invalid_parameter;
232   }
233 
234   if (output_min >= output_max) {
235     xnn_log_error(
236       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
237       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), output_min, output_max);
238     return xnn_status_invalid_parameter;
239   }
240 
241   const bool relu_activation = (output_max == INFINITY) && (output_min == 0.0f);
242   xnn_univector_ukernel_function clamp_ukernel = (relu_activation && (xnn_params.f32.relu != NULL)) ?
243     xnn_params.f32.relu : xnn_params.f32.clamp.ukernel;
244 
245   union xnn_f32_minmax_params params;
246   if (xnn_params.f32.clamp.init.f32_minmax != NULL) {
247     xnn_params.f32.clamp.init.f32_minmax(&params, output_min, output_max);
248   }
249   return create_unary_elementwise_nc(
250     channels, input_stride, output_stride, flags,
251     &params, sizeof(params),
252     xnn_operator_type_clamp_nc_f32,
253     clamp_ukernel,
254     clamp_op_out);
255 }
256 
xnn_create_abs_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * abs_op_out)257 enum xnn_status xnn_create_abs_nc_f32(
258     size_t channels,
259     size_t input_stride,
260     size_t output_stride,
261     uint32_t flags,
262     xnn_operator_t* abs_op_out)
263 {
264   union xnn_f32_abs_params params;
265   if (xnn_params.f32.abs.init.f32_abs != NULL) {
266     xnn_params.f32.abs.init.f32_abs(&params);
267   }
268   return create_unary_elementwise_nc(
269     channels, input_stride, output_stride, flags,
270     &params, sizeof(params),
271     xnn_operator_type_abs_nc_f32,
272     xnn_params.f32.abs.ukernel,
273     abs_op_out);
274 }
275 
xnn_create_bankers_rounding_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * rounding_op_out)276 enum xnn_status xnn_create_bankers_rounding_nc_f32(
277     size_t channels,
278     size_t input_stride,
279     size_t output_stride,
280     uint32_t flags,
281     xnn_operator_t* rounding_op_out)
282 {
283   union xnn_f32_rnd_params params;
284   if (xnn_params.f32.rndne.init.f32_rnd != NULL) {
285     xnn_params.f32.rndne.init.f32_rnd(&params);
286   }
287   return create_unary_elementwise_nc(
288     channels, input_stride, output_stride, flags,
289     &params, sizeof(params),
290     xnn_operator_type_bankers_rounding_nc_f32,
291     xnn_params.f32.rndne.ukernel,
292     rounding_op_out);
293 }
294 
xnn_create_ceiling_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * ceiling_op_out)295 enum xnn_status xnn_create_ceiling_nc_f32(
296     size_t channels,
297     size_t input_stride,
298     size_t output_stride,
299     uint32_t flags,
300     xnn_operator_t* ceiling_op_out)
301 {
302   union xnn_f32_rnd_params params;
303   if (xnn_params.f32.rndu.init.f32_rnd != NULL) {
304     xnn_params.f32.rndu.init.f32_rnd(&params);
305   }
306   return create_unary_elementwise_nc(
307     channels, input_stride, output_stride, flags,
308     &params, sizeof(params),
309     xnn_operator_type_ceiling_nc_f32,
310     xnn_params.f32.rndu.ukernel,
311     ceiling_op_out);
312 }
313 
xnn_create_convert_nc_f16_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * convert_op_out)314 enum xnn_status xnn_create_convert_nc_f16_f32(
315   size_t channels,
316   size_t input_stride,
317   size_t output_stride,
318   uint32_t flags,
319   xnn_operator_t* convert_op_out)
320 {
321   union xnn_f16_f32_cvt_params params;
322   if (xnn_params.vcvt.f16_to_f32.init.f16_f32_cvt != NULL) {
323     xnn_params.vcvt.f16_to_f32.init.f16_f32_cvt(&params);
324   }
325   return create_unary_elementwise_nc(
326     channels, input_stride, output_stride, flags,
327     &params, sizeof(params),
328     xnn_operator_type_convert_nc_f16_f32,
329     xnn_params.vcvt.f16_to_f32.ukernel,
330     convert_op_out);
331 }
332 
xnn_create_convert_nc_f32_f16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * convert_op_out)333 enum xnn_status xnn_create_convert_nc_f32_f16(
334   size_t channels,
335   size_t input_stride,
336   size_t output_stride,
337   uint32_t flags,
338   xnn_operator_t* convert_op_out)
339 {
340   union xnn_f32_f16_cvt_params params;
341   if (xnn_params.vcvt.f32_to_f16.init.f32_f16_cvt != NULL) {
342     xnn_params.vcvt.f32_to_f16.init.f32_f16_cvt(&params);
343   }
344   return create_unary_elementwise_nc(
345     channels, input_stride, output_stride, flags,
346     &params, sizeof(params),
347     xnn_operator_type_convert_nc_f32_f16,
348     xnn_params.vcvt.f32_to_f16.ukernel,
349     convert_op_out);
350 }
351 
xnn_create_convert_nc_f32_qs8(size_t channels,size_t input_stride,size_t output_stride,float output_scale,int8_t output_zero_point,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * convert_op_out)352 enum xnn_status xnn_create_convert_nc_f32_qs8(
353   size_t channels,
354   size_t input_stride,
355   size_t output_stride,
356   float output_scale,
357   int8_t output_zero_point,
358   int8_t output_min,
359   int8_t output_max,
360   uint32_t flags,
361   xnn_operator_t* convert_op_out)
362 {
363   if (output_scale <= 0.0f || !isnormal(output_scale)) {
364     xnn_log_error(
365       "failed to create %s operator with %.7g output scale parameter: scale must be finite, normalized, and positive",
366       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8), output_scale);
367     return xnn_status_invalid_parameter;
368   }
369 
370   if (output_min >= output_max) {
371     xnn_log_error(
372       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
373       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8), output_min, output_max);
374     return xnn_status_invalid_parameter;
375   }
376 
377   union xnn_f32_qs8_cvt_params params;
378   if (xnn_params.vcvt.f32_to_qs8.init.f32_qs8_cvt != NULL) {
379     xnn_params.vcvt.f32_to_qs8.init.f32_qs8_cvt(&params, 1.0f / output_scale, output_zero_point, output_min, output_max);
380   }
381   return create_unary_elementwise_nc(
382     channels, input_stride, output_stride, flags,
383     &params, sizeof(params),
384     xnn_operator_type_convert_nc_f32_qs8,
385     xnn_params.vcvt.f32_to_qs8.ukernel,
386     convert_op_out);
387 }
388 
xnn_create_convert_nc_f32_qu8(size_t channels,size_t input_stride,size_t output_stride,float output_scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * convert_op_out)389 enum xnn_status xnn_create_convert_nc_f32_qu8(
390   size_t channels,
391   size_t input_stride,
392   size_t output_stride,
393   float output_scale,
394   uint8_t output_zero_point,
395   uint8_t output_min,
396   uint8_t output_max,
397   uint32_t flags,
398   xnn_operator_t* convert_op_out)
399 {
400   if (output_scale <= 0.0f || !isnormal(output_scale)) {
401     xnn_log_error(
402       "failed to create %s operator with %.7g output scale parameter: scale must be finite, normalized, and positive",
403       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8), output_scale);
404     return xnn_status_invalid_parameter;
405   }
406 
407   if (output_min >= output_max) {
408     xnn_log_error(
409       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
410       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8), output_min, output_max);
411     return xnn_status_invalid_parameter;
412   }
413 
414   union xnn_f32_qu8_cvt_params params;
415   if (xnn_params.vcvt.f32_to_qu8.init.f32_qu8_cvt != NULL) {
416     xnn_params.vcvt.f32_to_qu8.init.f32_qu8_cvt(&params, 1.0f / output_scale, output_zero_point, output_min, output_max);
417   }
418   return create_unary_elementwise_nc(
419     channels, input_stride, output_stride, flags,
420     &params, sizeof(params),
421     xnn_operator_type_convert_nc_f32_qu8,
422     xnn_params.vcvt.f32_to_qu8.ukernel,
423     convert_op_out);
424 }
425 
xnn_create_convert_nc_qs8_f32(size_t channels,size_t input_stride,size_t output_stride,float input_scale,int8_t input_zero_point,uint32_t flags,xnn_operator_t * convert_op_out)426 enum xnn_status xnn_create_convert_nc_qs8_f32(
427   size_t channels,
428   size_t input_stride,
429   size_t output_stride,
430   float input_scale,
431   int8_t input_zero_point,
432   uint32_t flags,
433   xnn_operator_t* convert_op_out)
434 {
435   if (input_scale <= 0.0f || !isnormal(input_scale)) {
436     xnn_log_error(
437       "failed to create %s operator with %.7g input scale parameter: scale must be finite, normalized, and positive",
438       xnn_operator_type_to_string(xnn_operator_type_convert_nc_qs8_f32), input_scale);
439     return xnn_status_invalid_parameter;
440   }
441 
442   union xnn_qs8_f32_cvt_params params;
443   if (xnn_params.vcvt.qs8_to_f32.init.qs8_f32_cvt != NULL) {
444     xnn_params.vcvt.qs8_to_f32.init.qs8_f32_cvt(&params, input_scale, input_zero_point);
445   }
446   return create_unary_elementwise_nc(
447     channels, input_stride, output_stride, flags,
448     &params, sizeof(params),
449     xnn_operator_type_convert_nc_qs8_f32,
450     xnn_params.vcvt.qs8_to_f32.ukernel,
451     convert_op_out);
452 }
453 
xnn_create_convert_nc_qu8_f32(size_t channels,size_t input_stride,size_t output_stride,float input_scale,uint8_t input_zero_point,uint32_t flags,xnn_operator_t * convert_op_out)454 enum xnn_status xnn_create_convert_nc_qu8_f32(
455   size_t channels,
456   size_t input_stride,
457   size_t output_stride,
458   float input_scale,
459   uint8_t input_zero_point,
460   uint32_t flags,
461   xnn_operator_t* convert_op_out)
462 {
463   if (input_scale <= 0.0f || !isnormal(input_scale)) {
464     xnn_log_error(
465       "failed to create %s operator with %.7g input scale parameter: scale must be finite, normalized, and positive",
466       xnn_operator_type_to_string(xnn_operator_type_convert_nc_qu8_f32), input_scale);
467     return xnn_status_invalid_parameter;
468   }
469 
470   union xnn_qu8_f32_cvt_params params;
471   if (xnn_params.vcvt.qu8_to_f32.init.qu8_f32_cvt != NULL) {
472     xnn_params.vcvt.qu8_to_f32.init.qu8_f32_cvt(&params, input_scale, input_zero_point);
473   }
474   return create_unary_elementwise_nc(
475     channels, input_stride, output_stride, flags,
476     &params, sizeof(params),
477     xnn_operator_type_convert_nc_qu8_f32,
478     xnn_params.vcvt.qu8_to_f32.ukernel,
479     convert_op_out);
480 }
481 
xnn_create_copy_nc_x8(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)482 enum xnn_status xnn_create_copy_nc_x8(
483     size_t channels,
484     size_t input_stride,
485     size_t output_stride,
486     uint32_t flags,
487     xnn_operator_t* copy_op_out)
488 {
489   return create_unary_elementwise_nc(
490     channels, input_stride, output_stride, flags,
491     NULL, 0,
492     xnn_operator_type_copy_nc_x8,
493     xnn_params.xx.copy,
494     copy_op_out);
495 }
496 
xnn_create_copy_nc_x16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)497 enum xnn_status xnn_create_copy_nc_x16(
498     size_t channels,
499     size_t input_stride,
500     size_t output_stride,
501     uint32_t flags,
502     xnn_operator_t* copy_op_out)
503 {
504   return create_unary_elementwise_nc(
505     channels, input_stride, output_stride, flags,
506     NULL, 0,
507     xnn_operator_type_copy_nc_x16,
508     xnn_params.xx.copy,
509     copy_op_out);
510 }
511 
xnn_create_copy_nc_x32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)512 enum xnn_status xnn_create_copy_nc_x32(
513     size_t channels,
514     size_t input_stride,
515     size_t output_stride,
516     uint32_t flags,
517     xnn_operator_t* copy_op_out)
518 {
519   return create_unary_elementwise_nc(
520     channels, input_stride, output_stride, flags,
521     NULL, 0,
522     xnn_operator_type_copy_nc_x32,
523     xnn_params.xx.copy,
524     copy_op_out);
525 }
526 
xnn_create_elu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float alpha,uint32_t flags,xnn_operator_t * elu_op_out)527 enum xnn_status xnn_create_elu_nc_f32(
528   size_t channels,
529   size_t input_stride,
530   size_t output_stride,
531   float alpha,
532   uint32_t flags,
533   xnn_operator_t* elu_op_out)
534 {
535   if (alpha <= 0.0f || !isnormal(alpha)) {
536     xnn_log_error(
537       "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
538       xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32), alpha);
539     return xnn_status_invalid_parameter;
540   }
541 
542   union xnn_f32_elu_params params;
543   if (xnn_params.f32.elu.init.f32_elu != NULL) {
544     xnn_params.f32.elu.init.f32_elu(&params, 1.0f /* prescale */, alpha, 1.0f /* beta */);
545   }
546   return create_unary_elementwise_nc(
547     channels, input_stride, output_stride, flags,
548     &params, sizeof(params),
549     xnn_operator_type_elu_nc_f32,
550     xnn_params.f32.elu.ukernel,
551     elu_op_out);
552 }
553 
xnn_create_floor_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * floor_op_out)554 enum xnn_status xnn_create_floor_nc_f32(
555     size_t channels,
556     size_t input_stride,
557     size_t output_stride,
558     uint32_t flags,
559     xnn_operator_t* floor_op_out)
560 {
561   union xnn_f32_rnd_params params;
562   if (xnn_params.f32.rndd.init.f32_rnd != NULL) {
563     xnn_params.f32.rndd.init.f32_rnd(&params);
564   }
565   return create_unary_elementwise_nc(
566     channels, input_stride, output_stride, flags,
567     &params, sizeof(params),
568     xnn_operator_type_floor_nc_f32,
569     xnn_params.f32.rndd.ukernel,
570     floor_op_out);
571 }
572 
xnn_create_hardswish_nc_f16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)573 enum xnn_status xnn_create_hardswish_nc_f16(
574     size_t channels,
575     size_t input_stride,
576     size_t output_stride,
577     uint32_t flags,
578     xnn_operator_t* hardswish_op_out)
579 {
580   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
581     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
582       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
583     return xnn_status_uninitialized;
584   }
585 
586   if ((xnn_params.init_flags & XNN_INIT_FLAG_F16) != XNN_INIT_FLAG_F16) {
587     xnn_log_error("failed to create %s operator: operations on data type are not supported",
588       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
589     return xnn_status_unsupported_hardware;
590   }
591 
592   union xnn_f16_hswish_params params;
593   if (xnn_params.f16.hswish.init.f16_hswish != NULL) {
594     xnn_params.f16.hswish.init.f16_hswish(&params);
595   }
596   return create_unary_elementwise_nc(
597     channels, input_stride, output_stride, flags,
598     &params, sizeof(params),
599     xnn_operator_type_hardswish_nc_f16,
600     xnn_params.f16.hswish.ukernel,
601     hardswish_op_out);
602 }
603 
xnn_create_hardswish_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)604 enum xnn_status xnn_create_hardswish_nc_f32(
605     size_t channels,
606     size_t input_stride,
607     size_t output_stride,
608     uint32_t flags,
609     xnn_operator_t* hardswish_op_out)
610 {
611   union xnn_f32_hswish_params params;
612   if (xnn_params.f32.hswish.init.f32_hswish != NULL) {
613     xnn_params.f32.hswish.init.f32_hswish(&params);
614   }
615   return create_unary_elementwise_nc(
616     channels, input_stride, output_stride, flags,
617     &params, sizeof(params),
618     xnn_operator_type_hardswish_nc_f32,
619     xnn_params.f32.hswish.ukernel,
620     hardswish_op_out);
621 }
622 
xnn_create_leaky_relu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float negative_slope,uint32_t flags,xnn_operator_t * leaky_relu_op_out)623 enum xnn_status xnn_create_leaky_relu_nc_f32(
624   size_t channels,
625   size_t input_stride,
626   size_t output_stride,
627   float negative_slope,
628   uint32_t flags,
629   xnn_operator_t* leaky_relu_op_out)
630 {
631   if (!isfinite(negative_slope)) {
632     xnn_log_error(
633       "failed to create %s operator with %f negative slope: finite number expected",
634       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
635       negative_slope);
636     return xnn_status_invalid_parameter;
637   }
638 
639   union xnn_f32_lrelu_params params;
640   if (xnn_params.f32.lrelu.init.f32_lrelu != NULL) {
641     xnn_params.f32.lrelu.init.f32_lrelu(&params, negative_slope);
642   }
643   return create_unary_elementwise_nc(
644     channels, input_stride, output_stride, flags,
645     &params, sizeof(params),
646     xnn_operator_type_leaky_relu_nc_f32,
647     xnn_params.f32.lrelu.ukernel,
648     leaky_relu_op_out);
649 }
650 
xnn_create_negate_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * negate_op_out)651 enum xnn_status xnn_create_negate_nc_f32(
652     size_t channels,
653     size_t input_stride,
654     size_t output_stride,
655     uint32_t flags,
656     xnn_operator_t* negate_op_out)
657 {
658   union xnn_f32_neg_params params;
659   if (xnn_params.f32.neg.init.f32_neg != NULL) {
660     xnn_params.f32.neg.init.f32_neg(&params);
661   }
662   return create_unary_elementwise_nc(
663     channels, input_stride, output_stride, flags,
664     &params, sizeof(params),
665     xnn_operator_type_negate_nc_f32,
666     xnn_params.f32.neg.ukernel,
667     negate_op_out);
668 }
669 
xnn_create_sigmoid_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sigmoid_op_out)670 enum xnn_status xnn_create_sigmoid_nc_f32(
671     size_t channels,
672     size_t input_stride,
673     size_t output_stride,
674     uint32_t flags,
675     xnn_operator_t* sigmoid_op_out)
676 {
677   union xnn_f32_sigmoid_params params;
678   if (xnn_params.f32.sigmoid.init.f32_sigmoid != NULL) {
679     xnn_params.f32.sigmoid.init.f32_sigmoid(&params);
680   }
681   return create_unary_elementwise_nc(
682     channels, input_stride, output_stride, flags,
683     &params, sizeof(params),
684     xnn_operator_type_sigmoid_nc_f32,
685     xnn_params.f32.sigmoid.ukernel,
686     sigmoid_op_out);
687 }
688 
xnn_create_square_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * square_op_out)689 enum xnn_status xnn_create_square_nc_f32(
690     size_t channels,
691     size_t input_stride,
692     size_t output_stride,
693     uint32_t flags,
694     xnn_operator_t* square_op_out)
695 {
696   union xnn_f32_default_params params;
697   if (xnn_params.f32.sqr.init.f32_default != NULL) {
698     xnn_params.f32.sqr.init.f32_default(&params);
699   }
700   return create_unary_elementwise_nc(
701     channels, input_stride, output_stride, flags,
702     &params, sizeof(params),
703     xnn_operator_type_square_nc_f32,
704     xnn_params.f32.sqr.ukernel,
705     square_op_out);
706 }
707 
xnn_create_square_root_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sqrt_op_out)708 enum xnn_status xnn_create_square_root_nc_f32(
709     size_t channels,
710     size_t input_stride,
711     size_t output_stride,
712     uint32_t flags,
713     xnn_operator_t* sqrt_op_out)
714 {
715   union xnn_f32_sqrt_params params;
716   if (xnn_params.f32.sqrt.init.f32_sqrt != NULL) {
717     xnn_params.f32.sqrt.init.f32_sqrt(&params);
718   }
719   return create_unary_elementwise_nc(
720     channels, input_stride, output_stride, flags,
721     &params, sizeof(params),
722     xnn_operator_type_square_root_nc_f32,
723     xnn_params.f32.sqrt.ukernel,
724     sqrt_op_out);
725 }
726 
xnn_create_truncation_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * truncation_op_out)727 enum xnn_status xnn_create_truncation_nc_f32(
728     size_t channels,
729     size_t input_stride,
730     size_t output_stride,
731     uint32_t flags,
732     xnn_operator_t* truncation_op_out)
733 {
734   union xnn_f32_rnd_params params;
735   if (xnn_params.f32.rndz.init.f32_rnd != NULL) {
736     xnn_params.f32.rndz.init.f32_rnd(&params);
737   }
738   return create_unary_elementwise_nc(
739     channels, input_stride, output_stride, flags,
740     &params, sizeof(params),
741     xnn_operator_type_truncation_nc_f32,
742     xnn_params.f32.rndz.ukernel,
743     truncation_op_out);
744 }
745 
xnn_setup_abs_nc_f32(xnn_operator_t abs_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)746 enum xnn_status xnn_setup_abs_nc_f32(
747     xnn_operator_t abs_op,
748     size_t batch_size,
749     const float* input,
750     float* output,
751     pthreadpool_t threadpool)
752 {
753   if (abs_op->type != xnn_operator_type_abs_nc_f32) {
754     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
755       xnn_operator_type_to_string(xnn_operator_type_abs_nc_f32),
756       xnn_operator_type_to_string(abs_op->type));
757     return xnn_status_invalid_parameter;
758   }
759   abs_op->state = xnn_run_state_invalid;
760 
761   return setup_unary_elementwise_nc(
762     abs_op,
763     batch_size, input, output,
764     2 /* log2(sizeof(float)) */,
765     2 /* log2(sizeof(float)) */,
766     &abs_op->params.f32_abs, sizeof(abs_op->params.f32_abs),
767     pthreadpool_get_threads_count(threadpool));
768 }
769 
xnn_setup_bankers_rounding_nc_f32(xnn_operator_t rounding_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)770 enum xnn_status xnn_setup_bankers_rounding_nc_f32(
771     xnn_operator_t rounding_op,
772     size_t batch_size,
773     const float* input,
774     float* output,
775     pthreadpool_t threadpool)
776 {
777   if (rounding_op->type != xnn_operator_type_bankers_rounding_nc_f32) {
778     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
779       xnn_operator_type_to_string(xnn_operator_type_bankers_rounding_nc_f32),
780       xnn_operator_type_to_string(rounding_op->type));
781     return xnn_status_invalid_parameter;
782   }
783   rounding_op->state = xnn_run_state_invalid;
784 
785   return setup_unary_elementwise_nc(
786     rounding_op,
787     batch_size, input, output,
788     2 /* log2(sizeof(float)) */,
789     2 /* log2(sizeof(float)) */,
790     &rounding_op->params.f32_rnd, sizeof(rounding_op->params.f32_rnd),
791     pthreadpool_get_threads_count(threadpool));
792 }
793 
xnn_setup_ceiling_nc_f32(xnn_operator_t ceiling_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)794 enum xnn_status xnn_setup_ceiling_nc_f32(
795     xnn_operator_t ceiling_op,
796     size_t batch_size,
797     const float* input,
798     float* output,
799     pthreadpool_t threadpool)
800 {
801   if (ceiling_op->type != xnn_operator_type_ceiling_nc_f32) {
802     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
803       xnn_operator_type_to_string(xnn_operator_type_ceiling_nc_f32),
804       xnn_operator_type_to_string(ceiling_op->type));
805     return xnn_status_invalid_parameter;
806   }
807   ceiling_op->state = xnn_run_state_invalid;
808 
809   return setup_unary_elementwise_nc(
810     ceiling_op,
811     batch_size, input, output,
812     2 /* log2(sizeof(float)) */,
813     2 /* log2(sizeof(float)) */,
814     &ceiling_op->params.f32_rnd, sizeof(ceiling_op->params.f32_rnd),
815     pthreadpool_get_threads_count(threadpool));
816 }
817 
xnn_setup_clamp_nc_s8(xnn_operator_t clamp_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)818 enum xnn_status xnn_setup_clamp_nc_s8(
819     xnn_operator_t clamp_op,
820     size_t batch_size,
821     const int8_t* input,
822     int8_t* output,
823     pthreadpool_t threadpool)
824 {
825   if (clamp_op->type != xnn_operator_type_clamp_nc_s8) {
826     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
827       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8),
828       xnn_operator_type_to_string(clamp_op->type));
829     return xnn_status_invalid_parameter;
830   }
831   clamp_op->state = xnn_run_state_invalid;
832 
833   return setup_unary_elementwise_nc(
834     clamp_op,
835     batch_size, input, output,
836     0 /* log2(sizeof(int8_t)) */,
837     0 /* log2(sizeof(int8_t)) */,
838     &clamp_op->params.s8_minmax, sizeof(clamp_op->params.s8_minmax),
839     pthreadpool_get_threads_count(threadpool));
840 }
841 
xnn_setup_clamp_nc_u8(xnn_operator_t clamp_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)842 enum xnn_status xnn_setup_clamp_nc_u8(
843     xnn_operator_t clamp_op,
844     size_t batch_size,
845     const uint8_t* input,
846     uint8_t* output,
847     pthreadpool_t threadpool)
848 {
849   if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
850     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
851       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
852       xnn_operator_type_to_string(clamp_op->type));
853     return xnn_status_invalid_parameter;
854   }
855   clamp_op->state = xnn_run_state_invalid;
856 
857   return setup_unary_elementwise_nc(
858     clamp_op,
859     batch_size, input, output,
860     0 /* log2(sizeof(uint8_t)) */,
861     0 /* log2(sizeof(uint8_t)) */,
862     &clamp_op->params.u8_minmax, sizeof(clamp_op->params.u8_minmax),
863     pthreadpool_get_threads_count(threadpool));
864 }
865 
xnn_setup_clamp_nc_f32(xnn_operator_t clamp_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)866 enum xnn_status xnn_setup_clamp_nc_f32(
867     xnn_operator_t clamp_op,
868     size_t batch_size,
869     const float* input,
870     float* output,
871     pthreadpool_t threadpool)
872 {
873   if (clamp_op->type != xnn_operator_type_clamp_nc_f32) {
874     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
875       xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32),
876       xnn_operator_type_to_string(clamp_op->type));
877     return xnn_status_invalid_parameter;
878   }
879   clamp_op->state = xnn_run_state_invalid;
880 
881   return setup_unary_elementwise_nc(
882     clamp_op,
883     batch_size, input, output,
884     2 /* log2(sizeof(float)) */,
885     2 /* log2(sizeof(float)) */,
886     &clamp_op->params.f32_minmax, sizeof(clamp_op->params.f32_minmax),
887     pthreadpool_get_threads_count(threadpool));
888 }
889 
xnn_setup_convert_nc_f16_f32(xnn_operator_t convert_op,size_t batch_size,const void * input,float * output,pthreadpool_t threadpool)890 enum xnn_status xnn_setup_convert_nc_f16_f32(
891   xnn_operator_t convert_op,
892   size_t batch_size,
893   const void* input,
894   float* output,
895   pthreadpool_t threadpool)
896 {
897   if (convert_op->type != xnn_operator_type_convert_nc_f16_f32) {
898     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
899       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f16_f32),
900       xnn_operator_type_to_string(convert_op->type));
901     return xnn_status_invalid_parameter;
902   }
903   convert_op->state = xnn_run_state_invalid;
904 
905   return setup_unary_elementwise_nc(
906     convert_op,
907     batch_size, input, output,
908     1 /* log2(sizeof(uint16_t)) */,
909     2 /* log2(sizeof(float)) */,
910     &convert_op->params.f16_f32_cvt, sizeof(convert_op->params.f16_f32_cvt),
911     pthreadpool_get_threads_count(threadpool));
912 }
913 
xnn_setup_convert_nc_f32_f16(xnn_operator_t convert_op,size_t batch_size,const float * input,void * output,pthreadpool_t threadpool)914 enum xnn_status xnn_setup_convert_nc_f32_f16(
915   xnn_operator_t convert_op,
916   size_t batch_size,
917   const float* input,
918   void* output,
919   pthreadpool_t threadpool)
920 {
921   if (convert_op->type != xnn_operator_type_convert_nc_f32_f16) {
922     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
923       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_f16),
924       xnn_operator_type_to_string(convert_op->type));
925     return xnn_status_invalid_parameter;
926   }
927   convert_op->state = xnn_run_state_invalid;
928 
929   return setup_unary_elementwise_nc(
930     convert_op,
931     batch_size, input, output,
932     2 /* log2(sizeof(float)) */,
933     1 /* log2(sizeof(uint16_t)) */,
934     &convert_op->params.f32_f16_cvt, sizeof(convert_op->params.f32_f16_cvt),
935     pthreadpool_get_threads_count(threadpool));
936 }
937 
xnn_setup_convert_nc_f32_qs8(xnn_operator_t convert_op,size_t batch_size,const float * input,int8_t * output,pthreadpool_t threadpool)938 enum xnn_status xnn_setup_convert_nc_f32_qs8(
939   xnn_operator_t convert_op,
940   size_t batch_size,
941   const float* input,
942   int8_t* output,
943   pthreadpool_t threadpool)
944 {
945   if (convert_op->type != xnn_operator_type_convert_nc_f32_qs8) {
946     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
947       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8),
948       xnn_operator_type_to_string(convert_op->type));
949     return xnn_status_invalid_parameter;
950   }
951   convert_op->state = xnn_run_state_invalid;
952 
953   return setup_unary_elementwise_nc(
954     convert_op,
955     batch_size, input, output,
956     2 /* log2(sizeof(float)) */,
957     0 /* log2(sizeof(int8_t)) */,
958     &convert_op->params.f32_qs8_cvt, sizeof(convert_op->params.f32_qs8_cvt),
959     pthreadpool_get_threads_count(threadpool));
960 }
961 
xnn_setup_convert_nc_f32_qu8(xnn_operator_t convert_op,size_t batch_size,const float * input,uint8_t * output,pthreadpool_t threadpool)962 enum xnn_status xnn_setup_convert_nc_f32_qu8(
963   xnn_operator_t convert_op,
964   size_t batch_size,
965   const float* input,
966   uint8_t* output,
967   pthreadpool_t threadpool)
968 {
969   if (convert_op->type != xnn_operator_type_convert_nc_f32_qu8) {
970     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
971       xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8),
972       xnn_operator_type_to_string(convert_op->type));
973     return xnn_status_invalid_parameter;
974   }
975   convert_op->state = xnn_run_state_invalid;
976 
977   return setup_unary_elementwise_nc(
978     convert_op,
979     batch_size, input, output,
980     2 /* log2(sizeof(float)) */,
981     0 /* log2(sizeof(uint8_t)) */,
982     &convert_op->params.f32_qu8_cvt, sizeof(convert_op->params.f32_qu8_cvt),
983     pthreadpool_get_threads_count(threadpool));
984 }
985 
xnn_setup_convert_nc_qs8_f32(xnn_operator_t convert_op,size_t batch_size,const int8_t * input,float * output,pthreadpool_t threadpool)986 enum xnn_status xnn_setup_convert_nc_qs8_f32(
987   xnn_operator_t convert_op,
988   size_t batch_size,
989   const int8_t* input,
990   float* output,
991   pthreadpool_t threadpool)
992 {
993   if (convert_op->type != xnn_operator_type_convert_nc_qs8_f32) {
994     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
995       xnn_operator_type_to_string(xnn_operator_type_convert_nc_qs8_f32),
996       xnn_operator_type_to_string(convert_op->type));
997     return xnn_status_invalid_parameter;
998   }
999   convert_op->state = xnn_run_state_invalid;
1000 
1001   return setup_unary_elementwise_nc(
1002     convert_op,
1003     batch_size, input, output,
1004     0 /* log2(sizeof(int8_t)) */,
1005     2 /* log2(sizeof(float)) */,
1006     &convert_op->params.qs8_f32_cvt, sizeof(convert_op->params.qs8_f32_cvt),
1007     pthreadpool_get_threads_count(threadpool));
1008 }
1009 
xnn_setup_convert_nc_qu8_f32(xnn_operator_t convert_op,size_t batch_size,const uint8_t * input,float * output,pthreadpool_t threadpool)1010 enum xnn_status xnn_setup_convert_nc_qu8_f32(
1011   xnn_operator_t convert_op,
1012   size_t batch_size,
1013   const uint8_t* input,
1014   float* output,
1015   pthreadpool_t threadpool)
1016 {
1017   if (convert_op->type != xnn_operator_type_convert_nc_qu8_f32) {
1018     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1019       xnn_operator_type_to_string(xnn_operator_type_convert_nc_qu8_f32),
1020       xnn_operator_type_to_string(convert_op->type));
1021     return xnn_status_invalid_parameter;
1022   }
1023   convert_op->state = xnn_run_state_invalid;
1024 
1025   return setup_unary_elementwise_nc(
1026     convert_op,
1027     batch_size, input, output,
1028     0 /* log2(sizeof(uint8_t)) */,
1029     2 /* log2(sizeof(float)) */,
1030     &convert_op->params.qu8_f32_cvt, sizeof(convert_op->params.qu8_f32_cvt),
1031     pthreadpool_get_threads_count(threadpool));
1032 }
1033 
xnn_setup_copy_nc_x8(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1034 enum xnn_status xnn_setup_copy_nc_x8(
1035     xnn_operator_t copy_op,
1036     size_t batch_size,
1037     const void* input,
1038     void* output,
1039     pthreadpool_t threadpool)
1040 {
1041   if (copy_op->type != xnn_operator_type_copy_nc_x8) {
1042     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1043       xnn_operator_type_to_string(xnn_operator_type_copy_nc_x8),
1044       xnn_operator_type_to_string(copy_op->type));
1045     return xnn_status_invalid_parameter;
1046   }
1047   copy_op->state = xnn_run_state_invalid;
1048 
1049   return setup_unary_elementwise_nc(
1050     copy_op,
1051     batch_size, input, output,
1052     0 /* log2(sizeof(uint16_t)) */,
1053     0 /* log2(sizeof(uint16_t)) */,
1054     NULL, 0,
1055     pthreadpool_get_threads_count(threadpool));
1056 }
1057 
xnn_setup_copy_nc_x16(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1058 enum xnn_status xnn_setup_copy_nc_x16(
1059     xnn_operator_t copy_op,
1060     size_t batch_size,
1061     const void* input,
1062     void* output,
1063     pthreadpool_t threadpool)
1064 {
1065   if (copy_op->type != xnn_operator_type_copy_nc_x16) {
1066     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1067       xnn_operator_type_to_string(xnn_operator_type_copy_nc_x16),
1068       xnn_operator_type_to_string(copy_op->type));
1069     return xnn_status_invalid_parameter;
1070   }
1071   copy_op->state = xnn_run_state_invalid;
1072 
1073   return setup_unary_elementwise_nc(
1074     copy_op,
1075     batch_size, input, output,
1076     1 /* log2(sizeof(uint16_t)) */,
1077     1 /* log2(sizeof(uint16_t)) */,
1078     NULL, 0,
1079     pthreadpool_get_threads_count(threadpool));
1080 }
1081 
xnn_setup_copy_nc_x32(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1082 enum xnn_status xnn_setup_copy_nc_x32(
1083     xnn_operator_t copy_op,
1084     size_t batch_size,
1085     const void* input,
1086     void* output,
1087     pthreadpool_t threadpool)
1088 {
1089   if (copy_op->type != xnn_operator_type_copy_nc_x32) {
1090     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1091       xnn_operator_type_to_string(xnn_operator_type_copy_nc_x32),
1092       xnn_operator_type_to_string(copy_op->type));
1093     return xnn_status_invalid_parameter;
1094   }
1095   copy_op->state = xnn_run_state_invalid;
1096 
1097   return setup_unary_elementwise_nc(
1098     copy_op,
1099     batch_size, input, output,
1100     2 /* log2(sizeof(uint32_t)) */,
1101     2 /* log2(sizeof(uint32_t)) */,
1102     NULL, 0,
1103     pthreadpool_get_threads_count(threadpool));
1104 }
1105 
xnn_setup_elu_nc_f32(xnn_operator_t elu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1106 enum xnn_status xnn_setup_elu_nc_f32(
1107     xnn_operator_t elu_op,
1108     size_t batch_size,
1109     const float* input,
1110     float* output,
1111     pthreadpool_t threadpool)
1112 {
1113   if (elu_op->type != xnn_operator_type_elu_nc_f32) {
1114     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1115       xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32),
1116       xnn_operator_type_to_string(elu_op->type));
1117     return xnn_status_invalid_parameter;
1118   }
1119   elu_op->state = xnn_run_state_invalid;
1120 
1121   return setup_unary_elementwise_nc(
1122     elu_op,
1123     batch_size, input, output,
1124     2 /* log2(sizeof(float)) */,
1125     2 /* log2(sizeof(float)) */,
1126     &elu_op->params.f32_elu, sizeof(elu_op->params.f32_elu),
1127     pthreadpool_get_threads_count(threadpool));
1128 }
1129 
xnn_setup_floor_nc_f32(xnn_operator_t floor_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1130 enum xnn_status xnn_setup_floor_nc_f32(
1131     xnn_operator_t floor_op,
1132     size_t batch_size,
1133     const float* input,
1134     float* output,
1135     pthreadpool_t threadpool)
1136 {
1137   if (floor_op->type != xnn_operator_type_floor_nc_f32) {
1138     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1139       xnn_operator_type_to_string(xnn_operator_type_floor_nc_f32),
1140       xnn_operator_type_to_string(floor_op->type));
1141     return xnn_status_invalid_parameter;
1142   }
1143   floor_op->state = xnn_run_state_invalid;
1144 
1145   return setup_unary_elementwise_nc(
1146     floor_op,
1147     batch_size, input, output,
1148     2 /* log2(sizeof(float)) */,
1149     2 /* log2(sizeof(float)) */,
1150     &floor_op->params.f32_rnd, sizeof(floor_op->params.f32_rnd),
1151     pthreadpool_get_threads_count(threadpool));
1152 }
1153 
xnn_setup_hardswish_nc_f16(xnn_operator_t hardswish_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1154 enum xnn_status xnn_setup_hardswish_nc_f16(
1155     xnn_operator_t hardswish_op,
1156     size_t batch_size,
1157     const void* input,
1158     void* output,
1159     pthreadpool_t threadpool)
1160 {
1161   if (hardswish_op->type != xnn_operator_type_hardswish_nc_f16) {
1162     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1163       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16),
1164       xnn_operator_type_to_string(hardswish_op->type));
1165     return xnn_status_invalid_parameter;
1166   }
1167   hardswish_op->state = xnn_run_state_invalid;
1168 
1169   return setup_unary_elementwise_nc(
1170     hardswish_op,
1171     batch_size, input, output,
1172     1 /* log2(sizeof(half)) */,
1173     1 /* log2(sizeof(half)) */,
1174     &hardswish_op->params.f16_hswish, sizeof(hardswish_op->params.f16_hswish),
1175     pthreadpool_get_threads_count(threadpool));
1176 }
1177 
xnn_setup_hardswish_nc_f32(xnn_operator_t hardswish_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1178 enum xnn_status xnn_setup_hardswish_nc_f32(
1179     xnn_operator_t hardswish_op,
1180     size_t batch_size,
1181     const float* input,
1182     float* output,
1183     pthreadpool_t threadpool)
1184 {
1185   if (hardswish_op->type != xnn_operator_type_hardswish_nc_f32) {
1186     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1187       xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f32),
1188       xnn_operator_type_to_string(hardswish_op->type));
1189     return xnn_status_invalid_parameter;
1190   }
1191   hardswish_op->state = xnn_run_state_invalid;
1192 
1193   return setup_unary_elementwise_nc(
1194     hardswish_op,
1195     batch_size, input, output,
1196     2 /* log2(sizeof(float)) */,
1197     2 /* log2(sizeof(float)) */,
1198     &hardswish_op->params.f32_hswish, sizeof(hardswish_op->params.f32_hswish),
1199     pthreadpool_get_threads_count(threadpool));
1200 }
1201 
xnn_setup_leaky_relu_nc_f32(xnn_operator_t leaky_relu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1202 enum xnn_status xnn_setup_leaky_relu_nc_f32(
1203   xnn_operator_t leaky_relu_op,
1204   size_t batch_size,
1205   const float* input,
1206   float* output,
1207   pthreadpool_t threadpool)
1208 {
1209   if (leaky_relu_op->type != xnn_operator_type_leaky_relu_nc_f32) {
1210     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1211       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
1212       xnn_operator_type_to_string(leaky_relu_op->type));
1213     return xnn_status_invalid_parameter;
1214   }
1215   leaky_relu_op->state = xnn_run_state_invalid;
1216 
1217   return setup_unary_elementwise_nc(
1218     leaky_relu_op,
1219     batch_size, input, output,
1220     2 /* log2(sizeof(float)) */,
1221     2 /* log2(sizeof(float)) */,
1222     &leaky_relu_op->params.f32_lrelu, sizeof(leaky_relu_op->params.f32_lrelu),
1223     pthreadpool_get_threads_count(threadpool));
1224 }
1225 
xnn_setup_negate_nc_f32(xnn_operator_t negate_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1226 enum xnn_status xnn_setup_negate_nc_f32(
1227     xnn_operator_t negate_op,
1228     size_t batch_size,
1229     const float* input,
1230     float* output,
1231     pthreadpool_t threadpool)
1232 {
1233   if (negate_op->type != xnn_operator_type_negate_nc_f32) {
1234     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1235       xnn_operator_type_to_string(xnn_operator_type_negate_nc_f32),
1236       xnn_operator_type_to_string(negate_op->type));
1237     return xnn_status_invalid_parameter;
1238   }
1239   negate_op->state = xnn_run_state_invalid;
1240 
1241   return setup_unary_elementwise_nc(
1242     negate_op,
1243     batch_size, input, output,
1244     2 /* log2(sizeof(float)) */,
1245     2 /* log2(sizeof(float)) */,
1246     &negate_op->params.f32_neg, sizeof(negate_op->params.f32_neg),
1247     pthreadpool_get_threads_count(threadpool));
1248 }
1249 
xnn_setup_sigmoid_nc_f32(xnn_operator_t sigmoid_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1250 enum xnn_status xnn_setup_sigmoid_nc_f32(
1251     xnn_operator_t sigmoid_op,
1252     size_t batch_size,
1253     const float* input,
1254     float* output,
1255     pthreadpool_t threadpool)
1256 {
1257   if (sigmoid_op->type != xnn_operator_type_sigmoid_nc_f32) {
1258     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1259       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_f32),
1260       xnn_operator_type_to_string(sigmoid_op->type));
1261     return xnn_status_invalid_parameter;
1262   }
1263   sigmoid_op->state = xnn_run_state_invalid;
1264 
1265   return setup_unary_elementwise_nc(
1266     sigmoid_op,
1267     batch_size, input, output,
1268     2 /* log2(sizeof(float)) */,
1269     2 /* log2(sizeof(float)) */,
1270     &sigmoid_op->params.f32_sigmoid, sizeof(sigmoid_op->params.f32_sigmoid),
1271     pthreadpool_get_threads_count(threadpool));
1272 }
1273 
xnn_setup_square_nc_f32(xnn_operator_t square_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1274 enum xnn_status xnn_setup_square_nc_f32(
1275     xnn_operator_t square_op,
1276     size_t batch_size,
1277     const float* input,
1278     float* output,
1279     pthreadpool_t threadpool)
1280 {
1281   if (square_op->type != xnn_operator_type_square_nc_f32) {
1282     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1283       xnn_operator_type_to_string(xnn_operator_type_square_nc_f32),
1284       xnn_operator_type_to_string(square_op->type));
1285     return xnn_status_invalid_parameter;
1286   }
1287   square_op->state = xnn_run_state_invalid;
1288 
1289   return setup_unary_elementwise_nc(
1290     square_op,
1291     batch_size, input, output,
1292     2 /* log2(sizeof(float)) */,
1293     2 /* log2(sizeof(float)) */,
1294     &square_op->params.f32_default, sizeof(square_op->params.f32_default),
1295     pthreadpool_get_threads_count(threadpool));
1296 }
1297 
xnn_setup_square_root_nc_f32(xnn_operator_t sqrt_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1298 enum xnn_status xnn_setup_square_root_nc_f32(
1299     xnn_operator_t sqrt_op,
1300     size_t batch_size,
1301     const float* input,
1302     float* output,
1303     pthreadpool_t threadpool)
1304 {
1305   if (sqrt_op->type != xnn_operator_type_square_root_nc_f32) {
1306     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1307       xnn_operator_type_to_string(xnn_operator_type_square_root_nc_f32),
1308       xnn_operator_type_to_string(sqrt_op->type));
1309     return xnn_status_invalid_parameter;
1310   }
1311   sqrt_op->state = xnn_run_state_invalid;
1312 
1313   return setup_unary_elementwise_nc(
1314     sqrt_op,
1315     batch_size, input, output,
1316     2 /* log2(sizeof(float)) */,
1317     2 /* log2(sizeof(float)) */,
1318     &sqrt_op->params.f32_sqrt, sizeof(sqrt_op->params.f32_sqrt),
1319     pthreadpool_get_threads_count(threadpool));
1320 }
1321 
xnn_setup_truncation_nc_f32(xnn_operator_t truncation_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1322 enum xnn_status xnn_setup_truncation_nc_f32(
1323     xnn_operator_t truncation_op,
1324     size_t batch_size,
1325     const float* input,
1326     float* output,
1327     pthreadpool_t threadpool)
1328 {
1329   if (truncation_op->type != xnn_operator_type_truncation_nc_f32) {
1330     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1331       xnn_operator_type_to_string(xnn_operator_type_truncation_nc_f32),
1332       xnn_operator_type_to_string(truncation_op->type));
1333     return xnn_status_invalid_parameter;
1334   }
1335   truncation_op->state = xnn_run_state_invalid;
1336 
1337   return setup_unary_elementwise_nc(
1338     truncation_op,
1339     batch_size, input, output,
1340     2 /* log2(sizeof(float)) */,
1341     2 /* log2(sizeof(float)) */,
1342     &truncation_op->params.f32_rnd, sizeof(truncation_op->params.f32_rnd),
1343     pthreadpool_get_threads_count(threadpool));
1344 }
1345