1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/log.h>
15 #include <xnnpack/operator.h>
16 #include <xnnpack/params-init.h>
17 #include <xnnpack/params.h>
18
19
create_unary_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,const void * params,size_t params_size,enum xnn_operator_type operator_type,xnn_univector_ukernel_function ukernel,xnn_operator_t * unary_elementwise_op_out)20 static enum xnn_status create_unary_elementwise_nc(
21 size_t channels,
22 size_t input_stride,
23 size_t output_stride,
24 uint32_t flags,
25 const void* params,
26 size_t params_size,
27 enum xnn_operator_type operator_type,
28 xnn_univector_ukernel_function ukernel,
29 xnn_operator_t* unary_elementwise_op_out)
30 {
31 xnn_operator_t unary_elementwise_op = NULL;
32
33 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
35 xnn_operator_type_to_string(operator_type));
36 return xnn_status_uninitialized;
37 }
38
39 if (channels == 0) {
40 xnn_log_error(
41 "failed to create %s operator with %zu channels: number of channels must be non-zero",
42 xnn_operator_type_to_string(operator_type), channels);
43 return xnn_status_invalid_parameter;
44 }
45
46 if (input_stride < channels) {
47 xnn_log_error(
48 "failed to create %s operator with input element stride of %zu: "
49 "stride must be at least as large as the number of channels (%zu)",
50 xnn_operator_type_to_string(operator_type), input_stride, channels);
51 return xnn_status_invalid_parameter;
52 }
53
54 if (output_stride < channels) {
55 xnn_log_error(
56 "failed to create %s operator with output element stride of %zu: "
57 "stride must be at least as large as the number of channels (%zu)",
58 xnn_operator_type_to_string(operator_type), output_stride, channels);
59 return xnn_status_invalid_parameter;
60 }
61
62 unary_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
63 if (unary_elementwise_op == NULL) {
64 xnn_log_error(
65 "failed to allocate %zu bytes for %s operator descriptor",
66 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
67 return xnn_status_out_of_memory;
68 }
69
70 unary_elementwise_op->channels = channels;
71 unary_elementwise_op->input_pixel_stride = input_stride;
72 unary_elementwise_op->output_pixel_stride = output_stride;
73 if (params_size != 0) {
74 memcpy(&unary_elementwise_op->params, params, params_size);
75 }
76
77 unary_elementwise_op->ukernel.vunary.function = ukernel;
78 unary_elementwise_op->type = operator_type;
79 unary_elementwise_op->flags = flags;
80
81 unary_elementwise_op->state = xnn_run_state_invalid;
82
83 *unary_elementwise_op_out = unary_elementwise_op;
84 return xnn_status_success;
85 }
86
setup_unary_elementwise_nc(xnn_operator_t unary_elementwise_op,size_t batch_size,const void * input,void * output,uint32_t log2_input_size,uint32_t log2_output_size,const void * params,size_t params_size,size_t num_threads)87 static enum xnn_status setup_unary_elementwise_nc(
88 xnn_operator_t unary_elementwise_op,
89 size_t batch_size,
90 const void* input,
91 void* output,
92 uint32_t log2_input_size,
93 uint32_t log2_output_size,
94 const void* params,
95 size_t params_size,
96 size_t num_threads)
97 {
98 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
99 xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
100 xnn_operator_type_to_string(unary_elementwise_op->type));
101 return xnn_status_uninitialized;
102 }
103
104 if (batch_size == 0) {
105 unary_elementwise_op->state = xnn_run_state_skip;
106 return xnn_status_success;
107 }
108
109 const size_t channels = unary_elementwise_op->channels;
110 const size_t input_stride = unary_elementwise_op->input_pixel_stride;
111 const size_t output_stride = unary_elementwise_op->output_pixel_stride;
112
113 xnn_univector_ukernel_function ukernel = unary_elementwise_op->ukernel.vunary.function;
114
115 if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
116 const size_t block_size = 4096;
117 unary_elementwise_op->context.univector_contiguous = (struct univector_contiguous_context) {
118 .x = input,
119 .y = output,
120 .log2_xsize = log2_input_size,
121 .log2_ysize = log2_output_size,
122 .ukernel = ukernel,
123 };
124 if (params_size != 0) {
125 memcpy(&unary_elementwise_op->context.univector_contiguous.params, params, params_size);
126 }
127
128 const size_t range = (batch_size * channels) << log2_input_size;
129 unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
130 unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_contiguous;
131 unary_elementwise_op->compute.range[0] = range;
132 unary_elementwise_op->compute.tile[0] = (num_threads == 1) ? range : block_size;
133 } else {
134 unary_elementwise_op->context.univector_strided = (struct univector_strided_context) {
135 .n = channels << log2_input_size,
136 .x = input,
137 .x_stride = input_stride << log2_input_size,
138 .y = output,
139 .y_stride = output_stride << log2_output_size,
140 .ukernel = ukernel,
141 };
142 if (params_size != 0) {
143 memcpy(&unary_elementwise_op->context.univector_strided.params, params, params_size);
144 }
145 unary_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
146 unary_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_univector_strided;
147 unary_elementwise_op->compute.range[0] = batch_size;
148 unary_elementwise_op->compute.tile[0] = (num_threads == 1) ? batch_size : 1;
149 }
150 unary_elementwise_op->state = xnn_run_state_ready;
151
152 return xnn_status_success;
153 }
154
xnn_create_clamp_nc_s8(size_t channels,size_t input_stride,size_t output_stride,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * clamp_op_out)155 enum xnn_status xnn_create_clamp_nc_s8(
156 size_t channels,
157 size_t input_stride,
158 size_t output_stride,
159 int8_t output_min,
160 int8_t output_max,
161 uint32_t flags,
162 xnn_operator_t* clamp_op_out)
163 {
164 if (output_min >= output_max) {
165 xnn_log_error(
166 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
167 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8), output_min, output_max);
168 return xnn_status_invalid_parameter;
169 }
170
171 union xnn_s8_minmax_params params;
172 if (xnn_params.s8.clamp.init.s8_minmax != NULL) {
173 xnn_params.s8.clamp.init.s8_minmax(¶ms, output_min, output_max);
174 }
175 return create_unary_elementwise_nc(
176 channels, input_stride, output_stride, flags,
177 ¶ms, sizeof(params),
178 xnn_operator_type_clamp_nc_s8,
179 xnn_params.s8.clamp.ukernel,
180 clamp_op_out);
181 }
182
xnn_create_clamp_nc_u8(size_t channels,size_t input_stride,size_t output_stride,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * clamp_op_out)183 enum xnn_status xnn_create_clamp_nc_u8(
184 size_t channels,
185 size_t input_stride,
186 size_t output_stride,
187 uint8_t output_min,
188 uint8_t output_max,
189 uint32_t flags,
190 xnn_operator_t* clamp_op_out)
191 {
192 if (output_min >= output_max) {
193 xnn_log_error(
194 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
195 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8), output_min, output_max);
196 return xnn_status_invalid_parameter;
197 }
198
199 union xnn_u8_minmax_params params;
200 if (xnn_params.u8.clamp.init.u8_minmax != NULL) {
201 xnn_params.u8.clamp.init.u8_minmax(¶ms, output_min, output_max);
202 }
203 return create_unary_elementwise_nc(
204 channels, input_stride, output_stride, flags,
205 ¶ms, sizeof(params),
206 xnn_operator_type_clamp_nc_u8,
207 xnn_params.u8.clamp.ukernel,
208 clamp_op_out);
209 }
210
xnn_create_clamp_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float output_min,float output_max,uint32_t flags,xnn_operator_t * clamp_op_out)211 enum xnn_status xnn_create_clamp_nc_f32(
212 size_t channels,
213 size_t input_stride,
214 size_t output_stride,
215 float output_min,
216 float output_max,
217 uint32_t flags,
218 xnn_operator_t* clamp_op_out)
219 {
220 if (isnan(output_min)) {
221 xnn_log_error(
222 "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
223 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
224 return xnn_status_invalid_parameter;
225 }
226
227 if (isnan(output_max)) {
228 xnn_log_error(
229 "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
230 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32));
231 return xnn_status_invalid_parameter;
232 }
233
234 if (output_min >= output_max) {
235 xnn_log_error(
236 "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
237 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32), output_min, output_max);
238 return xnn_status_invalid_parameter;
239 }
240
241 const bool relu_activation = (output_max == INFINITY) && (output_min == 0.0f);
242 xnn_univector_ukernel_function clamp_ukernel = (relu_activation && (xnn_params.f32.relu != NULL)) ?
243 xnn_params.f32.relu : xnn_params.f32.clamp.ukernel;
244
245 union xnn_f32_minmax_params params;
246 if (xnn_params.f32.clamp.init.f32_minmax != NULL) {
247 xnn_params.f32.clamp.init.f32_minmax(¶ms, output_min, output_max);
248 }
249 return create_unary_elementwise_nc(
250 channels, input_stride, output_stride, flags,
251 ¶ms, sizeof(params),
252 xnn_operator_type_clamp_nc_f32,
253 clamp_ukernel,
254 clamp_op_out);
255 }
256
xnn_create_abs_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * abs_op_out)257 enum xnn_status xnn_create_abs_nc_f32(
258 size_t channels,
259 size_t input_stride,
260 size_t output_stride,
261 uint32_t flags,
262 xnn_operator_t* abs_op_out)
263 {
264 union xnn_f32_abs_params params;
265 if (xnn_params.f32.abs.init.f32_abs != NULL) {
266 xnn_params.f32.abs.init.f32_abs(¶ms);
267 }
268 return create_unary_elementwise_nc(
269 channels, input_stride, output_stride, flags,
270 ¶ms, sizeof(params),
271 xnn_operator_type_abs_nc_f32,
272 xnn_params.f32.abs.ukernel,
273 abs_op_out);
274 }
275
xnn_create_bankers_rounding_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * rounding_op_out)276 enum xnn_status xnn_create_bankers_rounding_nc_f32(
277 size_t channels,
278 size_t input_stride,
279 size_t output_stride,
280 uint32_t flags,
281 xnn_operator_t* rounding_op_out)
282 {
283 union xnn_f32_rnd_params params;
284 if (xnn_params.f32.rndne.init.f32_rnd != NULL) {
285 xnn_params.f32.rndne.init.f32_rnd(¶ms);
286 }
287 return create_unary_elementwise_nc(
288 channels, input_stride, output_stride, flags,
289 ¶ms, sizeof(params),
290 xnn_operator_type_bankers_rounding_nc_f32,
291 xnn_params.f32.rndne.ukernel,
292 rounding_op_out);
293 }
294
xnn_create_ceiling_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * ceiling_op_out)295 enum xnn_status xnn_create_ceiling_nc_f32(
296 size_t channels,
297 size_t input_stride,
298 size_t output_stride,
299 uint32_t flags,
300 xnn_operator_t* ceiling_op_out)
301 {
302 union xnn_f32_rnd_params params;
303 if (xnn_params.f32.rndu.init.f32_rnd != NULL) {
304 xnn_params.f32.rndu.init.f32_rnd(¶ms);
305 }
306 return create_unary_elementwise_nc(
307 channels, input_stride, output_stride, flags,
308 ¶ms, sizeof(params),
309 xnn_operator_type_ceiling_nc_f32,
310 xnn_params.f32.rndu.ukernel,
311 ceiling_op_out);
312 }
313
xnn_create_convert_nc_f16_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * convert_op_out)314 enum xnn_status xnn_create_convert_nc_f16_f32(
315 size_t channels,
316 size_t input_stride,
317 size_t output_stride,
318 uint32_t flags,
319 xnn_operator_t* convert_op_out)
320 {
321 union xnn_f16_f32_cvt_params params;
322 if (xnn_params.vcvt.f16_to_f32.init.f16_f32_cvt != NULL) {
323 xnn_params.vcvt.f16_to_f32.init.f16_f32_cvt(¶ms);
324 }
325 return create_unary_elementwise_nc(
326 channels, input_stride, output_stride, flags,
327 ¶ms, sizeof(params),
328 xnn_operator_type_convert_nc_f16_f32,
329 xnn_params.vcvt.f16_to_f32.ukernel,
330 convert_op_out);
331 }
332
xnn_create_convert_nc_f32_f16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * convert_op_out)333 enum xnn_status xnn_create_convert_nc_f32_f16(
334 size_t channels,
335 size_t input_stride,
336 size_t output_stride,
337 uint32_t flags,
338 xnn_operator_t* convert_op_out)
339 {
340 union xnn_f32_f16_cvt_params params;
341 if (xnn_params.vcvt.f32_to_f16.init.f32_f16_cvt != NULL) {
342 xnn_params.vcvt.f32_to_f16.init.f32_f16_cvt(¶ms);
343 }
344 return create_unary_elementwise_nc(
345 channels, input_stride, output_stride, flags,
346 ¶ms, sizeof(params),
347 xnn_operator_type_convert_nc_f32_f16,
348 xnn_params.vcvt.f32_to_f16.ukernel,
349 convert_op_out);
350 }
351
xnn_create_convert_nc_f32_qs8(size_t channels,size_t input_stride,size_t output_stride,float output_scale,int8_t output_zero_point,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * convert_op_out)352 enum xnn_status xnn_create_convert_nc_f32_qs8(
353 size_t channels,
354 size_t input_stride,
355 size_t output_stride,
356 float output_scale,
357 int8_t output_zero_point,
358 int8_t output_min,
359 int8_t output_max,
360 uint32_t flags,
361 xnn_operator_t* convert_op_out)
362 {
363 if (output_scale <= 0.0f || !isnormal(output_scale)) {
364 xnn_log_error(
365 "failed to create %s operator with %.7g output scale parameter: scale must be finite, normalized, and positive",
366 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8), output_scale);
367 return xnn_status_invalid_parameter;
368 }
369
370 if (output_min >= output_max) {
371 xnn_log_error(
372 "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
373 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8), output_min, output_max);
374 return xnn_status_invalid_parameter;
375 }
376
377 union xnn_f32_qs8_cvt_params params;
378 if (xnn_params.vcvt.f32_to_qs8.init.f32_qs8_cvt != NULL) {
379 xnn_params.vcvt.f32_to_qs8.init.f32_qs8_cvt(¶ms, 1.0f / output_scale, output_zero_point, output_min, output_max);
380 }
381 return create_unary_elementwise_nc(
382 channels, input_stride, output_stride, flags,
383 ¶ms, sizeof(params),
384 xnn_operator_type_convert_nc_f32_qs8,
385 xnn_params.vcvt.f32_to_qs8.ukernel,
386 convert_op_out);
387 }
388
xnn_create_convert_nc_f32_qu8(size_t channels,size_t input_stride,size_t output_stride,float output_scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * convert_op_out)389 enum xnn_status xnn_create_convert_nc_f32_qu8(
390 size_t channels,
391 size_t input_stride,
392 size_t output_stride,
393 float output_scale,
394 uint8_t output_zero_point,
395 uint8_t output_min,
396 uint8_t output_max,
397 uint32_t flags,
398 xnn_operator_t* convert_op_out)
399 {
400 if (output_scale <= 0.0f || !isnormal(output_scale)) {
401 xnn_log_error(
402 "failed to create %s operator with %.7g output scale parameter: scale must be finite, normalized, and positive",
403 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8), output_scale);
404 return xnn_status_invalid_parameter;
405 }
406
407 if (output_min >= output_max) {
408 xnn_log_error(
409 "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
410 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8), output_min, output_max);
411 return xnn_status_invalid_parameter;
412 }
413
414 union xnn_f32_qu8_cvt_params params;
415 if (xnn_params.vcvt.f32_to_qu8.init.f32_qu8_cvt != NULL) {
416 xnn_params.vcvt.f32_to_qu8.init.f32_qu8_cvt(¶ms, 1.0f / output_scale, output_zero_point, output_min, output_max);
417 }
418 return create_unary_elementwise_nc(
419 channels, input_stride, output_stride, flags,
420 ¶ms, sizeof(params),
421 xnn_operator_type_convert_nc_f32_qu8,
422 xnn_params.vcvt.f32_to_qu8.ukernel,
423 convert_op_out);
424 }
425
xnn_create_convert_nc_qs8_f32(size_t channels,size_t input_stride,size_t output_stride,float input_scale,int8_t input_zero_point,uint32_t flags,xnn_operator_t * convert_op_out)426 enum xnn_status xnn_create_convert_nc_qs8_f32(
427 size_t channels,
428 size_t input_stride,
429 size_t output_stride,
430 float input_scale,
431 int8_t input_zero_point,
432 uint32_t flags,
433 xnn_operator_t* convert_op_out)
434 {
435 if (input_scale <= 0.0f || !isnormal(input_scale)) {
436 xnn_log_error(
437 "failed to create %s operator with %.7g input scale parameter: scale must be finite, normalized, and positive",
438 xnn_operator_type_to_string(xnn_operator_type_convert_nc_qs8_f32), input_scale);
439 return xnn_status_invalid_parameter;
440 }
441
442 union xnn_qs8_f32_cvt_params params;
443 if (xnn_params.vcvt.qs8_to_f32.init.qs8_f32_cvt != NULL) {
444 xnn_params.vcvt.qs8_to_f32.init.qs8_f32_cvt(¶ms, input_scale, input_zero_point);
445 }
446 return create_unary_elementwise_nc(
447 channels, input_stride, output_stride, flags,
448 ¶ms, sizeof(params),
449 xnn_operator_type_convert_nc_qs8_f32,
450 xnn_params.vcvt.qs8_to_f32.ukernel,
451 convert_op_out);
452 }
453
xnn_create_convert_nc_qu8_f32(size_t channels,size_t input_stride,size_t output_stride,float input_scale,uint8_t input_zero_point,uint32_t flags,xnn_operator_t * convert_op_out)454 enum xnn_status xnn_create_convert_nc_qu8_f32(
455 size_t channels,
456 size_t input_stride,
457 size_t output_stride,
458 float input_scale,
459 uint8_t input_zero_point,
460 uint32_t flags,
461 xnn_operator_t* convert_op_out)
462 {
463 if (input_scale <= 0.0f || !isnormal(input_scale)) {
464 xnn_log_error(
465 "failed to create %s operator with %.7g input scale parameter: scale must be finite, normalized, and positive",
466 xnn_operator_type_to_string(xnn_operator_type_convert_nc_qu8_f32), input_scale);
467 return xnn_status_invalid_parameter;
468 }
469
470 union xnn_qu8_f32_cvt_params params;
471 if (xnn_params.vcvt.qu8_to_f32.init.qu8_f32_cvt != NULL) {
472 xnn_params.vcvt.qu8_to_f32.init.qu8_f32_cvt(¶ms, input_scale, input_zero_point);
473 }
474 return create_unary_elementwise_nc(
475 channels, input_stride, output_stride, flags,
476 ¶ms, sizeof(params),
477 xnn_operator_type_convert_nc_qu8_f32,
478 xnn_params.vcvt.qu8_to_f32.ukernel,
479 convert_op_out);
480 }
481
xnn_create_copy_nc_x8(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)482 enum xnn_status xnn_create_copy_nc_x8(
483 size_t channels,
484 size_t input_stride,
485 size_t output_stride,
486 uint32_t flags,
487 xnn_operator_t* copy_op_out)
488 {
489 return create_unary_elementwise_nc(
490 channels, input_stride, output_stride, flags,
491 NULL, 0,
492 xnn_operator_type_copy_nc_x8,
493 xnn_params.xx.copy,
494 copy_op_out);
495 }
496
xnn_create_copy_nc_x16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)497 enum xnn_status xnn_create_copy_nc_x16(
498 size_t channels,
499 size_t input_stride,
500 size_t output_stride,
501 uint32_t flags,
502 xnn_operator_t* copy_op_out)
503 {
504 return create_unary_elementwise_nc(
505 channels, input_stride, output_stride, flags,
506 NULL, 0,
507 xnn_operator_type_copy_nc_x16,
508 xnn_params.xx.copy,
509 copy_op_out);
510 }
511
xnn_create_copy_nc_x32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * copy_op_out)512 enum xnn_status xnn_create_copy_nc_x32(
513 size_t channels,
514 size_t input_stride,
515 size_t output_stride,
516 uint32_t flags,
517 xnn_operator_t* copy_op_out)
518 {
519 return create_unary_elementwise_nc(
520 channels, input_stride, output_stride, flags,
521 NULL, 0,
522 xnn_operator_type_copy_nc_x32,
523 xnn_params.xx.copy,
524 copy_op_out);
525 }
526
xnn_create_elu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float alpha,uint32_t flags,xnn_operator_t * elu_op_out)527 enum xnn_status xnn_create_elu_nc_f32(
528 size_t channels,
529 size_t input_stride,
530 size_t output_stride,
531 float alpha,
532 uint32_t flags,
533 xnn_operator_t* elu_op_out)
534 {
535 if (alpha <= 0.0f || !isnormal(alpha)) {
536 xnn_log_error(
537 "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
538 xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32), alpha);
539 return xnn_status_invalid_parameter;
540 }
541
542 union xnn_f32_elu_params params;
543 if (xnn_params.f32.elu.init.f32_elu != NULL) {
544 xnn_params.f32.elu.init.f32_elu(¶ms, 1.0f /* prescale */, alpha, 1.0f /* beta */);
545 }
546 return create_unary_elementwise_nc(
547 channels, input_stride, output_stride, flags,
548 ¶ms, sizeof(params),
549 xnn_operator_type_elu_nc_f32,
550 xnn_params.f32.elu.ukernel,
551 elu_op_out);
552 }
553
xnn_create_floor_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * floor_op_out)554 enum xnn_status xnn_create_floor_nc_f32(
555 size_t channels,
556 size_t input_stride,
557 size_t output_stride,
558 uint32_t flags,
559 xnn_operator_t* floor_op_out)
560 {
561 union xnn_f32_rnd_params params;
562 if (xnn_params.f32.rndd.init.f32_rnd != NULL) {
563 xnn_params.f32.rndd.init.f32_rnd(¶ms);
564 }
565 return create_unary_elementwise_nc(
566 channels, input_stride, output_stride, flags,
567 ¶ms, sizeof(params),
568 xnn_operator_type_floor_nc_f32,
569 xnn_params.f32.rndd.ukernel,
570 floor_op_out);
571 }
572
xnn_create_hardswish_nc_f16(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)573 enum xnn_status xnn_create_hardswish_nc_f16(
574 size_t channels,
575 size_t input_stride,
576 size_t output_stride,
577 uint32_t flags,
578 xnn_operator_t* hardswish_op_out)
579 {
580 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
581 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
582 xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
583 return xnn_status_uninitialized;
584 }
585
586 if ((xnn_params.init_flags & XNN_INIT_FLAG_F16) != XNN_INIT_FLAG_F16) {
587 xnn_log_error("failed to create %s operator: operations on data type are not supported",
588 xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16));
589 return xnn_status_unsupported_hardware;
590 }
591
592 union xnn_f16_hswish_params params;
593 if (xnn_params.f16.hswish.init.f16_hswish != NULL) {
594 xnn_params.f16.hswish.init.f16_hswish(¶ms);
595 }
596 return create_unary_elementwise_nc(
597 channels, input_stride, output_stride, flags,
598 ¶ms, sizeof(params),
599 xnn_operator_type_hardswish_nc_f16,
600 xnn_params.f16.hswish.ukernel,
601 hardswish_op_out);
602 }
603
xnn_create_hardswish_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * hardswish_op_out)604 enum xnn_status xnn_create_hardswish_nc_f32(
605 size_t channels,
606 size_t input_stride,
607 size_t output_stride,
608 uint32_t flags,
609 xnn_operator_t* hardswish_op_out)
610 {
611 union xnn_f32_hswish_params params;
612 if (xnn_params.f32.hswish.init.f32_hswish != NULL) {
613 xnn_params.f32.hswish.init.f32_hswish(¶ms);
614 }
615 return create_unary_elementwise_nc(
616 channels, input_stride, output_stride, flags,
617 ¶ms, sizeof(params),
618 xnn_operator_type_hardswish_nc_f32,
619 xnn_params.f32.hswish.ukernel,
620 hardswish_op_out);
621 }
622
xnn_create_leaky_relu_nc_f32(size_t channels,size_t input_stride,size_t output_stride,float negative_slope,uint32_t flags,xnn_operator_t * leaky_relu_op_out)623 enum xnn_status xnn_create_leaky_relu_nc_f32(
624 size_t channels,
625 size_t input_stride,
626 size_t output_stride,
627 float negative_slope,
628 uint32_t flags,
629 xnn_operator_t* leaky_relu_op_out)
630 {
631 if (!isfinite(negative_slope)) {
632 xnn_log_error(
633 "failed to create %s operator with %f negative slope: finite number expected",
634 xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
635 negative_slope);
636 return xnn_status_invalid_parameter;
637 }
638
639 union xnn_f32_lrelu_params params;
640 if (xnn_params.f32.lrelu.init.f32_lrelu != NULL) {
641 xnn_params.f32.lrelu.init.f32_lrelu(¶ms, negative_slope);
642 }
643 return create_unary_elementwise_nc(
644 channels, input_stride, output_stride, flags,
645 ¶ms, sizeof(params),
646 xnn_operator_type_leaky_relu_nc_f32,
647 xnn_params.f32.lrelu.ukernel,
648 leaky_relu_op_out);
649 }
650
xnn_create_negate_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * negate_op_out)651 enum xnn_status xnn_create_negate_nc_f32(
652 size_t channels,
653 size_t input_stride,
654 size_t output_stride,
655 uint32_t flags,
656 xnn_operator_t* negate_op_out)
657 {
658 union xnn_f32_neg_params params;
659 if (xnn_params.f32.neg.init.f32_neg != NULL) {
660 xnn_params.f32.neg.init.f32_neg(¶ms);
661 }
662 return create_unary_elementwise_nc(
663 channels, input_stride, output_stride, flags,
664 ¶ms, sizeof(params),
665 xnn_operator_type_negate_nc_f32,
666 xnn_params.f32.neg.ukernel,
667 negate_op_out);
668 }
669
xnn_create_sigmoid_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sigmoid_op_out)670 enum xnn_status xnn_create_sigmoid_nc_f32(
671 size_t channels,
672 size_t input_stride,
673 size_t output_stride,
674 uint32_t flags,
675 xnn_operator_t* sigmoid_op_out)
676 {
677 union xnn_f32_sigmoid_params params;
678 if (xnn_params.f32.sigmoid.init.f32_sigmoid != NULL) {
679 xnn_params.f32.sigmoid.init.f32_sigmoid(¶ms);
680 }
681 return create_unary_elementwise_nc(
682 channels, input_stride, output_stride, flags,
683 ¶ms, sizeof(params),
684 xnn_operator_type_sigmoid_nc_f32,
685 xnn_params.f32.sigmoid.ukernel,
686 sigmoid_op_out);
687 }
688
xnn_create_square_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * square_op_out)689 enum xnn_status xnn_create_square_nc_f32(
690 size_t channels,
691 size_t input_stride,
692 size_t output_stride,
693 uint32_t flags,
694 xnn_operator_t* square_op_out)
695 {
696 union xnn_f32_default_params params;
697 if (xnn_params.f32.sqr.init.f32_default != NULL) {
698 xnn_params.f32.sqr.init.f32_default(¶ms);
699 }
700 return create_unary_elementwise_nc(
701 channels, input_stride, output_stride, flags,
702 ¶ms, sizeof(params),
703 xnn_operator_type_square_nc_f32,
704 xnn_params.f32.sqr.ukernel,
705 square_op_out);
706 }
707
xnn_create_square_root_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * sqrt_op_out)708 enum xnn_status xnn_create_square_root_nc_f32(
709 size_t channels,
710 size_t input_stride,
711 size_t output_stride,
712 uint32_t flags,
713 xnn_operator_t* sqrt_op_out)
714 {
715 union xnn_f32_sqrt_params params;
716 if (xnn_params.f32.sqrt.init.f32_sqrt != NULL) {
717 xnn_params.f32.sqrt.init.f32_sqrt(¶ms);
718 }
719 return create_unary_elementwise_nc(
720 channels, input_stride, output_stride, flags,
721 ¶ms, sizeof(params),
722 xnn_operator_type_square_root_nc_f32,
723 xnn_params.f32.sqrt.ukernel,
724 sqrt_op_out);
725 }
726
xnn_create_truncation_nc_f32(size_t channels,size_t input_stride,size_t output_stride,uint32_t flags,xnn_operator_t * truncation_op_out)727 enum xnn_status xnn_create_truncation_nc_f32(
728 size_t channels,
729 size_t input_stride,
730 size_t output_stride,
731 uint32_t flags,
732 xnn_operator_t* truncation_op_out)
733 {
734 union xnn_f32_rnd_params params;
735 if (xnn_params.f32.rndz.init.f32_rnd != NULL) {
736 xnn_params.f32.rndz.init.f32_rnd(¶ms);
737 }
738 return create_unary_elementwise_nc(
739 channels, input_stride, output_stride, flags,
740 ¶ms, sizeof(params),
741 xnn_operator_type_truncation_nc_f32,
742 xnn_params.f32.rndz.ukernel,
743 truncation_op_out);
744 }
745
xnn_setup_abs_nc_f32(xnn_operator_t abs_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)746 enum xnn_status xnn_setup_abs_nc_f32(
747 xnn_operator_t abs_op,
748 size_t batch_size,
749 const float* input,
750 float* output,
751 pthreadpool_t threadpool)
752 {
753 if (abs_op->type != xnn_operator_type_abs_nc_f32) {
754 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
755 xnn_operator_type_to_string(xnn_operator_type_abs_nc_f32),
756 xnn_operator_type_to_string(abs_op->type));
757 return xnn_status_invalid_parameter;
758 }
759 abs_op->state = xnn_run_state_invalid;
760
761 return setup_unary_elementwise_nc(
762 abs_op,
763 batch_size, input, output,
764 2 /* log2(sizeof(float)) */,
765 2 /* log2(sizeof(float)) */,
766 &abs_op->params.f32_abs, sizeof(abs_op->params.f32_abs),
767 pthreadpool_get_threads_count(threadpool));
768 }
769
xnn_setup_bankers_rounding_nc_f32(xnn_operator_t rounding_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)770 enum xnn_status xnn_setup_bankers_rounding_nc_f32(
771 xnn_operator_t rounding_op,
772 size_t batch_size,
773 const float* input,
774 float* output,
775 pthreadpool_t threadpool)
776 {
777 if (rounding_op->type != xnn_operator_type_bankers_rounding_nc_f32) {
778 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
779 xnn_operator_type_to_string(xnn_operator_type_bankers_rounding_nc_f32),
780 xnn_operator_type_to_string(rounding_op->type));
781 return xnn_status_invalid_parameter;
782 }
783 rounding_op->state = xnn_run_state_invalid;
784
785 return setup_unary_elementwise_nc(
786 rounding_op,
787 batch_size, input, output,
788 2 /* log2(sizeof(float)) */,
789 2 /* log2(sizeof(float)) */,
790 &rounding_op->params.f32_rnd, sizeof(rounding_op->params.f32_rnd),
791 pthreadpool_get_threads_count(threadpool));
792 }
793
xnn_setup_ceiling_nc_f32(xnn_operator_t ceiling_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)794 enum xnn_status xnn_setup_ceiling_nc_f32(
795 xnn_operator_t ceiling_op,
796 size_t batch_size,
797 const float* input,
798 float* output,
799 pthreadpool_t threadpool)
800 {
801 if (ceiling_op->type != xnn_operator_type_ceiling_nc_f32) {
802 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
803 xnn_operator_type_to_string(xnn_operator_type_ceiling_nc_f32),
804 xnn_operator_type_to_string(ceiling_op->type));
805 return xnn_status_invalid_parameter;
806 }
807 ceiling_op->state = xnn_run_state_invalid;
808
809 return setup_unary_elementwise_nc(
810 ceiling_op,
811 batch_size, input, output,
812 2 /* log2(sizeof(float)) */,
813 2 /* log2(sizeof(float)) */,
814 &ceiling_op->params.f32_rnd, sizeof(ceiling_op->params.f32_rnd),
815 pthreadpool_get_threads_count(threadpool));
816 }
817
xnn_setup_clamp_nc_s8(xnn_operator_t clamp_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)818 enum xnn_status xnn_setup_clamp_nc_s8(
819 xnn_operator_t clamp_op,
820 size_t batch_size,
821 const int8_t* input,
822 int8_t* output,
823 pthreadpool_t threadpool)
824 {
825 if (clamp_op->type != xnn_operator_type_clamp_nc_s8) {
826 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
827 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_s8),
828 xnn_operator_type_to_string(clamp_op->type));
829 return xnn_status_invalid_parameter;
830 }
831 clamp_op->state = xnn_run_state_invalid;
832
833 return setup_unary_elementwise_nc(
834 clamp_op,
835 batch_size, input, output,
836 0 /* log2(sizeof(int8_t)) */,
837 0 /* log2(sizeof(int8_t)) */,
838 &clamp_op->params.s8_minmax, sizeof(clamp_op->params.s8_minmax),
839 pthreadpool_get_threads_count(threadpool));
840 }
841
xnn_setup_clamp_nc_u8(xnn_operator_t clamp_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)842 enum xnn_status xnn_setup_clamp_nc_u8(
843 xnn_operator_t clamp_op,
844 size_t batch_size,
845 const uint8_t* input,
846 uint8_t* output,
847 pthreadpool_t threadpool)
848 {
849 if (clamp_op->type != xnn_operator_type_clamp_nc_u8) {
850 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
851 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_u8),
852 xnn_operator_type_to_string(clamp_op->type));
853 return xnn_status_invalid_parameter;
854 }
855 clamp_op->state = xnn_run_state_invalid;
856
857 return setup_unary_elementwise_nc(
858 clamp_op,
859 batch_size, input, output,
860 0 /* log2(sizeof(uint8_t)) */,
861 0 /* log2(sizeof(uint8_t)) */,
862 &clamp_op->params.u8_minmax, sizeof(clamp_op->params.u8_minmax),
863 pthreadpool_get_threads_count(threadpool));
864 }
865
xnn_setup_clamp_nc_f32(xnn_operator_t clamp_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)866 enum xnn_status xnn_setup_clamp_nc_f32(
867 xnn_operator_t clamp_op,
868 size_t batch_size,
869 const float* input,
870 float* output,
871 pthreadpool_t threadpool)
872 {
873 if (clamp_op->type != xnn_operator_type_clamp_nc_f32) {
874 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
875 xnn_operator_type_to_string(xnn_operator_type_clamp_nc_f32),
876 xnn_operator_type_to_string(clamp_op->type));
877 return xnn_status_invalid_parameter;
878 }
879 clamp_op->state = xnn_run_state_invalid;
880
881 return setup_unary_elementwise_nc(
882 clamp_op,
883 batch_size, input, output,
884 2 /* log2(sizeof(float)) */,
885 2 /* log2(sizeof(float)) */,
886 &clamp_op->params.f32_minmax, sizeof(clamp_op->params.f32_minmax),
887 pthreadpool_get_threads_count(threadpool));
888 }
889
xnn_setup_convert_nc_f16_f32(xnn_operator_t convert_op,size_t batch_size,const void * input,float * output,pthreadpool_t threadpool)890 enum xnn_status xnn_setup_convert_nc_f16_f32(
891 xnn_operator_t convert_op,
892 size_t batch_size,
893 const void* input,
894 float* output,
895 pthreadpool_t threadpool)
896 {
897 if (convert_op->type != xnn_operator_type_convert_nc_f16_f32) {
898 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
899 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f16_f32),
900 xnn_operator_type_to_string(convert_op->type));
901 return xnn_status_invalid_parameter;
902 }
903 convert_op->state = xnn_run_state_invalid;
904
905 return setup_unary_elementwise_nc(
906 convert_op,
907 batch_size, input, output,
908 1 /* log2(sizeof(uint16_t)) */,
909 2 /* log2(sizeof(float)) */,
910 &convert_op->params.f16_f32_cvt, sizeof(convert_op->params.f16_f32_cvt),
911 pthreadpool_get_threads_count(threadpool));
912 }
913
xnn_setup_convert_nc_f32_f16(xnn_operator_t convert_op,size_t batch_size,const float * input,void * output,pthreadpool_t threadpool)914 enum xnn_status xnn_setup_convert_nc_f32_f16(
915 xnn_operator_t convert_op,
916 size_t batch_size,
917 const float* input,
918 void* output,
919 pthreadpool_t threadpool)
920 {
921 if (convert_op->type != xnn_operator_type_convert_nc_f32_f16) {
922 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
923 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_f16),
924 xnn_operator_type_to_string(convert_op->type));
925 return xnn_status_invalid_parameter;
926 }
927 convert_op->state = xnn_run_state_invalid;
928
929 return setup_unary_elementwise_nc(
930 convert_op,
931 batch_size, input, output,
932 2 /* log2(sizeof(float)) */,
933 1 /* log2(sizeof(uint16_t)) */,
934 &convert_op->params.f32_f16_cvt, sizeof(convert_op->params.f32_f16_cvt),
935 pthreadpool_get_threads_count(threadpool));
936 }
937
xnn_setup_convert_nc_f32_qs8(xnn_operator_t convert_op,size_t batch_size,const float * input,int8_t * output,pthreadpool_t threadpool)938 enum xnn_status xnn_setup_convert_nc_f32_qs8(
939 xnn_operator_t convert_op,
940 size_t batch_size,
941 const float* input,
942 int8_t* output,
943 pthreadpool_t threadpool)
944 {
945 if (convert_op->type != xnn_operator_type_convert_nc_f32_qs8) {
946 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
947 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qs8),
948 xnn_operator_type_to_string(convert_op->type));
949 return xnn_status_invalid_parameter;
950 }
951 convert_op->state = xnn_run_state_invalid;
952
953 return setup_unary_elementwise_nc(
954 convert_op,
955 batch_size, input, output,
956 2 /* log2(sizeof(float)) */,
957 0 /* log2(sizeof(int8_t)) */,
958 &convert_op->params.f32_qs8_cvt, sizeof(convert_op->params.f32_qs8_cvt),
959 pthreadpool_get_threads_count(threadpool));
960 }
961
xnn_setup_convert_nc_f32_qu8(xnn_operator_t convert_op,size_t batch_size,const float * input,uint8_t * output,pthreadpool_t threadpool)962 enum xnn_status xnn_setup_convert_nc_f32_qu8(
963 xnn_operator_t convert_op,
964 size_t batch_size,
965 const float* input,
966 uint8_t* output,
967 pthreadpool_t threadpool)
968 {
969 if (convert_op->type != xnn_operator_type_convert_nc_f32_qu8) {
970 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
971 xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qu8),
972 xnn_operator_type_to_string(convert_op->type));
973 return xnn_status_invalid_parameter;
974 }
975 convert_op->state = xnn_run_state_invalid;
976
977 return setup_unary_elementwise_nc(
978 convert_op,
979 batch_size, input, output,
980 2 /* log2(sizeof(float)) */,
981 0 /* log2(sizeof(uint8_t)) */,
982 &convert_op->params.f32_qu8_cvt, sizeof(convert_op->params.f32_qu8_cvt),
983 pthreadpool_get_threads_count(threadpool));
984 }
985
xnn_setup_convert_nc_qs8_f32(xnn_operator_t convert_op,size_t batch_size,const int8_t * input,float * output,pthreadpool_t threadpool)986 enum xnn_status xnn_setup_convert_nc_qs8_f32(
987 xnn_operator_t convert_op,
988 size_t batch_size,
989 const int8_t* input,
990 float* output,
991 pthreadpool_t threadpool)
992 {
993 if (convert_op->type != xnn_operator_type_convert_nc_qs8_f32) {
994 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
995 xnn_operator_type_to_string(xnn_operator_type_convert_nc_qs8_f32),
996 xnn_operator_type_to_string(convert_op->type));
997 return xnn_status_invalid_parameter;
998 }
999 convert_op->state = xnn_run_state_invalid;
1000
1001 return setup_unary_elementwise_nc(
1002 convert_op,
1003 batch_size, input, output,
1004 0 /* log2(sizeof(int8_t)) */,
1005 2 /* log2(sizeof(float)) */,
1006 &convert_op->params.qs8_f32_cvt, sizeof(convert_op->params.qs8_f32_cvt),
1007 pthreadpool_get_threads_count(threadpool));
1008 }
1009
xnn_setup_convert_nc_qu8_f32(xnn_operator_t convert_op,size_t batch_size,const uint8_t * input,float * output,pthreadpool_t threadpool)1010 enum xnn_status xnn_setup_convert_nc_qu8_f32(
1011 xnn_operator_t convert_op,
1012 size_t batch_size,
1013 const uint8_t* input,
1014 float* output,
1015 pthreadpool_t threadpool)
1016 {
1017 if (convert_op->type != xnn_operator_type_convert_nc_qu8_f32) {
1018 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1019 xnn_operator_type_to_string(xnn_operator_type_convert_nc_qu8_f32),
1020 xnn_operator_type_to_string(convert_op->type));
1021 return xnn_status_invalid_parameter;
1022 }
1023 convert_op->state = xnn_run_state_invalid;
1024
1025 return setup_unary_elementwise_nc(
1026 convert_op,
1027 batch_size, input, output,
1028 0 /* log2(sizeof(uint8_t)) */,
1029 2 /* log2(sizeof(float)) */,
1030 &convert_op->params.qu8_f32_cvt, sizeof(convert_op->params.qu8_f32_cvt),
1031 pthreadpool_get_threads_count(threadpool));
1032 }
1033
xnn_setup_copy_nc_x8(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1034 enum xnn_status xnn_setup_copy_nc_x8(
1035 xnn_operator_t copy_op,
1036 size_t batch_size,
1037 const void* input,
1038 void* output,
1039 pthreadpool_t threadpool)
1040 {
1041 if (copy_op->type != xnn_operator_type_copy_nc_x8) {
1042 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1043 xnn_operator_type_to_string(xnn_operator_type_copy_nc_x8),
1044 xnn_operator_type_to_string(copy_op->type));
1045 return xnn_status_invalid_parameter;
1046 }
1047 copy_op->state = xnn_run_state_invalid;
1048
1049 return setup_unary_elementwise_nc(
1050 copy_op,
1051 batch_size, input, output,
1052 0 /* log2(sizeof(uint16_t)) */,
1053 0 /* log2(sizeof(uint16_t)) */,
1054 NULL, 0,
1055 pthreadpool_get_threads_count(threadpool));
1056 }
1057
xnn_setup_copy_nc_x16(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1058 enum xnn_status xnn_setup_copy_nc_x16(
1059 xnn_operator_t copy_op,
1060 size_t batch_size,
1061 const void* input,
1062 void* output,
1063 pthreadpool_t threadpool)
1064 {
1065 if (copy_op->type != xnn_operator_type_copy_nc_x16) {
1066 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1067 xnn_operator_type_to_string(xnn_operator_type_copy_nc_x16),
1068 xnn_operator_type_to_string(copy_op->type));
1069 return xnn_status_invalid_parameter;
1070 }
1071 copy_op->state = xnn_run_state_invalid;
1072
1073 return setup_unary_elementwise_nc(
1074 copy_op,
1075 batch_size, input, output,
1076 1 /* log2(sizeof(uint16_t)) */,
1077 1 /* log2(sizeof(uint16_t)) */,
1078 NULL, 0,
1079 pthreadpool_get_threads_count(threadpool));
1080 }
1081
xnn_setup_copy_nc_x32(xnn_operator_t copy_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1082 enum xnn_status xnn_setup_copy_nc_x32(
1083 xnn_operator_t copy_op,
1084 size_t batch_size,
1085 const void* input,
1086 void* output,
1087 pthreadpool_t threadpool)
1088 {
1089 if (copy_op->type != xnn_operator_type_copy_nc_x32) {
1090 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1091 xnn_operator_type_to_string(xnn_operator_type_copy_nc_x32),
1092 xnn_operator_type_to_string(copy_op->type));
1093 return xnn_status_invalid_parameter;
1094 }
1095 copy_op->state = xnn_run_state_invalid;
1096
1097 return setup_unary_elementwise_nc(
1098 copy_op,
1099 batch_size, input, output,
1100 2 /* log2(sizeof(uint32_t)) */,
1101 2 /* log2(sizeof(uint32_t)) */,
1102 NULL, 0,
1103 pthreadpool_get_threads_count(threadpool));
1104 }
1105
xnn_setup_elu_nc_f32(xnn_operator_t elu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1106 enum xnn_status xnn_setup_elu_nc_f32(
1107 xnn_operator_t elu_op,
1108 size_t batch_size,
1109 const float* input,
1110 float* output,
1111 pthreadpool_t threadpool)
1112 {
1113 if (elu_op->type != xnn_operator_type_elu_nc_f32) {
1114 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1115 xnn_operator_type_to_string(xnn_operator_type_elu_nc_f32),
1116 xnn_operator_type_to_string(elu_op->type));
1117 return xnn_status_invalid_parameter;
1118 }
1119 elu_op->state = xnn_run_state_invalid;
1120
1121 return setup_unary_elementwise_nc(
1122 elu_op,
1123 batch_size, input, output,
1124 2 /* log2(sizeof(float)) */,
1125 2 /* log2(sizeof(float)) */,
1126 &elu_op->params.f32_elu, sizeof(elu_op->params.f32_elu),
1127 pthreadpool_get_threads_count(threadpool));
1128 }
1129
xnn_setup_floor_nc_f32(xnn_operator_t floor_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1130 enum xnn_status xnn_setup_floor_nc_f32(
1131 xnn_operator_t floor_op,
1132 size_t batch_size,
1133 const float* input,
1134 float* output,
1135 pthreadpool_t threadpool)
1136 {
1137 if (floor_op->type != xnn_operator_type_floor_nc_f32) {
1138 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1139 xnn_operator_type_to_string(xnn_operator_type_floor_nc_f32),
1140 xnn_operator_type_to_string(floor_op->type));
1141 return xnn_status_invalid_parameter;
1142 }
1143 floor_op->state = xnn_run_state_invalid;
1144
1145 return setup_unary_elementwise_nc(
1146 floor_op,
1147 batch_size, input, output,
1148 2 /* log2(sizeof(float)) */,
1149 2 /* log2(sizeof(float)) */,
1150 &floor_op->params.f32_rnd, sizeof(floor_op->params.f32_rnd),
1151 pthreadpool_get_threads_count(threadpool));
1152 }
1153
xnn_setup_hardswish_nc_f16(xnn_operator_t hardswish_op,size_t batch_size,const void * input,void * output,pthreadpool_t threadpool)1154 enum xnn_status xnn_setup_hardswish_nc_f16(
1155 xnn_operator_t hardswish_op,
1156 size_t batch_size,
1157 const void* input,
1158 void* output,
1159 pthreadpool_t threadpool)
1160 {
1161 if (hardswish_op->type != xnn_operator_type_hardswish_nc_f16) {
1162 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1163 xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f16),
1164 xnn_operator_type_to_string(hardswish_op->type));
1165 return xnn_status_invalid_parameter;
1166 }
1167 hardswish_op->state = xnn_run_state_invalid;
1168
1169 return setup_unary_elementwise_nc(
1170 hardswish_op,
1171 batch_size, input, output,
1172 1 /* log2(sizeof(half)) */,
1173 1 /* log2(sizeof(half)) */,
1174 &hardswish_op->params.f16_hswish, sizeof(hardswish_op->params.f16_hswish),
1175 pthreadpool_get_threads_count(threadpool));
1176 }
1177
xnn_setup_hardswish_nc_f32(xnn_operator_t hardswish_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1178 enum xnn_status xnn_setup_hardswish_nc_f32(
1179 xnn_operator_t hardswish_op,
1180 size_t batch_size,
1181 const float* input,
1182 float* output,
1183 pthreadpool_t threadpool)
1184 {
1185 if (hardswish_op->type != xnn_operator_type_hardswish_nc_f32) {
1186 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1187 xnn_operator_type_to_string(xnn_operator_type_hardswish_nc_f32),
1188 xnn_operator_type_to_string(hardswish_op->type));
1189 return xnn_status_invalid_parameter;
1190 }
1191 hardswish_op->state = xnn_run_state_invalid;
1192
1193 return setup_unary_elementwise_nc(
1194 hardswish_op,
1195 batch_size, input, output,
1196 2 /* log2(sizeof(float)) */,
1197 2 /* log2(sizeof(float)) */,
1198 &hardswish_op->params.f32_hswish, sizeof(hardswish_op->params.f32_hswish),
1199 pthreadpool_get_threads_count(threadpool));
1200 }
1201
xnn_setup_leaky_relu_nc_f32(xnn_operator_t leaky_relu_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1202 enum xnn_status xnn_setup_leaky_relu_nc_f32(
1203 xnn_operator_t leaky_relu_op,
1204 size_t batch_size,
1205 const float* input,
1206 float* output,
1207 pthreadpool_t threadpool)
1208 {
1209 if (leaky_relu_op->type != xnn_operator_type_leaky_relu_nc_f32) {
1210 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1211 xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_f32),
1212 xnn_operator_type_to_string(leaky_relu_op->type));
1213 return xnn_status_invalid_parameter;
1214 }
1215 leaky_relu_op->state = xnn_run_state_invalid;
1216
1217 return setup_unary_elementwise_nc(
1218 leaky_relu_op,
1219 batch_size, input, output,
1220 2 /* log2(sizeof(float)) */,
1221 2 /* log2(sizeof(float)) */,
1222 &leaky_relu_op->params.f32_lrelu, sizeof(leaky_relu_op->params.f32_lrelu),
1223 pthreadpool_get_threads_count(threadpool));
1224 }
1225
xnn_setup_negate_nc_f32(xnn_operator_t negate_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1226 enum xnn_status xnn_setup_negate_nc_f32(
1227 xnn_operator_t negate_op,
1228 size_t batch_size,
1229 const float* input,
1230 float* output,
1231 pthreadpool_t threadpool)
1232 {
1233 if (negate_op->type != xnn_operator_type_negate_nc_f32) {
1234 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1235 xnn_operator_type_to_string(xnn_operator_type_negate_nc_f32),
1236 xnn_operator_type_to_string(negate_op->type));
1237 return xnn_status_invalid_parameter;
1238 }
1239 negate_op->state = xnn_run_state_invalid;
1240
1241 return setup_unary_elementwise_nc(
1242 negate_op,
1243 batch_size, input, output,
1244 2 /* log2(sizeof(float)) */,
1245 2 /* log2(sizeof(float)) */,
1246 &negate_op->params.f32_neg, sizeof(negate_op->params.f32_neg),
1247 pthreadpool_get_threads_count(threadpool));
1248 }
1249
xnn_setup_sigmoid_nc_f32(xnn_operator_t sigmoid_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1250 enum xnn_status xnn_setup_sigmoid_nc_f32(
1251 xnn_operator_t sigmoid_op,
1252 size_t batch_size,
1253 const float* input,
1254 float* output,
1255 pthreadpool_t threadpool)
1256 {
1257 if (sigmoid_op->type != xnn_operator_type_sigmoid_nc_f32) {
1258 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1259 xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_f32),
1260 xnn_operator_type_to_string(sigmoid_op->type));
1261 return xnn_status_invalid_parameter;
1262 }
1263 sigmoid_op->state = xnn_run_state_invalid;
1264
1265 return setup_unary_elementwise_nc(
1266 sigmoid_op,
1267 batch_size, input, output,
1268 2 /* log2(sizeof(float)) */,
1269 2 /* log2(sizeof(float)) */,
1270 &sigmoid_op->params.f32_sigmoid, sizeof(sigmoid_op->params.f32_sigmoid),
1271 pthreadpool_get_threads_count(threadpool));
1272 }
1273
xnn_setup_square_nc_f32(xnn_operator_t square_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1274 enum xnn_status xnn_setup_square_nc_f32(
1275 xnn_operator_t square_op,
1276 size_t batch_size,
1277 const float* input,
1278 float* output,
1279 pthreadpool_t threadpool)
1280 {
1281 if (square_op->type != xnn_operator_type_square_nc_f32) {
1282 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1283 xnn_operator_type_to_string(xnn_operator_type_square_nc_f32),
1284 xnn_operator_type_to_string(square_op->type));
1285 return xnn_status_invalid_parameter;
1286 }
1287 square_op->state = xnn_run_state_invalid;
1288
1289 return setup_unary_elementwise_nc(
1290 square_op,
1291 batch_size, input, output,
1292 2 /* log2(sizeof(float)) */,
1293 2 /* log2(sizeof(float)) */,
1294 &square_op->params.f32_default, sizeof(square_op->params.f32_default),
1295 pthreadpool_get_threads_count(threadpool));
1296 }
1297
xnn_setup_square_root_nc_f32(xnn_operator_t sqrt_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1298 enum xnn_status xnn_setup_square_root_nc_f32(
1299 xnn_operator_t sqrt_op,
1300 size_t batch_size,
1301 const float* input,
1302 float* output,
1303 pthreadpool_t threadpool)
1304 {
1305 if (sqrt_op->type != xnn_operator_type_square_root_nc_f32) {
1306 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1307 xnn_operator_type_to_string(xnn_operator_type_square_root_nc_f32),
1308 xnn_operator_type_to_string(sqrt_op->type));
1309 return xnn_status_invalid_parameter;
1310 }
1311 sqrt_op->state = xnn_run_state_invalid;
1312
1313 return setup_unary_elementwise_nc(
1314 sqrt_op,
1315 batch_size, input, output,
1316 2 /* log2(sizeof(float)) */,
1317 2 /* log2(sizeof(float)) */,
1318 &sqrt_op->params.f32_sqrt, sizeof(sqrt_op->params.f32_sqrt),
1319 pthreadpool_get_threads_count(threadpool));
1320 }
1321
xnn_setup_truncation_nc_f32(xnn_operator_t truncation_op,size_t batch_size,const float * input,float * output,pthreadpool_t threadpool)1322 enum xnn_status xnn_setup_truncation_nc_f32(
1323 xnn_operator_t truncation_op,
1324 size_t batch_size,
1325 const float* input,
1326 float* output,
1327 pthreadpool_t threadpool)
1328 {
1329 if (truncation_op->type != xnn_operator_type_truncation_nc_f32) {
1330 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1331 xnn_operator_type_to_string(xnn_operator_type_truncation_nc_f32),
1332 xnn_operator_type_to_string(truncation_op->type));
1333 return xnn_status_invalid_parameter;
1334 }
1335 truncation_op->state = xnn_run_state_invalid;
1336
1337 return setup_unary_elementwise_nc(
1338 truncation_op,
1339 batch_size, input, output,
1340 2 /* log2(sizeof(float)) */,
1341 2 /* log2(sizeof(float)) */,
1342 &truncation_op->params.f32_rnd, sizeof(truncation_op->params.f32_rnd),
1343 pthreadpool_get_threads_count(threadpool));
1344 }
1345