// Copyright (c) Facebook, Inc. and its affiliates. // All rights reserved. // // Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #pragma once #include #include #include #include #include #include #include #include #include struct xnn_ukernel_conv2d { union { xnn_conv_hwc2chw_ukernel_function hwc2chw_function; xnn_conv_hwc_ukernel_function hwc_function; }; uint8_t output_height_tile; uint8_t output_channel_tile; }; struct xnn_ukernel_dwconv { union { xnn_dwconv_unipass_ukernel_function unipass_function; xnn_dwconv_multipass_ukernel_function multipass_function; }; uint8_t primary_tile; uint8_t incremental_tile; }; // Direct 2D Depthwise Convolution struct xnn_ukernel_dwconv2d { union { xnn_dwconv2d_chw_ukernel_function chw_function; }; uint8_t output_width_tile; }; struct xnn_ukernel_gemm { struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR]; uint8_t mr; uint8_t nr; uint8_t kr; uint8_t sr; }; struct xnn_ukernel_igemm { struct xnn_hmp_igemm_ukernel igemm_cases[XNN_MAX_MR]; struct xnn_hmp_gemm_ukernel gemm_cases[XNN_MAX_MR]; uint8_t mr; uint8_t nr; uint8_t kr; uint8_t sr; }; struct xnn_ukernel_spmm { xnn_spmm_ukernel_function function; uint8_t mr; }; struct xnn_ukernel_vmulcaddc { xnn_vmulcaddc_ukernel_function function; uint8_t mr; }; struct xnn_ukernel_vbinary { xnn_vbinary_ukernel_function op_function; xnn_vbinary_ukernel_function opc_function; xnn_vbinary_ukernel_function ropc_function; }; struct xnn_ukernel_vunary { xnn_vunary_ukernel_function function; }; struct xnn_ukernel { enum xnn_ukernel_type type; union { struct xnn_ukernel_conv2d conv2d; struct xnn_ukernel_dwconv dwconv; struct xnn_ukernel_dwconv2d dwconv2d; struct xnn_ukernel_gemm gemm; struct xnn_ukernel_igemm igemm; struct xnn_ukernel_spmm spmm; struct xnn_ukernel_vmulcaddc vmulcaddc; struct xnn_ukernel_vbinary vbinary; struct xnn_ukernel_vunary vunary; }; }; enum xnn_run_state { xnn_run_state_invalid = 0, xnn_run_state_ready, xnn_run_state_skip, }; struct subconvolution_params { void* weights; size_t w_stride; const void** indirection_buffer; void* output; size_t slice_width; size_t slice_height; size_t indirection_y_stride; size_t indirection_x_stride; // scaled_kernel_size := kernel_size * mr * sizeof(void*). size_t scaled_kernel_size; }; struct xnn_operator { size_t batch_size; uint32_t padding_top; uint32_t padding_right; uint32_t padding_bottom; uint32_t padding_left; uint32_t kernel_height; uint32_t kernel_width; uint32_t stride_height; uint32_t stride_width; uint32_t dilation_height; uint32_t dilation_width; uint32_t groups; size_t group_channels; size_t group_input_channels; size_t group_output_channels; size_t channels; uint32_t pad_value; size_t input_height; size_t input_width; size_t input_pixel_stride; const void* input; const void* input2; const void** indirection_buffer; size_t output_height; size_t output_width; size_t output_pixel_stride; void* output; union { // Pointer to allocated packed weights. Use this if weights_cache is NULL. void* pointer; // Offset into the weights cache where the packed weights are. Only valid if weights_cache is not NULL. size_t offset; } packed_weights; // Total number of non-zero kernel elements when weights use sparse representation. size_t num_nonzero_values; // Total number of non-zero kernel blocks when weights use sparse representation. size_t num_nonzero_blocks; // Total number of output channel blocks when weights use sparse representation. size_t num_output_channel_blocks; // Input channel corresponding to the first non-zero kernel element. size_t first_input_channel; float input_scale; float output_scale; int32_t input_zero_point; size_t valid_batch_size; size_t last_input_height; size_t last_input_width; const void* last_input; size_t last_output_height; size_t last_output_width; void* last_output; uint32_t block_size; void* zero_buffer; void* lookup_table; void* pixelwise_buffer; struct subconvolution_params* subconvolution_buffer; uint32_t flags; union { union xnn_f16_abs_params f16_abs; union xnn_f16_f32_cvt_params f16_f32_cvt; union xnn_f16_hswish_params f16_hswish; union xnn_f16_elu_params f16_elu; union xnn_f16_lrelu_params f16_lrelu; union xnn_f16_neg_params f16_neg; union xnn_f16_sigmoid_params f16_sigmoid; union xnn_f32_abs_params f32_abs; union xnn_f32_default_params f32_default; union xnn_f32_elu_params f32_elu; union xnn_f32_lrelu_params f32_lrelu; union xnn_f32_neg_params f32_neg; union xnn_f32_rnd_params f32_rnd; union xnn_f32_sigmoid_params f32_sigmoid; union xnn_f32_sqrt_params f32_sqrt; // Parameters for Global Average Pooling in CHW layout union xnn_f32_gavgpool_params f32_gavgpool; union xnn_f32_hswish_params f32_hswish; // Pixelwise Average Pooling normally use f16_minmax_params, but also initialize // f16_scaleminmax_params in case it needs to switch to Global Average Pooling operation. struct { union xnn_f16_minmax_params f16_minmax; union xnn_f16_scaleminmax_params f16_scaleminmax; }; // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation. struct { union xnn_f32_minmax_params f32_minmax; union xnn_f32_scaleminmax_params f32_scaleminmax; }; union xnn_f32_chw_params f32_chw; union xnn_f32_f16_cvt_params f32_f16_cvt; union xnn_f32_qs8_cvt_params f32_qs8_cvt; union xnn_f32_qu8_cvt_params f32_qu8_cvt; union xnn_qs8_cvt_params qs8_cvt; union xnn_qs8_f32_cvt_params qs8_f32_cvt; union xnn_qu8_cvt_params qu8_cvt; union xnn_qu8_f32_cvt_params qu8_f32_cvt; union xnn_qs8_conv_minmax_params qs8_conv_minmax; // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch // to Global Average Pooling operation. struct { union xnn_qs8_avgpool_minmax_params qs8_avgpool; union xnn_qs8_avgpool_minmax_params qs8_gavgpool; }; // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order. struct { union xnn_qs8_add_minmax_params qs8_add; union xnn_qs8_add_minmax_params qs8_radd; }; struct { union xnn_qs8_mul_minmax_params qs8_mul; union xnn_qs8_mul_minmax_params qs8_rmul; }; struct { union xnn_qu8_add_minmax_params qu8_add; union xnn_qu8_add_minmax_params qu8_radd; }; struct { union xnn_qu8_mul_minmax_params qu8_mul; union xnn_qu8_mul_minmax_params qu8_rmul; }; union xnn_qu8_conv_minmax_params qu8_conv_minmax; // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch // to Global Average Pooling operation. struct { union xnn_qu8_avgpool_minmax_params qu8_avgpool; union xnn_qu8_avgpool_minmax_params qu8_gavgpool; }; union xnn_qs8_lrelu_params qs8_lrelu; union xnn_qu8_lrelu_params qu8_lrelu; union xnn_s8_minmax_params s8_minmax; union xnn_u8_minmax_params u8_minmax; } params; size_t num_post_operation_params; void* post_operation_params; enum xnn_operator_type type; struct xnn_ukernel ukernel; struct compute_parameters compute; struct compute_parameters compute2; union { struct argmax_pooling_context argmax_pooling; struct average_pooling_context average_pooling; struct channel_shuffle_context channel_shuffle; struct conv2d_context conv2d; struct dwconv2d_context dwconv2d; struct dwconv_context dwconv; struct elementwise_binary_context elementwise_binary; struct gemm_context gemm; struct global_average_pooling_nwc_context global_average_pooling_nwc; struct global_average_pooling_ncw_context global_average_pooling_ncw; struct igemm_context igemm; struct lut_contiguous_context lut_contiguous; struct lut_strided_context lut_strided; struct max_pooling_context max_pooling; struct pad_context pad; struct pixelwise_average_pooling_context pixelwise_average_pooling; struct prelu_context prelu; struct resize_bilinear_context resize_bilinear; struct resize_bilinear_chw_context resize_bilinear_chw; struct spmm_context spmm; struct subconv_context subconv; struct subgemm_context subgemm; struct transpose_context transpose; struct floating_point_softmax_context floating_point_softmax; struct u8_softmax_context u8_softmax; struct univector_contiguous_context univector_contiguous; struct univector_strided_context univector_strided; struct unpooling_context unpooling; struct vmulcaddc_context vmulcaddc; } context; struct xnn_code_cache* code_cache; struct xnn_weights_cache* weights_cache; enum xnn_run_state state; }; static inline void* packed_weights(struct xnn_operator* op) { if (op->weights_cache == NULL) { return op->packed_weights.pointer; } else { return (void*) ((uintptr_t) op->weights_cache->cache.weights.start + op->packed_weights.offset); } } static inline bool use_weights_cache(struct xnn_operator* op) { return op->weights_cache != NULL; } // Get a pointer to a region to pack weights into. If weights cache is available, use it, returning to a pointer to the // cache's buffer, otherwise, allocate and return a pointer to a new region. Returns NULL on error. XNN_INTERNAL void* xnn_get_pointer_to_write_weights( xnn_operator_t op, size_t aligned_weights_size, int padding_byte); #ifdef __cplusplus extern "C" { #endif XNN_INTERNAL size_t xnn_compute_convolution_output_dimension( size_t padded_input_dimension, size_t kernel_dimension, size_t dilation_dimension, size_t subsampling_dimension); XNN_INTERNAL size_t xnn_compute_deconvolution_output_dimension( size_t input_dimension, size_t output_padding_dimension, size_t adjustment_dimension, size_t kernel_dimension, size_t dilation_dimension, size_t stride_dimension); XNN_INTERNAL size_t xnn_compute_unpooling_output_dimension( size_t input_dimension, size_t input_padding_dimension, size_t kernel_dimension); XNN_INTERNAL uint32_t xnn_get_heuristic_mr_gemm( size_t batch_size, uint32_t max_mr, uint32_t nr, struct xnn_hmp_gemm_ukernel *gemm_cases); XNN_INTERNAL uint32_t xnn_get_heuristic_mr_igemm( size_t batch_size, uint32_t max_mr, uint32_t nr, struct xnn_hmp_igemm_ukernel *igemm_cases); #ifdef __cplusplus } #endif