1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/log.h>
16
17
18 typedef float (*xnn_lut_init_fn)(float, const void*);
19
create_lut_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,int32_t input_zero_point,float input_scale,int32_t input_min,long output_zero_point,float output_scale,long output_min,long output_max,uint32_t flags,xnn_lut_init_fn init_fn,const void * init_params,enum xnn_operator_type operator_type,xnn_operator_t * lut_elementwise_op_out)20 static enum xnn_status create_lut_elementwise_nc(
21 size_t channels,
22 size_t input_stride,
23 size_t output_stride,
24 int32_t input_zero_point,
25 float input_scale,
26 int32_t input_min,
27 long output_zero_point,
28 float output_scale,
29 long output_min,
30 long output_max,
31 uint32_t flags,
32 xnn_lut_init_fn init_fn,
33 const void* init_params,
34 enum xnn_operator_type operator_type,
35 xnn_operator_t* lut_elementwise_op_out)
36 {
37 xnn_operator_t lut_elementwise_op = NULL;
38 enum xnn_status status = xnn_status_uninitialized;
39
40 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
41 xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
42 xnn_operator_type_to_string(operator_type));
43 goto error;
44 }
45
46 status = xnn_status_invalid_parameter;
47
48 if (channels == 0) {
49 xnn_log_error(
50 "failed to create %s operator with %zu channels: number of channels must be non-zero",
51 xnn_operator_type_to_string(operator_type), channels);
52 goto error;
53 }
54
55 if (input_stride < channels) {
56 xnn_log_error(
57 "failed to create %s operator with input element stride of %zu: "
58 "stride must be at least as large as the number of channels (%zu)",
59 xnn_operator_type_to_string(operator_type), input_stride, channels);
60 goto error;
61 }
62
63 if (output_stride < channels) {
64 xnn_log_error(
65 "failed to create %s operator with output element stride of %zu: "
66 "stride must be at least as large as the number of channels (%zu)",
67 xnn_operator_type_to_string(operator_type), output_stride, channels);
68 goto error;
69 }
70
71 if (input_scale <= 0.0f || !isnormal(input_scale)) {
72 xnn_log_error(
73 "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
74 xnn_operator_type_to_string(operator_type), input_scale);
75 goto error;
76 }
77
78 if (output_scale <= 0.0f || !isnormal(output_scale)) {
79 xnn_log_error(
80 "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
81 xnn_operator_type_to_string(operator_type), output_scale);
82 goto error;
83 }
84
85 if (output_min >= output_max) {
86 xnn_log_error(
87 "failed to create %s operator with [%ld, %ld] output range: range min must be below range max",
88 xnn_operator_type_to_string(operator_type), output_min, output_max);
89 goto error;
90 }
91
92 status = xnn_status_out_of_memory;
93
94 lut_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
95 if (lut_elementwise_op == NULL) {
96 xnn_log_error(
97 "failed to allocate %zu bytes for %s operator descriptor",
98 sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
99 goto error;
100 }
101
102 lut_elementwise_op->lookup_table = xnn_allocate_simd_memory(256 * sizeof(uint8_t));
103 if (lut_elementwise_op->lookup_table == NULL) {
104 xnn_log_error(
105 "failed to allocate 256 bytes for %s operator lookup table",
106 xnn_operator_type_to_string(operator_type));
107 goto error;
108 }
109
110 uint8_t* lookup_table = lut_elementwise_op->lookup_table;
111 const float inv_output_scale = 1.0f / output_scale;
112 for (int32_t i = input_min; i < input_min + 256; i++) {
113 const float dequantized_input = (i - input_zero_point) * input_scale;
114 const float dequantized_output = init_fn(dequantized_input, init_params);
115 long quantized_output = lrintf(dequantized_output * inv_output_scale) + output_zero_point;
116 quantized_output = XNN_UNPREDICTABLE(quantized_output < output_min) ? output_min : quantized_output;
117 quantized_output = XNN_UNPREDICTABLE(quantized_output > output_max) ? output_max : quantized_output;
118 lookup_table[(uint8_t) i] = (uint8_t) quantized_output;
119 }
120
121 lut_elementwise_op->channels = channels;
122 lut_elementwise_op->input_pixel_stride = input_stride;
123 lut_elementwise_op->output_pixel_stride = output_stride;
124
125 lut_elementwise_op->type = operator_type;
126 lut_elementwise_op->flags = flags;
127
128 lut_elementwise_op->state = xnn_run_state_invalid;
129
130 *lut_elementwise_op_out = lut_elementwise_op;
131 return xnn_status_success;
132
133 error:
134 xnn_delete_operator(lut_elementwise_op);
135 return status;
136 }
137
calculate_elu(float x,const float * alpha_ptr)138 static float calculate_elu(float x, const float* alpha_ptr) {
139 const float alpha = *alpha_ptr;
140 return signbit(x) ? alpha * expm1f(x) : x;
141 }
142
xnn_create_elu_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,float alpha,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * elu_op_out)143 enum xnn_status xnn_create_elu_nc_qs8(
144 size_t channels,
145 size_t input_stride,
146 size_t output_stride,
147 float alpha,
148 int8_t input_zero_point,
149 float input_scale,
150 int8_t output_zero_point,
151 float output_scale,
152 int8_t output_min,
153 int8_t output_max,
154 uint32_t flags,
155 xnn_operator_t* elu_op_out)
156 {
157 if (alpha <= 0.0f || !isnormal(alpha)) {
158 xnn_log_error(
159 "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
160 xnn_operator_type_to_string(xnn_operator_type_elu_nc_qs8), alpha);
161 return xnn_status_invalid_parameter;
162 }
163
164 return create_lut_elementwise_nc(
165 channels, input_stride, output_stride,
166 (int32_t) input_zero_point, input_scale, INT8_MIN,
167 (long) output_zero_point, output_scale,
168 (long) output_min, (long) output_max,
169 flags,
170 (xnn_lut_init_fn) &calculate_elu, &alpha,
171 xnn_operator_type_elu_nc_qs8, elu_op_out);
172 }
173
calculate_leaky_relu(float x,const float * negative_slope_ptr)174 static float calculate_leaky_relu(float x, const float* negative_slope_ptr) {
175 const float negative_slope = *negative_slope_ptr;
176 return signbit(x) ? x * negative_slope : x;
177 }
178
xnn_create_leaky_relu_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,float negative_slope,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * leaky_relu_op_out)179 enum xnn_status xnn_create_leaky_relu_nc_qu8(
180 size_t channels,
181 size_t input_stride,
182 size_t output_stride,
183 float negative_slope,
184 uint8_t input_zero_point,
185 float input_scale,
186 uint8_t output_zero_point,
187 float output_scale,
188 uint8_t output_min,
189 uint8_t output_max,
190 uint32_t flags,
191 xnn_operator_t* leaky_relu_op_out)
192 {
193 if (negative_slope <= 0.0f || !isnormal(negative_slope)) {
194 xnn_log_error(
195 "failed to create %s operator with %.7g negative slope: slope must be finite, normalized, and positive",
196 xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), negative_slope);
197 return xnn_status_invalid_parameter;
198 }
199
200 if (negative_slope > 1.0f) {
201 xnn_log_error(
202 "failed to create %s operator with %.7g negative slope: slope must not exceed 1.0",
203 xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), negative_slope);
204 return xnn_status_invalid_parameter;
205 }
206
207 const float input_output_scale = input_scale / output_scale;
208 if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
209 xnn_log_error(
210 "failed to create %s operator with %.7g input-to-output scale ratio: "
211 "scale ratio must be in [2**-8, 2**8) range",
212 xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), input_output_scale);
213 return xnn_status_unsupported_parameter;
214 }
215
216 return create_lut_elementwise_nc(
217 channels, input_stride, output_stride,
218 (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
219 (long) (unsigned long) output_zero_point, output_scale,
220 (long) (unsigned long) output_min, (long) (unsigned long) output_max,
221 flags,
222 (xnn_lut_init_fn) &calculate_leaky_relu, &negative_slope,
223 xnn_operator_type_leaky_relu_nc_qu8, leaky_relu_op_out);
224 }
225
calculate_sigmoid(float x,const void * params)226 static float calculate_sigmoid(float x, const void* params) {
227 return signbit(x) ? 1.0f / (1.0f + expf(-x)) : 1.0f - 1.0f / (1.0f + expf(x));
228 }
229
xnn_create_sigmoid_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)230 enum xnn_status xnn_create_sigmoid_nc_qs8(
231 size_t channels,
232 size_t input_stride,
233 size_t output_stride,
234 int8_t input_zero_point,
235 float input_scale,
236 int8_t output_zero_point,
237 float output_scale,
238 int8_t output_min,
239 int8_t output_max,
240 uint32_t flags,
241 xnn_operator_t* sigmoid_op_out)
242 {
243 if (output_scale != 0x1.0p-8f) {
244 xnn_log_error(
245 "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
246 xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_scale);
247 return xnn_status_unsupported_parameter;
248 }
249
250 if (output_zero_point != -128) {
251 xnn_log_error(
252 "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of -128 is supported",
253 xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_zero_point);
254 return xnn_status_unsupported_parameter;
255 }
256
257 return create_lut_elementwise_nc(
258 channels, input_stride, output_stride,
259 (int32_t) input_zero_point, input_scale, INT8_MIN,
260 (long) output_zero_point, output_scale,
261 (long) output_min, (long) output_max,
262 flags,
263 (xnn_lut_init_fn) &calculate_sigmoid, NULL,
264 xnn_operator_type_sigmoid_nc_qs8, sigmoid_op_out);
265 }
266
xnn_create_sigmoid_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)267 enum xnn_status xnn_create_sigmoid_nc_qu8(
268 size_t channels,
269 size_t input_stride,
270 size_t output_stride,
271 uint8_t input_zero_point,
272 float input_scale,
273 uint8_t output_zero_point,
274 float output_scale,
275 uint8_t output_min,
276 uint8_t output_max,
277 uint32_t flags,
278 xnn_operator_t* sigmoid_op_out)
279 {
280 if (output_scale != 0x1.0p-8f) {
281 xnn_log_error(
282 "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
283 xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_scale);
284 return xnn_status_unsupported_parameter;
285 }
286
287 if (output_zero_point != 0) {
288 xnn_log_error(
289 "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
290 xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_zero_point);
291 return xnn_status_unsupported_parameter;
292 }
293
294 return create_lut_elementwise_nc(
295 channels, input_stride, output_stride,
296 (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
297 (long) (unsigned long) output_zero_point, output_scale,
298 (long) (unsigned long) output_min, (long) (unsigned long) output_max,
299 flags,
300 (xnn_lut_init_fn) &calculate_sigmoid, NULL,
301 xnn_operator_type_sigmoid_nc_qu8, sigmoid_op_out);
302 }
303
calculate_tanh(float x,const void * params)304 static float calculate_tanh(float x, const void* params) {
305 return tanhf(x);
306 }
307
xnn_create_tanh_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)308 enum xnn_status xnn_create_tanh_nc_qs8(
309 size_t channels,
310 size_t input_stride,
311 size_t output_stride,
312 int8_t input_zero_point,
313 float input_scale,
314 int8_t output_zero_point,
315 float output_scale,
316 int8_t output_min,
317 int8_t output_max,
318 uint32_t flags,
319 xnn_operator_t* tanh_op_out)
320 {
321 if (output_scale != 0x1.0p-7f) {
322 xnn_log_error(
323 "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
324 xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_scale);
325 return xnn_status_unsupported_parameter;
326 }
327
328 if (output_zero_point != 0) {
329 xnn_log_error(
330 "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
331 xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_zero_point);
332 return xnn_status_unsupported_parameter;
333 }
334
335 return create_lut_elementwise_nc(
336 channels, input_stride, output_stride,
337 (int32_t) input_zero_point, input_scale, INT8_MIN,
338 (long) output_zero_point, output_scale,
339 (long) output_min, (long) output_max,
340 flags,
341 (xnn_lut_init_fn) &calculate_tanh, NULL,
342 xnn_operator_type_tanh_nc_qs8, tanh_op_out);
343 }
344
xnn_create_tanh_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)345 enum xnn_status xnn_create_tanh_nc_qu8(
346 size_t channels,
347 size_t input_stride,
348 size_t output_stride,
349 uint8_t input_zero_point,
350 float input_scale,
351 uint8_t output_zero_point,
352 float output_scale,
353 uint8_t output_min,
354 uint8_t output_max,
355 uint32_t flags,
356 xnn_operator_t* tanh_op_out)
357 {
358 if (output_scale != 0x1.0p-7f) {
359 xnn_log_error(
360 "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
361 xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_scale);
362 return xnn_status_unsupported_parameter;
363 }
364
365 if (output_zero_point != 128) {
366 xnn_log_error(
367 "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 128 is supported",
368 xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_zero_point);
369 return xnn_status_unsupported_parameter;
370 }
371
372 return create_lut_elementwise_nc(
373 channels, input_stride, output_stride,
374 (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
375 (long) (unsigned long) output_zero_point, output_scale,
376 (long) (unsigned long) output_min, (long) (unsigned long) output_max,
377 flags,
378 (xnn_lut_init_fn) &calculate_tanh, NULL,
379 xnn_operator_type_tanh_nc_qu8, tanh_op_out);
380 }
381
setup_lut_elementwise_nc(xnn_operator_t lut_elementwise_op,enum xnn_operator_type expected_operator_type,size_t batch_size,const void * input,void * output)382 static enum xnn_status setup_lut_elementwise_nc(
383 xnn_operator_t lut_elementwise_op,
384 enum xnn_operator_type expected_operator_type,
385 size_t batch_size,
386 const void* input,
387 void* output)
388 {
389 if (lut_elementwise_op->type != expected_operator_type) {
390 xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
391 xnn_operator_type_to_string(expected_operator_type),
392 xnn_operator_type_to_string(lut_elementwise_op->type));
393 return xnn_status_invalid_parameter;
394 }
395 lut_elementwise_op->state = xnn_run_state_invalid;
396
397 if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
398 xnn_log_error(
399 "failed to setup %s operator: XNNPACK is not initialized",
400 xnn_operator_type_to_string(expected_operator_type));
401 return xnn_status_uninitialized;
402 }
403
404 if (batch_size == 0) {
405 lut_elementwise_op->state = xnn_run_state_skip;
406 return xnn_status_success;
407 }
408
409 const size_t channels = lut_elementwise_op->channels;
410 const size_t input_stride = lut_elementwise_op->input_pixel_stride;
411 const size_t output_stride = lut_elementwise_op->output_pixel_stride;
412 if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
413 const size_t block_size = 1024;
414 lut_elementwise_op->context.lut_contiguous = (struct lut_contiguous_context) {
415 .x = input,
416 .x_stride = input_stride * sizeof(uint8_t),
417 .t = lut_elementwise_op->lookup_table,
418 .y = output,
419 .y_stride = output_stride * sizeof(uint8_t),
420 .ukernel = xnn_params.x8.lut,
421 };
422 lut_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
423 lut_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_lut_contiguous;
424 lut_elementwise_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
425 lut_elementwise_op->compute.tile[0] = block_size;
426 } else {
427 lut_elementwise_op->context.lut_strided = (struct lut_strided_context) {
428 .n = channels,
429 .x = input,
430 .x_stride = input_stride * sizeof(uint8_t),
431 .t = lut_elementwise_op->lookup_table,
432 .y = output,
433 .y_stride = output_stride * sizeof(uint8_t),
434 .ukernel = xnn_params.x8.lut,
435 };
436 lut_elementwise_op->compute.type = xnn_parallelization_type_1d;
437 lut_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_lut_strided;
438 lut_elementwise_op->compute.range[0] = batch_size;
439 lut_elementwise_op->compute.tile[0] = 0;
440 }
441 lut_elementwise_op->state = xnn_run_state_ready;
442
443 return xnn_status_success;
444 }
445
xnn_setup_elu_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)446 enum xnn_status xnn_setup_elu_nc_qs8(
447 xnn_operator_t sigmoid_op,
448 size_t batch_size,
449 const int8_t* input,
450 int8_t* output,
451 pthreadpool_t threadpool)
452 {
453 return setup_lut_elementwise_nc(
454 sigmoid_op, xnn_operator_type_elu_nc_qs8,
455 batch_size, input, output);
456 }
457
xnn_setup_leaky_relu_nc_qu8(xnn_operator_t leaky_relu_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)458 enum xnn_status xnn_setup_leaky_relu_nc_qu8(
459 xnn_operator_t leaky_relu_op,
460 size_t batch_size,
461 const uint8_t* input,
462 uint8_t* output,
463 pthreadpool_t threadpool)
464 {
465 return setup_lut_elementwise_nc(
466 leaky_relu_op, xnn_operator_type_leaky_relu_nc_qu8,
467 batch_size, input, output);
468 }
469
xnn_setup_sigmoid_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)470 enum xnn_status xnn_setup_sigmoid_nc_qs8(
471 xnn_operator_t sigmoid_op,
472 size_t batch_size,
473 const int8_t* input,
474 int8_t* output,
475 pthreadpool_t threadpool)
476 {
477 return setup_lut_elementwise_nc(
478 sigmoid_op, xnn_operator_type_sigmoid_nc_qs8,
479 batch_size, input, output);
480 }
481
xnn_setup_sigmoid_nc_qu8(xnn_operator_t sigmoid_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)482 enum xnn_status xnn_setup_sigmoid_nc_qu8(
483 xnn_operator_t sigmoid_op,
484 size_t batch_size,
485 const uint8_t* input,
486 uint8_t* output,
487 pthreadpool_t threadpool)
488 {
489 return setup_lut_elementwise_nc(
490 sigmoid_op, xnn_operator_type_sigmoid_nc_qu8,
491 batch_size, input, output);
492 }
493
xnn_setup_tanh_nc_qs8(xnn_operator_t tanh_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)494 enum xnn_status xnn_setup_tanh_nc_qs8(
495 xnn_operator_t tanh_op,
496 size_t batch_size,
497 const int8_t* input,
498 int8_t* output,
499 pthreadpool_t threadpool)
500 {
501 return setup_lut_elementwise_nc(
502 tanh_op, xnn_operator_type_tanh_nc_qs8,
503 batch_size, input, output);
504 }
505
xnn_setup_tanh_nc_qu8(xnn_operator_t tanh_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)506 enum xnn_status xnn_setup_tanh_nc_qu8(
507 xnn_operator_t tanh_op,
508 size_t batch_size,
509 const uint8_t* input,
510 uint8_t* output,
511 pthreadpool_t threadpool)
512 {
513 return setup_lut_elementwise_nc(
514 tanh_op, xnn_operator_type_tanh_nc_qu8,
515 batch_size, input, output);
516 }
517