• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <stdlib.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/log.h>
16 
17 
18 typedef float (*xnn_lut_init_fn)(float, const void*);
19 
create_lut_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,int32_t input_zero_point,float input_scale,int32_t input_min,long output_zero_point,float output_scale,long output_min,long output_max,uint32_t flags,xnn_lut_init_fn init_fn,const void * init_params,enum xnn_operator_type operator_type,xnn_operator_t * lut_elementwise_op_out)20 static enum xnn_status create_lut_elementwise_nc(
21     size_t channels,
22     size_t input_stride,
23     size_t output_stride,
24     int32_t input_zero_point,
25     float input_scale,
26     int32_t input_min,
27     long output_zero_point,
28     float output_scale,
29     long output_min,
30     long output_max,
31     uint32_t flags,
32     xnn_lut_init_fn init_fn,
33     const void* init_params,
34     enum xnn_operator_type operator_type,
35     xnn_operator_t* lut_elementwise_op_out)
36 {
37   xnn_operator_t lut_elementwise_op = NULL;
38   enum xnn_status status = xnn_status_uninitialized;
39 
40   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
41     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
42       xnn_operator_type_to_string(operator_type));
43     goto error;
44   }
45 
46   status = xnn_status_invalid_parameter;
47 
48   if (channels == 0) {
49     xnn_log_error(
50       "failed to create %s operator with %zu channels: number of channels must be non-zero",
51       xnn_operator_type_to_string(operator_type), channels);
52     goto error;
53   }
54 
55   if (input_stride < channels) {
56     xnn_log_error(
57       "failed to create %s operator with input element stride of %zu: "
58       "stride must be at least as large as the number of channels (%zu)",
59       xnn_operator_type_to_string(operator_type), input_stride, channels);
60     goto error;
61   }
62 
63   if (output_stride < channels) {
64     xnn_log_error(
65       "failed to create %s operator with output element stride of %zu: "
66       "stride must be at least as large as the number of channels (%zu)",
67       xnn_operator_type_to_string(operator_type), output_stride, channels);
68     goto error;
69   }
70 
71   if (input_scale <= 0.0f || !isnormal(input_scale)) {
72     xnn_log_error(
73       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
74       xnn_operator_type_to_string(operator_type), input_scale);
75     goto error;
76   }
77 
78   if (output_scale <= 0.0f || !isnormal(output_scale)) {
79     xnn_log_error(
80       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
81       xnn_operator_type_to_string(operator_type), output_scale);
82     goto error;
83   }
84 
85   if (output_min >= output_max) {
86     xnn_log_error(
87       "failed to create %s operator with [%ld, %ld] output range: range min must be below range max",
88       xnn_operator_type_to_string(operator_type), output_min, output_max);
89     goto error;
90   }
91 
92   status = xnn_status_out_of_memory;
93 
94   lut_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
95   if (lut_elementwise_op == NULL) {
96     xnn_log_error(
97       "failed to allocate %zu bytes for %s operator descriptor",
98       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
99     goto error;
100   }
101 
102   lut_elementwise_op->lookup_table = xnn_allocate_simd_memory(256 * sizeof(uint8_t));
103   if (lut_elementwise_op->lookup_table == NULL) {
104     xnn_log_error(
105       "failed to allocate 256 bytes for %s operator lookup table",
106       xnn_operator_type_to_string(operator_type));
107     goto error;
108   }
109 
110   uint8_t* lookup_table = lut_elementwise_op->lookup_table;
111   const float inv_output_scale = 1.0f / output_scale;
112   for (int32_t i = input_min; i < input_min + 256; i++) {
113     const float dequantized_input = (i - input_zero_point) * input_scale;
114     const float dequantized_output = init_fn(dequantized_input, init_params);
115     long quantized_output = lrintf(dequantized_output * inv_output_scale) + output_zero_point;
116     quantized_output = XNN_UNPREDICTABLE(quantized_output < output_min) ? output_min : quantized_output;
117     quantized_output = XNN_UNPREDICTABLE(quantized_output > output_max) ? output_max : quantized_output;
118     lookup_table[(uint8_t) i] = (uint8_t) quantized_output;
119   }
120 
121   lut_elementwise_op->channels = channels;
122   lut_elementwise_op->input_pixel_stride = input_stride;
123   lut_elementwise_op->output_pixel_stride = output_stride;
124 
125   lut_elementwise_op->type = operator_type;
126   lut_elementwise_op->flags = flags;
127 
128   lut_elementwise_op->state = xnn_run_state_invalid;
129 
130   *lut_elementwise_op_out = lut_elementwise_op;
131   return xnn_status_success;
132 
133 error:
134   xnn_delete_operator(lut_elementwise_op);
135   return status;
136 }
137 
calculate_elu(float x,const float * alpha_ptr)138 static float calculate_elu(float x, const float* alpha_ptr) {
139   const float alpha = *alpha_ptr;
140   return signbit(x) ? alpha * expm1f(x) : x;
141 }
142 
xnn_create_elu_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,float alpha,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * elu_op_out)143 enum xnn_status xnn_create_elu_nc_qs8(
144     size_t channels,
145     size_t input_stride,
146     size_t output_stride,
147     float alpha,
148     int8_t input_zero_point,
149     float input_scale,
150     int8_t output_zero_point,
151     float output_scale,
152     int8_t output_min,
153     int8_t output_max,
154     uint32_t flags,
155     xnn_operator_t* elu_op_out)
156 {
157   if (alpha <= 0.0f || !isnormal(alpha)) {
158     xnn_log_error(
159       "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
160       xnn_operator_type_to_string(xnn_operator_type_elu_nc_qs8), alpha);
161     return xnn_status_invalid_parameter;
162   }
163 
164   return create_lut_elementwise_nc(
165     channels, input_stride, output_stride,
166     (int32_t) input_zero_point, input_scale, INT8_MIN,
167     (long) output_zero_point, output_scale,
168     (long) output_min, (long) output_max,
169     flags,
170     (xnn_lut_init_fn) &calculate_elu, &alpha,
171     xnn_operator_type_elu_nc_qs8, elu_op_out);
172 }
173 
calculate_leaky_relu(float x,const float * negative_slope_ptr)174 static float calculate_leaky_relu(float x, const float* negative_slope_ptr) {
175   const float negative_slope = *negative_slope_ptr;
176   return signbit(x) ? x * negative_slope : x;
177 }
178 
xnn_create_leaky_relu_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,float negative_slope,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * leaky_relu_op_out)179 enum xnn_status xnn_create_leaky_relu_nc_qu8(
180     size_t channels,
181     size_t input_stride,
182     size_t output_stride,
183     float negative_slope,
184     uint8_t input_zero_point,
185     float input_scale,
186     uint8_t output_zero_point,
187     float output_scale,
188     uint8_t output_min,
189     uint8_t output_max,
190     uint32_t flags,
191     xnn_operator_t* leaky_relu_op_out)
192 {
193   if (negative_slope <= 0.0f || !isnormal(negative_slope)) {
194     xnn_log_error(
195       "failed to create %s operator with %.7g negative slope: slope must be finite, normalized, and positive",
196       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), negative_slope);
197     return xnn_status_invalid_parameter;
198   }
199 
200   if (negative_slope > 1.0f) {
201     xnn_log_error(
202       "failed to create %s operator with %.7g negative slope: slope must not exceed 1.0",
203       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), negative_slope);
204     return xnn_status_invalid_parameter;
205   }
206 
207   const float input_output_scale = input_scale / output_scale;
208   if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) {
209     xnn_log_error(
210       "failed to create %s operator with %.7g input-to-output scale ratio: "
211       "scale ratio must be in [2**-8, 2**8) range",
212       xnn_operator_type_to_string(xnn_operator_type_leaky_relu_nc_qu8), input_output_scale);
213     return xnn_status_unsupported_parameter;
214   }
215 
216   return create_lut_elementwise_nc(
217     channels, input_stride, output_stride,
218     (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
219     (long) (unsigned long) output_zero_point, output_scale,
220     (long) (unsigned long) output_min, (long) (unsigned long) output_max,
221     flags,
222     (xnn_lut_init_fn) &calculate_leaky_relu, &negative_slope,
223     xnn_operator_type_leaky_relu_nc_qu8, leaky_relu_op_out);
224 }
225 
calculate_sigmoid(float x,const void * params)226 static float calculate_sigmoid(float x, const void* params) {
227   return signbit(x) ? 1.0f / (1.0f + expf(-x)) : 1.0f - 1.0f / (1.0f + expf(x));
228 }
229 
xnn_create_sigmoid_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)230 enum xnn_status xnn_create_sigmoid_nc_qs8(
231     size_t channels,
232     size_t input_stride,
233     size_t output_stride,
234     int8_t input_zero_point,
235     float input_scale,
236     int8_t output_zero_point,
237     float output_scale,
238     int8_t output_min,
239     int8_t output_max,
240     uint32_t flags,
241     xnn_operator_t* sigmoid_op_out)
242 {
243   if (output_scale != 0x1.0p-8f) {
244     xnn_log_error(
245       "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
246       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_scale);
247     return xnn_status_unsupported_parameter;
248   }
249 
250   if (output_zero_point != -128) {
251     xnn_log_error(
252       "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of -128 is supported",
253       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_zero_point);
254     return xnn_status_unsupported_parameter;
255   }
256 
257   return create_lut_elementwise_nc(
258     channels, input_stride, output_stride,
259     (int32_t) input_zero_point, input_scale, INT8_MIN,
260     (long) output_zero_point, output_scale,
261     (long) output_min, (long) output_max,
262     flags,
263     (xnn_lut_init_fn) &calculate_sigmoid, NULL,
264     xnn_operator_type_sigmoid_nc_qs8, sigmoid_op_out);
265 }
266 
xnn_create_sigmoid_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)267 enum xnn_status xnn_create_sigmoid_nc_qu8(
268     size_t channels,
269     size_t input_stride,
270     size_t output_stride,
271     uint8_t input_zero_point,
272     float input_scale,
273     uint8_t output_zero_point,
274     float output_scale,
275     uint8_t output_min,
276     uint8_t output_max,
277     uint32_t flags,
278     xnn_operator_t* sigmoid_op_out)
279 {
280   if (output_scale != 0x1.0p-8f) {
281     xnn_log_error(
282       "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
283       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_scale);
284     return xnn_status_unsupported_parameter;
285   }
286 
287   if (output_zero_point != 0) {
288     xnn_log_error(
289       "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
290       xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_zero_point);
291     return xnn_status_unsupported_parameter;
292   }
293 
294   return create_lut_elementwise_nc(
295     channels, input_stride, output_stride,
296     (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
297     (long) (unsigned long) output_zero_point, output_scale,
298     (long) (unsigned long) output_min, (long) (unsigned long) output_max,
299     flags,
300     (xnn_lut_init_fn) &calculate_sigmoid, NULL,
301     xnn_operator_type_sigmoid_nc_qu8, sigmoid_op_out);
302 }
303 
calculate_tanh(float x,const void * params)304 static float calculate_tanh(float x, const void* params) {
305   return tanhf(x);
306 }
307 
xnn_create_tanh_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)308 enum xnn_status xnn_create_tanh_nc_qs8(
309     size_t channels,
310     size_t input_stride,
311     size_t output_stride,
312     int8_t input_zero_point,
313     float input_scale,
314     int8_t output_zero_point,
315     float output_scale,
316     int8_t output_min,
317     int8_t output_max,
318     uint32_t flags,
319     xnn_operator_t* tanh_op_out)
320 {
321   if (output_scale != 0x1.0p-7f) {
322     xnn_log_error(
323       "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
324       xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_scale);
325     return xnn_status_unsupported_parameter;
326   }
327 
328   if (output_zero_point != 0) {
329     xnn_log_error(
330       "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
331       xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_zero_point);
332     return xnn_status_unsupported_parameter;
333   }
334 
335   return create_lut_elementwise_nc(
336     channels, input_stride, output_stride,
337     (int32_t) input_zero_point, input_scale, INT8_MIN,
338     (long) output_zero_point, output_scale,
339     (long) output_min, (long) output_max,
340     flags,
341     (xnn_lut_init_fn) &calculate_tanh, NULL,
342     xnn_operator_type_tanh_nc_qs8, tanh_op_out);
343 }
344 
xnn_create_tanh_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)345 enum xnn_status xnn_create_tanh_nc_qu8(
346     size_t channels,
347     size_t input_stride,
348     size_t output_stride,
349     uint8_t input_zero_point,
350     float input_scale,
351     uint8_t output_zero_point,
352     float output_scale,
353     uint8_t output_min,
354     uint8_t output_max,
355     uint32_t flags,
356     xnn_operator_t* tanh_op_out)
357 {
358   if (output_scale != 0x1.0p-7f) {
359     xnn_log_error(
360       "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
361       xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_scale);
362     return xnn_status_unsupported_parameter;
363   }
364 
365   if (output_zero_point != 128) {
366     xnn_log_error(
367       "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 128 is supported",
368       xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_zero_point);
369     return xnn_status_unsupported_parameter;
370   }
371 
372   return create_lut_elementwise_nc(
373     channels, input_stride, output_stride,
374     (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
375     (long) (unsigned long) output_zero_point, output_scale,
376     (long) (unsigned long) output_min, (long) (unsigned long) output_max,
377     flags,
378     (xnn_lut_init_fn) &calculate_tanh, NULL,
379     xnn_operator_type_tanh_nc_qu8, tanh_op_out);
380 }
381 
setup_lut_elementwise_nc(xnn_operator_t lut_elementwise_op,enum xnn_operator_type expected_operator_type,size_t batch_size,const void * input,void * output)382 static enum xnn_status setup_lut_elementwise_nc(
383     xnn_operator_t lut_elementwise_op,
384     enum xnn_operator_type expected_operator_type,
385     size_t batch_size,
386     const void* input,
387     void* output)
388 {
389   if (lut_elementwise_op->type != expected_operator_type) {
390     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
391       xnn_operator_type_to_string(expected_operator_type),
392       xnn_operator_type_to_string(lut_elementwise_op->type));
393     return xnn_status_invalid_parameter;
394   }
395   lut_elementwise_op->state = xnn_run_state_invalid;
396 
397   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
398     xnn_log_error(
399       "failed to setup %s operator: XNNPACK is not initialized",
400       xnn_operator_type_to_string(expected_operator_type));
401     return xnn_status_uninitialized;
402   }
403 
404   if (batch_size == 0) {
405     lut_elementwise_op->state = xnn_run_state_skip;
406     return xnn_status_success;
407   }
408 
409   const size_t channels = lut_elementwise_op->channels;
410   const size_t input_stride = lut_elementwise_op->input_pixel_stride;
411   const size_t output_stride = lut_elementwise_op->output_pixel_stride;
412   if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
413     const size_t block_size = 1024;
414     lut_elementwise_op->context.lut_contiguous = (struct lut_contiguous_context) {
415       .x = input,
416       .x_stride = input_stride * sizeof(uint8_t),
417       .t = lut_elementwise_op->lookup_table,
418       .y = output,
419       .y_stride = output_stride * sizeof(uint8_t),
420       .ukernel = xnn_params.x8.lut,
421     };
422     lut_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
423     lut_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_lut_contiguous;
424     lut_elementwise_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
425     lut_elementwise_op->compute.tile[0] = block_size;
426   } else {
427     lut_elementwise_op->context.lut_strided = (struct lut_strided_context) {
428       .n = channels,
429       .x = input,
430       .x_stride = input_stride * sizeof(uint8_t),
431       .t = lut_elementwise_op->lookup_table,
432       .y = output,
433       .y_stride = output_stride * sizeof(uint8_t),
434       .ukernel = xnn_params.x8.lut,
435     };
436     lut_elementwise_op->compute.type = xnn_parallelization_type_1d;
437     lut_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_lut_strided;
438     lut_elementwise_op->compute.range[0] = batch_size;
439     lut_elementwise_op->compute.tile[0] = 0;
440   }
441   lut_elementwise_op->state = xnn_run_state_ready;
442 
443   return xnn_status_success;
444 }
445 
xnn_setup_elu_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)446 enum xnn_status xnn_setup_elu_nc_qs8(
447     xnn_operator_t sigmoid_op,
448     size_t batch_size,
449     const int8_t* input,
450     int8_t* output,
451     pthreadpool_t threadpool)
452 {
453   return setup_lut_elementwise_nc(
454     sigmoid_op, xnn_operator_type_elu_nc_qs8,
455     batch_size, input, output);
456 }
457 
xnn_setup_leaky_relu_nc_qu8(xnn_operator_t leaky_relu_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)458 enum xnn_status xnn_setup_leaky_relu_nc_qu8(
459     xnn_operator_t leaky_relu_op,
460     size_t batch_size,
461     const uint8_t* input,
462     uint8_t* output,
463     pthreadpool_t threadpool)
464 {
465   return setup_lut_elementwise_nc(
466     leaky_relu_op, xnn_operator_type_leaky_relu_nc_qu8,
467     batch_size, input, output);
468 }
469 
xnn_setup_sigmoid_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)470 enum xnn_status xnn_setup_sigmoid_nc_qs8(
471     xnn_operator_t sigmoid_op,
472     size_t batch_size,
473     const int8_t* input,
474     int8_t* output,
475     pthreadpool_t threadpool)
476 {
477   return setup_lut_elementwise_nc(
478     sigmoid_op, xnn_operator_type_sigmoid_nc_qs8,
479     batch_size, input, output);
480 }
481 
xnn_setup_sigmoid_nc_qu8(xnn_operator_t sigmoid_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)482 enum xnn_status xnn_setup_sigmoid_nc_qu8(
483     xnn_operator_t sigmoid_op,
484     size_t batch_size,
485     const uint8_t* input,
486     uint8_t* output,
487     pthreadpool_t threadpool)
488 {
489   return setup_lut_elementwise_nc(
490     sigmoid_op, xnn_operator_type_sigmoid_nc_qu8,
491     batch_size, input, output);
492 }
493 
xnn_setup_tanh_nc_qs8(xnn_operator_t tanh_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)494 enum xnn_status xnn_setup_tanh_nc_qs8(
495     xnn_operator_t tanh_op,
496     size_t batch_size,
497     const int8_t* input,
498     int8_t* output,
499     pthreadpool_t threadpool)
500 {
501   return setup_lut_elementwise_nc(
502     tanh_op, xnn_operator_type_tanh_nc_qs8,
503     batch_size, input, output);
504 }
505 
xnn_setup_tanh_nc_qu8(xnn_operator_t tanh_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)506 enum xnn_status xnn_setup_tanh_nc_qu8(
507     xnn_operator_t tanh_op,
508     size_t batch_size,
509     const uint8_t* input,
510     uint8_t* output,
511     pthreadpool_t threadpool)
512 {
513   return setup_lut_elementwise_nc(
514     tanh_op, xnn_operator_type_tanh_nc_qu8,
515     batch_size, input, output);
516 }
517