// Copyright (c) Facebook, Inc. and its affiliates. // All rights reserved. // // Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #pragma once #include #include #include #include #include struct xnn_f16_output_params { uint16_t scale; uint16_t max; uint16_t min; }; union xnn_f32_output_params { struct { float max; float min; } scalar; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) float max[4]; XNN_ALIGN(16) float min[4]; } sse; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_f32_spchw_params { struct { float max; float min; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { float min; float max; XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) float max[4]; XNN_ALIGN(16) float min[4]; XNN_ALIGN(16) uint32_t mask_even[4]; // used by stride 2 kernels XNN_ALIGN(16) uint32_t mask_odd[4]; // used by stride 2 kernels XNN_ALIGN(16) uint32_t mask[4]; // used by stride 1 kernels } sse; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_u8_output_params { struct { int32_t max; int32_t min; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { uint8_t max; uint8_t min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) uint8_t max[16]; XNN_ALIGN(16) uint8_t min[16]; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_f32_avgpool_params { struct { float multiplier; float output_min; float output_max; } scalar; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) float multiplier[4]; XNN_ALIGN(16) float output_max[4]; XNN_ALIGN(16) float output_min[4]; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { XNN_ALIGN(16) float multiplier; XNN_ALIGN(16) float output_max; XNN_ALIGN(16) float output_min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 }; union xnn_f32_gavgpool_params { struct { float multiplier; float output_min; float output_max; } scalar; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) float multiplier[4]; XNN_ALIGN(16) float output_max[4]; XNN_ALIGN(16) float output_min[4]; XNN_ALIGN(16) uint32_t mask[4]; } sse; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { XNN_ALIGN(16) float multiplier; XNN_ALIGN(16) float output_max; XNN_ALIGN(16) float output_min; XNN_ALIGN(16) uint32_t mask[4]; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 */ }; union xnn_f32_hswish_params { struct { float sixth; float half; float one; } scalar; #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) float sixth[4]; XNN_ALIGN(16) float half[4]; XNN_ALIGN(16) float one[4]; } sse; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_q8_gemm_params { struct { int32_t kernel_zero_point; int32_t input_zero_point; int32_t multiplier; int32_t remainder_mask; int32_t remainder_threshold; uint32_t shift; int32_t output_min_less_zero_point; int32_t output_max_less_zero_point; int32_t output_zero_point; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { int16_t kernel_zero_point; int16_t input_zero_point; int32_t multiplier; int32_t right_shift; int16_t output_zero_point; uint8_t output_max; uint8_t output_min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) int16_t kernel_zero_point[8]; XNN_ALIGN(16) int16_t input_zero_point[8]; XNN_ALIGN(16) uint32_t multiplier[4]; XNN_ALIGN(16) uint64_t rounding[2]; XNN_ALIGN(16) int32_t remainder_mask[4]; XNN_ALIGN(16) int32_t remainder_threshold[4]; XNN_ALIGN(16) uint64_t shift[2]; XNN_ALIGN(16) int16_t output_zero_point[8]; XNN_ALIGN(16) uint8_t output_max[16]; XNN_ALIGN(16) uint8_t output_min[16]; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_q8_add_params { struct { int32_t zero_point_product; uint32_t a_multiplier; uint32_t b_multiplier; uint32_t shift; int32_t remainder_mask; int32_t remainder_threshold; int32_t y_zero_point; int32_t y_max; int32_t y_min; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { uint8_t a_zero_point; uint8_t b_zero_point; int16_t y_zero_point; int32_t a_multiplier; int32_t b_multiplier; int32_t right_shift; uint8_t y_max; uint8_t y_min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) int32_t zero_point_product[4]; XNN_ALIGN(16) uint16_t a_multiplier_lo[8]; XNN_ALIGN(16) uint16_t a_multiplier_hi[8]; XNN_ALIGN(16) uint16_t b_multiplier_lo[8]; XNN_ALIGN(16) uint16_t b_multiplier_hi[8]; XNN_ALIGN(16) int32_t remainder_mask[4]; XNN_ALIGN(16) int32_t remainder_threshold[4]; XNN_ALIGN(16) int16_t y_zero_point[8]; XNN_ALIGN(16) uint8_t y_max[16]; XNN_ALIGN(16) uint8_t y_min[16]; uint32_t shift; uint32_t a_multiplier; uint32_t b_multiplier; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_q8_avgpool_params { struct { int32_t bias; int32_t multiplier; int64_t rounding; uint32_t right_shift; int32_t output_min_less_zero_point; int32_t output_max_less_zero_point; int32_t output_zero_point; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { int32_t bias; int32_t multiplier; int64_t left_shift; int16_t output_zero_point; uint8_t output_max; uint8_t output_min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) int32_t bias[4]; XNN_ALIGN(16) uint32_t multiplier[4]; XNN_ALIGN(16) uint64_t rounding[2]; XNN_ALIGN(16) uint64_t right_shift[2]; XNN_ALIGN(16) int16_t output_zero_point[8]; XNN_ALIGN(16) uint8_t output_max[16]; XNN_ALIGN(16) uint8_t output_min[16]; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_fp32_requantization_params { struct { float scale; float min_less_zero_point; float max_less_zero_point; float magic; int32_t magic_less_zero_point; } scalar; struct { float scale; float max; float min; float magic; int32_t magic_less_zero_point; } neon; struct { float scale; int16_t zero_point; uint8_t max; uint8_t min; } neonv8; struct { XNN_ALIGN(16) float scale[4]; XNN_ALIGN(16) int16_t zero_point[8]; XNN_ALIGN(16) uint8_t max[16]; XNN_ALIGN(16) uint8_t min[16]; } sse2; struct { XNN_ALIGN(16) float scale[4]; XNN_ALIGN(16) float min_less_zero_point[4]; XNN_ALIGN(16) float max_less_zero_point[4]; XNN_ALIGN(16) float magic[4]; XNN_ALIGN(16) int32_t magic_less_zero_point[4]; } psimd; }; union xnn_precise_requantization_params { struct { uint32_t multiplier; uint32_t rounding_lo; uint32_t rounding_hi; uint32_t shift_less_32; int32_t min_less_zero_point; int32_t max_less_zero_point; int32_t zero_point; } scalar; struct { int32_t multiplier; int32_t right_shift; int16_t zero_point; uint8_t max; uint8_t min; } neon; struct { XNN_ALIGN(16) uint32_t multiplier[4]; XNN_ALIGN(16) uint64_t rounding[2]; XNN_ALIGN(16) uint32_t shift[4]; XNN_ALIGN(16) int16_t zero_point[8]; XNN_ALIGN(16) uint8_t max[16]; XNN_ALIGN(16) uint8_t min[16]; } sse2; }; union xnn_q31_requantization_params { struct { int32_t multiplier; int32_t remainder_mask; int32_t remainder_threshold; uint32_t shift; int32_t min_less_zero_point; int32_t max_less_zero_point; int32_t zero_point; } scalar; #if XNN_ARCH_ARM || XNN_ARCH_ARM64 struct { int32_t multiplier; int32_t right_shift; int16_t zero_point; uint8_t max; uint8_t min; } neon; #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 struct { XNN_ALIGN(16) uint32_t multiplier[4]; XNN_ALIGN(16) uint64_t rounding[2]; XNN_ALIGN(16) int32_t remainder_mask[4]; XNN_ALIGN(16) int32_t remainder_threshold[4]; XNN_ALIGN(16) uint64_t shift[2]; XNN_ALIGN(16) int16_t zero_point[8]; XNN_ALIGN(16) uint8_t max[16]; XNN_ALIGN(16) uint8_t min[16]; } sse2; #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 }; union xnn_requantization_params { union xnn_precise_requantization_params precise; union xnn_fp32_requantization_params fp32; union xnn_q31_requantization_params q31; }; typedef void (*xnn_ppmm_ukernel_function)( size_t mr, size_t nc, size_t kc, const void* a, const void* w, void* c, size_t cm_stride, size_t cn_stride, const void* params); typedef void (*xnn_f32_ppmm_ukernel_function)( size_t mr, size_t nc, size_t kc, const float* a, const float* w, float* c, size_t cm_stride, size_t cn_stride, const union xnn_f32_output_params* params); typedef void (*xnn_f16_ppmm_ukernel_function)( size_t mr, size_t nc, size_t kc, const void* a, const void* w, void* c, size_t cm_stride, size_t cn_stride, const struct xnn_f16_output_params* params); typedef void (*xnn_gemm_ukernel_function)( size_t mr, size_t nr, size_t k, const void* a, size_t a_stride, const void* w, void* c, size_t cm_stride, size_t cn_stride, const void* params); typedef void (*xnn_f32_gemm_ukernel_function)( size_t mr, size_t nr, size_t k, const float* a, size_t a_stride, const float* w, float* c, size_t cm_stride, size_t cn_stride, const union xnn_f32_output_params* params); typedef void (*xnn_f32_gemminc_ukernel_function)( size_t mr, size_t nr, size_t k, const float* a, size_t a_stride, const float* w, float* c, size_t cm_stride, size_t cn_stride, const float* acc, const union xnn_f32_output_params* params); typedef void (*xnn_f16_gemm_ukernel_function)( size_t mr, size_t nr, size_t k, const void* a, size_t a_stride, const void* w, void* c, size_t cm_stride, size_t cn_stride, const struct xnn_f16_output_params* params); typedef void (*xnn_q8_gemm_ukernel_function)( size_t mr, size_t nr, size_t k, const uint8_t* a, size_t a_stride, const void* w, uint8_t* c, size_t cm_stride, size_t cn_stride, const union xnn_q8_gemm_params* params); typedef void (*xnn_igemm_ukernel_function)( size_t mr, size_t nr, size_t kc, size_t ks, const void** a, const void* w, void* c, size_t cm_stride, size_t cn_stride, size_t a_offset, const void* zero, const void* params); typedef void (*xnn_f32_igemm_ukernel_function)( size_t mr, size_t nr, size_t kc, size_t ks, const float** a, const float* w, float* c, size_t cm_stride, size_t cn_stride, size_t a_offset, const float* zero, const union xnn_f32_output_params* params); typedef void (*xnn_q8_igemm_ukernel_function)( size_t mr, size_t nr, size_t kc, size_t ks, const uint8_t** a, const void* w, uint8_t* c, size_t cm_stride, size_t cn_stride, size_t a_offset, const uint8_t* zero, const union xnn_q8_gemm_params* params); typedef void (*xnn_conv_hwc_ukernel_function)( size_t input_height, size_t input_width, size_t output_y_start, size_t output_y_end, const void* input, const void* zero, const void* weights, void* output, size_t input_padding_top, size_t output_channels, size_t output_height_stride, size_t output_width_stride, const void* params); typedef void (*xnn_f32_conv_hwc_ukernel_function)( size_t input_height, size_t input_width, size_t output_y_start, size_t output_y_end, const float* input, const float* zero, const float* weights, float* output, size_t input_padding_top, size_t output_channels, size_t output_height_stride, size_t output_width_stride, const union xnn_f32_output_params* params); typedef void (*xnn_conv_hwc2spchw_ukernel_function)( size_t input_height, size_t input_width, size_t output_y_start, size_t output_y_end, const void* input, const void* zero, const void* weights, void* output, size_t input_padding_top, size_t output_channels, size_t output_height_stride, size_t output_channel_stride, const void* params); typedef void (*xnn_f32_conv_hwc2spchw_ukernel_function)( size_t input_height, size_t input_width, size_t output_y_start, size_t output_y_end, const float* input, const float* zero, const float* weights, float* output, size_t input_padding_top, size_t output_channels, size_t output_height_stride, size_t output_channel_stride, const union xnn_f32_output_params* params); typedef void (*xnn_spmm_ukernel_function)( uint32_t m, uint32_t n, const void* a, const void* w, const int32_t* dmap, const uint32_t* nmap, void* c, const void* params); typedef void (*xnn_f16_spmm_ukernel_function)( uint32_t m, uint32_t n, const void* a, const void* w, const int32_t* dmap, const uint32_t* nmap, void* c, const struct xnn_f16_output_params* params); typedef void (*xnn_f32_spmm_ukernel_function)( uint32_t m, uint32_t n, const float* a, const float* w, const int32_t* dmap, const uint32_t* nmap, float* c, const union xnn_f32_output_params* params); typedef void (*xnn_packx_ukernel_function)( size_t m, size_t k, const void* x, size_t x_stride, void* y); typedef void (*xnn_x32_packx_ukernel_function)( size_t m, size_t k, const uint32_t* x, size_t x_stride, uint32_t* y); typedef void (*xnn_pad_ukernel_function)( size_t m, size_t n, size_t l, size_t r, uint32_t c, const void* x, size_t x_stride, void* y, size_t y_stride); typedef void (*xnn_unpool_ukernel_function)( size_t p, size_t c, uint32_t f, const void* input, const uint32_t* index, void** output); typedef void (*xnn_x32_unpool_ukernel_function)( size_t p, size_t c, uint32_t f, const uint32_t* input, const uint32_t* index, uint32_t** output); typedef void (*xnn_zipc_ukernel_function)( size_t n, const void* x, void* y); typedef void (*xnn_x8_zipc_ukernel_function)( size_t n, const uint8_t* x, uint8_t* y); typedef void (*xnn_x32_zipc_ukernel_function)( size_t n, const uint32_t* x, uint32_t* y); typedef void (*xnn_zipv_ukernel_function)( size_t n, size_t m, const void* x, void* y); typedef void (*xnn_x8_zipv_ukernel_function)( size_t n, size_t m, const uint8_t* x, uint8_t* y); typedef void (*xnn_x32_zipv_ukernel_function)( size_t n, size_t m, const uint32_t* x, uint32_t* y); typedef void (*xnn_x8_lut_ukernel_function)( size_t n, const uint8_t* x, const uint8_t* t, uint8_t* y); typedef void (*xnn_dwconv_spchw_ukernel_function)( size_t output_height, size_t input_width, const void* input, const void* weights, void* output, size_t input_tuple_stride, size_t output_tuple_stride, size_t input_height_stride, size_t output_height_stride, const void* params); typedef void (*xnn_f32_dwconv_spchw_ukernel_function)( size_t output_height, size_t input_width, const float* input, const float* weights, float* output, size_t input_tuple_stride, size_t output_tuple_stride, size_t input_height_stride, size_t output_height_stride, const union xnn_f32_spchw_params* params); typedef void (*xnn_dwconv_up_ukernel_function)( size_t channels, size_t output_width, const void** input, const void* weights, void* output, size_t input_stride, size_t output_increment, const void* params); typedef void (*xnn_f32_dwconv_up_ukernel_function)( size_t channels, size_t output_width, const float** input, const float* weights, float* output, size_t input_stride, size_t output_increment, const union xnn_f32_output_params* params); typedef void (*xnn_q8_dwconv_up_ukernel_function)( size_t channels, size_t output_width, const uint8_t** input, const void* weights, uint8_t* output, size_t input_stride, size_t output_increment, const union xnn_q8_gemm_params* params); typedef void (*xnn_dwconv_mp_ukernel_function)( size_t channels, size_t output_width, const void** input, const void* weights, void* buffer, void* output, size_t input_stride, size_t output_increment, const void* params); typedef void (*xnn_f32_bilinear_ukernel_function)( size_t output_pixels, size_t channels, const float** input, size_t input_offset, const float* weights, float* output, size_t output_increment); typedef void (*xnn_bilinear_ukernel_function)( size_t output_pixels, size_t channels, const void** input, size_t input_offset, const void* weights, void* output, size_t output_increment); typedef void (*xnn_gavgpool_up_ukernel_function)( size_t m, size_t n, const void* x, size_t x_stride, const void* zero, void* y, const void* params); typedef void (*xnn_f32_gavgpool_up_ukernel_function)( size_t m, size_t n, const float* x, size_t x_stride, const float* zero, float* y, const union xnn_f32_avgpool_params* params); typedef void (*xnn_gavgpool_spchw_ukernel_function)( size_t elements, size_t channels, const float* input, float* output, const void* params); typedef void (*xnn_f32_gavgpool_spchw_ukernel_function)( size_t elements, size_t channels, const float* input, float* output, const union xnn_f32_gavgpool_params* params); typedef void (*xnn_q8_gavgpool_up_ukernel_function)( size_t m, size_t n, const uint8_t* x, size_t x_stride, const uint8_t* zero, uint8_t* y, const union xnn_q8_avgpool_params* params); typedef void (*xnn_gavgpool_mp_ukernel_function)( size_t m, size_t n, const void* x, size_t x_stride, const void* zero, void* buffer, void* y, const void* params); typedef void (*xnn_f32_gavgpool_mp_ukernel_function)( size_t m, size_t n, const float* x, size_t x_stride, const float* zero, float* buffer, float* y, const union xnn_f32_avgpool_params* params); typedef void (*xnn_q8_gavgpool_mp_ukernel_function)( size_t m, size_t n, const uint8_t* x, size_t x_stride, const uint8_t* zero, int32_t* buffer, uint8_t* y, const union xnn_q8_avgpool_params* params); typedef void (*xnn_avgpool_up_ukernel_function)( size_t n, size_t ks, size_t kc, const void** x, const void* zero, void* y, size_t x_increment, size_t y_increment, const void* params); typedef void (*xnn_f32_avgpool_up_ukernel_function)( size_t n, size_t ks, size_t kc, const float** x, const float* zero, float* y, size_t x_increment, size_t y_increment, const union xnn_f32_avgpool_params* params); typedef void (*xnn_q8_avgpool_up_ukernel_function)( size_t n, size_t ks, size_t kc, const uint8_t** x, const uint8_t* zero, uint8_t* y, size_t x_increment, size_t y_increment, const union xnn_q8_avgpool_params* params); typedef void (*xnn_avgpool_mp_ukernel_function)( size_t n, size_t ks, size_t kc, const void** x, const void* zero, void* buffer, void* y, size_t x_increment, size_t y_increment, const void* params); typedef void (*xnn_f32_avgpool_mp_ukernel_function)( size_t n, size_t ks, size_t kc, const float** x, const float* zero, float* buffer, float* y, size_t x_increment, size_t y_increment, const union xnn_f32_avgpool_params* params); typedef void (*xnn_q8_avgpool_mp_ukernel_function)( size_t n, size_t ks, size_t kc, const uint8_t** x, const uint8_t* zero, int32_t* buffer, uint8_t* y, size_t x_increment, size_t y_increment, const union xnn_q8_avgpool_params* params); typedef void (*xnn_pavgpool_up_ukernel_function)( size_t n, size_t ks, size_t kc, const void** x, const void* zero, const void* multiplier, void* y, size_t x_increment, size_t y_increment, const void* params); typedef void (*xnn_f32_pavgpool_up_ukernel_function)( size_t n, size_t ks, size_t kc, const float** x, const float* zero, const float* multiplier, float* y, size_t x_increment, size_t y_increment, const union xnn_f32_output_params* params); typedef void (*xnn_pavgpool_mp_ukernel_function)( size_t n, size_t ks, size_t kc, const void** x, const void* zero, const void* multiplier, void* buffer, void* y, size_t x_increment, size_t y_increment, const void* params); typedef void (*xnn_f32_pavgpool_mp_ukernel_function)( size_t n, size_t ks, size_t kc, const float** x, const float* zero, const float* multiplier, float* buffer, float* y, size_t x_increment, size_t y_increment, const union xnn_f32_output_params* params); typedef void (*xnn_maxpool_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const void** input, size_t input_offset, void* output, size_t input_increment, size_t output_increment, const void* params); typedef void (*xnn_f32_maxpool_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const float** input, size_t input_offset, float* output, size_t input_increment, size_t output_increment, const union xnn_f32_output_params* params); typedef void (*xnn_u8_maxpool_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const uint8_t** input, size_t input_offset, uint8_t* output, size_t input_increment, size_t output_increment, const union xnn_u8_output_params* params); typedef void (*xnn_argmaxpool_up_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const void** input, size_t input_offset, void* output, uint32_t* index, size_t input_increment, size_t output_increment, const void* params); typedef void (*xnn_f32_argmaxpool_up_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const float** input, size_t input_offset, float* output, uint32_t* index, size_t input_increment, size_t output_increment, const union xnn_f32_output_params* params); typedef void (*xnn_argmaxpool_mp_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const void** input, size_t input_offset, void* accumulation_buffer, uint32_t* index_buffer, void* output, uint32_t* index, size_t input_increment, size_t output_increment, const void* params); typedef void (*xnn_f32_argmaxpool_mp_ukernel_function)( size_t output_pixels, size_t kernel_elements, size_t channels, const float** input, size_t input_offset, float* accumulation_buffer, uint32_t* index_buffer, float* output, uint32_t* index, size_t input_increment, size_t output_increment, const union xnn_f32_output_params* params); typedef void (*xnn_univector_ukernel_function)( size_t n, const void* x, void* y, const void* params); typedef void (*xnn_f32_clamp_ukernel_function)( size_t n, const float* x, float* y, const union xnn_f32_output_params* params); typedef void (*xnn_u8_clamp_ukernel_function)( size_t n, const uint8_t* x, uint8_t* y, const union xnn_u8_output_params* params); typedef void (*xnn_f32_hswish_ukernel_function)( size_t n, const float* x, float* y, const union xnn_f32_hswish_params* params); typedef void (*xnn_rmax_ukernel_function)( size_t n, const void* x, void* y); typedef void (*xnn_u8_rmax_ukernel_function)( size_t n, const uint8_t* x, uint8_t* y); typedef void (*xnn_f32_rmax_ukernel_function)( size_t n, const float* x, float* y); typedef void (*xnn_u8_lut32norm_ukernel_function)( size_t n, const uint8_t* x, const uint32_t* t, uint8_t* y); typedef void (*xnn_vadd_ukernel_function)( size_t n, const void* a, const void* b, void* y, const void* params); typedef void (*xnn_f32_vadd_ukernel_function)( size_t n, const float* a, const float* b, float* y, const union xnn_f32_output_params* params); typedef void (*xnn_q8_vadd_ukernel_function)( size_t n, const uint8_t* a, const uint8_t* b, uint8_t* y, const union xnn_q8_add_params* params); typedef void (*xnn_vbinary_ukernel_function)( size_t n, const void* a, const void* b, void* y, const void* params); typedef void (*xnn_f32_vbinary_ukernel_function)( size_t n, const float* a, const float* b, float* y, const union xnn_f32_output_params* params); typedef void (*xnn_vunary_ukernel_function)( size_t n, const void* x, void* y, const void* params); typedef void (*xnn_f32_vunary_ukernel_function)( size_t n, const float* x, float* y, const void* params); typedef void (*xnn_vmulcaddc_ukernel_function)( size_t m, size_t c, const void* x, size_t x_stride, const void* w, void* y, size_t y_stride, const void* params); typedef void (*xnn_f32_vmulcaddc_ukernel_function)( size_t m, size_t c, const float* x, size_t x_stride, const float* w, float* y, size_t y_stride, const union xnn_f32_output_params* params); typedef void (*xnn_prelu_ukernel_function)( size_t mr, size_t n, const void* x, size_t x_stride, const void* w, void* y, size_t y_stride, const void* params); typedef void (*xnn_f32_prelu_ukernel_function)( size_t mr, size_t n, const float* x, size_t x_stride, const float* w, float* y, size_t y_stride, const union xnn_f32_output_params* params); typedef void (*xnn_f32_raddexpminusmax_ukernel_function)( size_t n, const float* input, float* sum, float max); typedef void (*xnn_f32_raddstoreexpminusmax_ukernel_function)( size_t n, const float* input, float* output, float* sum, float max); typedef void (*xnn_f32_vscaleexpminusmax_ukernel_function)( size_t n, const float* input, float* output, float max, float scale); typedef void (*xnn_f32_vscale_ukernel_function)( size_t n, const float* x, float* y, float c); // Reduce-Add Extended ("mantissa" + "exponent") Exponentials typedef void (*xnn_f32_raddextexp_ukernel_function)( size_t n, const float* input, float* sum); // Vector Scale Extended ("mantissa" + "exponent") Exponentials typedef void (*xnn_f32_vscaleextexp_ukernel_function)( size_t n, const float* input, float* output, float scale_mantissa, float scale_exponent); struct gemm_parameters { xnn_gemm_ukernel_function gemm; xnn_igemm_ukernel_function igemm; // Optional GEMM and IGEMM micro-kernels with MR=1 and the same NR and KR parameters. xnn_gemm_ukernel_function gemm1; xnn_igemm_ukernel_function igemm1; uint8_t mr; uint8_t nr; uint8_t log2_kr; uint8_t log2_sr; }; struct vbinary_parameters { xnn_vbinary_ukernel_function op_ukernel; xnn_vbinary_ukernel_function opc_ukernel; xnn_vbinary_ukernel_function ropc_ukernel; // Number of elements in a tile. // For best efficiency, micro-kernel must process a multiple of this number of elements in each call. uint8_t element_tile; }; struct spmm_parameters { xnn_spmm_ukernel_function ukernel; // Number of M-dimension elements in a tile. // Corresponds to a block of pixels in 1x1 Convolution and a block of batch size in Fully Connected operator. uint8_t mr; // Number of N-dimension elements in a tile. // Corresponds to a block of output channels/features in 1x1 Convolution and Fully Connected operator. uint8_t nr; }; struct hwc2spchw_dconv_parameters { xnn_conv_hwc2spchw_ukernel_function ukernel_with_symm_padding; // Number of output channels in a tile. // This parameter must be passed as is to weight packing function. uint8_t output_channel_tile; // Number of output height pixels in a tile. // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call. uint8_t output_height_tile; // Number of output width pixes in a tile. uint8_t output_width_tile; }; struct spchw_dwconv_parameters { xnn_dwconv_spchw_ukernel_function ukernel; // Number of input width pixels in a tile. uint8_t input_width_tile; // Number of output width pixels in a tile. uint8_t output_width_tile; // Number of output height pixels in a tile. // For best efficiency, micro-kernel must produce a multiple of this number of rows in each call. uint8_t output_height_tile; }; struct spchw_gavgpool_parameters { xnn_gavgpool_spchw_ukernel_function ukernel; // Number of channels in a tile. // For best efficiency, micro-kernel must process a multiple of this number of channels in each call. uint8_t channel_tile; }; struct dwconv_parameters { union { xnn_dwconv_up_ukernel_function up; xnn_dwconv_mp_ukernel_function mp; }; uint8_t cr; uint8_t mr; uint8_t qr; }; struct gavgpool_parameters { xnn_gavgpool_up_ukernel_function up; xnn_gavgpool_mp_ukernel_function mp; uint8_t mr; }; struct avgpool_parameters { xnn_avgpool_up_ukernel_function up; xnn_avgpool_mp_ukernel_function mp; uint8_t mr; uint8_t qr; }; struct pavgpool_parameters { xnn_pavgpool_up_ukernel_function up; xnn_pavgpool_mp_ukernel_function mp; uint8_t mr; uint8_t qr; }; struct argmaxpool_parameters { union { xnn_argmaxpool_up_ukernel_function up; xnn_argmaxpool_mp_ukernel_function mp; }; uint8_t mr; uint8_t qr; }; struct maxpool_parameters { xnn_maxpool_ukernel_function ukernel; uint8_t mr; uint8_t qr; }; struct bilinear_parameters { xnn_bilinear_ukernel_function ukernel; // Number of output pixels in a tile. // For best efficiency, micro-kernel must produce a multiple of this number of pixels in each call. uint8_t pixel_tile; // Number of channels in a tile. // For best efficiency, micro-kernel must process a multiple of this number of channels in each call. uint8_t channel_tile; }; struct zip_parameters { xnn_zipc_ukernel_function x2; xnn_zipc_ukernel_function x3; xnn_zipc_ukernel_function x4; xnn_zipv_ukernel_function xm; }; struct prelu_parameters { xnn_prelu_ukernel_function ukernel; uint16_t row_tile; uint16_t channel_tile; }; struct pad_parameters { xnn_pad_ukernel_function ukernel; uint8_t mr; }; struct vmulcaddc_parameters { xnn_vmulcaddc_ukernel_function ukernel; uint8_t channel_tile; uint8_t row_tile; }; #define XNN_MAX_Q8_DWCONV_UKERNELS 1 #define XNN_MAX_F32_DWCONV_UKERNELS 3 #define XNN_MAX_F32_ARGMAXPOOL_UKERNELS 3 struct xnn_parameters { bool initialized; struct xnn_allocator allocator; struct { struct gemm_parameters gemm; struct dwconv_parameters dwconv[XNN_MAX_Q8_DWCONV_UKERNELS]; struct avgpool_parameters avgpool; struct gavgpool_parameters gavgpool; xnn_vadd_ukernel_function vadd; } q8; struct { struct maxpool_parameters maxpool; xnn_univector_ukernel_function clamp; xnn_u8_lut32norm_ukernel_function lut32norm; xnn_u8_rmax_ukernel_function rmax; } u8; struct { xnn_x8_lut_ukernel_function lut; struct zip_parameters zip; } x8; struct { struct gemm_parameters gemm; struct gemm_parameters gemm2; struct dwconv_parameters dwconv[XNN_MAX_F32_DWCONV_UKERNELS]; struct avgpool_parameters avgpool; struct pavgpool_parameters pavgpool; struct gavgpool_parameters gavgpool; struct maxpool_parameters maxpool; struct argmaxpool_parameters argmaxpool[XNN_MAX_F32_ARGMAXPOOL_UKERNELS]; // Bilinear interpolation (2D). struct bilinear_parameters bilinear; xnn_univector_ukernel_function clamp; xnn_univector_ukernel_function hswish; xnn_univector_ukernel_function sigmoid; struct prelu_parameters prelu; struct vbinary_parameters vadd; struct vbinary_parameters vdiv; struct vbinary_parameters vmax; struct vbinary_parameters vmin; struct vbinary_parameters vmul; struct vbinary_parameters vsub; struct vmulcaddc_parameters vmulcaddc; xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax; xnn_f32_rmax_ukernel_function rmax; // Sparse Matrix-Dense Matrix Multiplication (NR=1 block). struct spmm_parameters spmm; // Sparse Matrix-Dense Matrix Multiplication (NR=2 block). struct spmm_parameters spmm2; // Sparse Matrix-Dense Matrix Multiplication (NR=4 block). struct spmm_parameters spmm4; // Direct 3x3 stride-2 Convolution with 3 input channels and HWC->SpCHW layout conversion. struct hwc2spchw_dconv_parameters hwc2spchw_dconv3x3c3s2; // Direct 3x3 stride-1 Convolution with padding 1 on left and right in SpCHW layout. struct spchw_dwconv_parameters spchw_dwconv3x3; // Direct 3x3 stride-2 Convolution with padding 1 on left and right in SpCHW layout. struct spchw_dwconv_parameters spchw_dwconv3x3s2; // Direct 5x5 stride-1 Convolution with padding 2 on left and right in SpCHW layout. struct spchw_dwconv_parameters spchw_dwconv5x5; // Direct 5x5 stride-2 Convolution with padding 2 on left and right in SpCHW layout. struct spchw_dwconv_parameters spchw_dwconv5x5s2; // Global Average Pooling in SpCHW layout. struct spchw_gavgpool_parameters spchw_gavgpool; } f32; struct { struct pad_parameters pad; xnn_unpool_ukernel_function unpool; struct zip_parameters zip; } x32; }; extern XNN_INTERNAL struct xnn_parameters xnn_params;