1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include <pthreadpool.h> 15 16 #include <xnnpack/params.h> 17 #include <xnnpack/compute.h> 18 19 20 enum xnn_ukernel_type { 21 xnn_ukernel_type_none = 0, 22 xnn_ukernel_type_add, 23 xnn_ukernel_type_argmax_pooling, 24 xnn_ukernel_type_average_pooling, 25 xnn_ukernel_type_binary_elementwise, 26 xnn_ukernel_type_channel_shuffle, 27 xnn_ukernel_type_clamp, 28 xnn_ukernel_type_dconv2d_hwc2spchw, 29 xnn_ukernel_type_dwconv, 30 xnn_ukernel_type_gemm, 31 xnn_ukernel_type_global_average_pooling, 32 xnn_ukernel_type_hswish, 33 xnn_ukernel_type_igemm, 34 xnn_ukernel_type_lut, 35 xnn_ukernel_type_max_pooling, 36 xnn_ukernel_type_pad, 37 xnn_ukernel_type_pixelwise_average_pooling, 38 xnn_ukernel_type_prelu, 39 xnn_ukernel_type_sigmoid, 40 xnn_ukernel_type_softmax, 41 xnn_ukernel_type_spmm, 42 xnn_ukernel_type_subconv2d, 43 xnn_ukernel_type_unpooling, 44 xnn_ukernel_type_vmulcaddc, 45 }; 46 47 enum xnn_operator_type { 48 xnn_operator_type_none = 0, 49 xnn_operator_type_add_nc_f32, 50 xnn_operator_type_add_nd_f32, 51 xnn_operator_type_add_nc_q8, 52 xnn_operator_type_argmax_pooling_nhwc_f32, 53 xnn_operator_type_average_pooling_nhwc_f32, 54 xnn_operator_type_average_pooling_nhwc_q8, 55 xnn_operator_type_channel_pad_nc_x32, 56 xnn_operator_type_channel_shuffle_nc_x32, 57 xnn_operator_type_channel_shuffle_nc_x8, 58 xnn_operator_type_clamp_nc_f32, 59 xnn_operator_type_clamp_nc_u8, 60 xnn_operator_type_convolution_nhwc_f32, 61 xnn_operator_type_convolution_nhwc_q8, 62 xnn_operator_type_convolution_nchw_f32, 63 xnn_operator_type_deconvolution_nhwc_f32, 64 xnn_operator_type_deconvolution_nhwc_q8, 65 xnn_operator_type_divide_nd_f32, 66 xnn_operator_type_fully_connected_nc_f32, 67 xnn_operator_type_fully_connected_nc_q8, 68 xnn_operator_type_global_average_pooling_nwc_f32, 69 xnn_operator_type_global_average_pooling_nwc_q8, 70 xnn_operator_type_global_average_pooling_ncw_f32, 71 xnn_operator_type_hardswish_nc_f32, 72 xnn_operator_type_leaky_relu_nc_q8, 73 xnn_operator_type_max_pooling_nhwc_f32, 74 xnn_operator_type_max_pooling_nhwc_u8, 75 xnn_operator_type_maximum_nd_f32, 76 xnn_operator_type_minimum_nd_f32, 77 xnn_operator_type_multiply_nd_f32, 78 xnn_operator_type_prelu_nc_f32, 79 xnn_operator_type_resize_bilinear_nhwc_f32, 80 xnn_operator_type_sigmoid_nc_f32, 81 xnn_operator_type_sigmoid_nc_q8, 82 xnn_operator_type_softmax_nc_f32, 83 xnn_operator_type_softmax_nc_q8, 84 xnn_operator_type_subtract_nd_f32, 85 xnn_operator_type_unpooling_nhwc_x32, 86 }; 87 88 struct xnn_ukernel_dconv2d { 89 union { 90 xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function; 91 xnn_conv_hwc_ukernel_function hwc_function; 92 }; 93 uint8_t output_height_tile; 94 uint8_t output_channel_tile; 95 }; 96 97 struct xnn_ukernel_dwconv { 98 union { 99 xnn_dwconv_up_ukernel_function unipass_function; 100 xnn_dwconv_mp_ukernel_function multipass_function; 101 }; 102 uint8_t mr; 103 uint8_t qr; 104 }; 105 106 // Direct 2D Depthwise Convolution 107 struct xnn_ukernel_dwconv2d { 108 union { 109 xnn_dwconv_spchw_ukernel_function spchw_function; 110 }; 111 uint8_t input_width_tile; 112 uint8_t output_width_tile; 113 }; 114 115 struct xnn_ukernel_gemm { 116 xnn_gemm_ukernel_function default_function; 117 xnn_gemm_ukernel_function mr1_function; 118 uint8_t mr; 119 uint8_t nr; 120 uint8_t kr; 121 }; 122 123 struct xnn_ukernel_igemm { 124 xnn_igemm_ukernel_function default_function; 125 xnn_igemm_ukernel_function mr1_function; 126 uint8_t mr; 127 uint8_t nr; 128 uint8_t kr; 129 }; 130 131 struct xnn_ukernel_spmm { 132 xnn_spmm_ukernel_function function; 133 uint8_t mr; 134 }; 135 136 struct xnn_ukernel_vmulcaddc { 137 xnn_vmulcaddc_ukernel_function function; 138 uint8_t mr; 139 }; 140 141 struct xnn_ukernel { 142 enum xnn_ukernel_type type; 143 union { 144 struct xnn_ukernel_dconv2d dconv2d; 145 struct xnn_ukernel_dwconv dwconv; 146 struct xnn_ukernel_dwconv2d dwconv2d; 147 struct xnn_ukernel_gemm gemm; 148 struct xnn_ukernel_igemm igemm; 149 struct xnn_ukernel_spmm spmm; 150 struct xnn_ukernel_vmulcaddc vmulcaddc; 151 }; 152 }; 153 154 enum xnn_run_state { 155 xnn_run_state_invalid = 0, 156 xnn_run_state_ready, 157 xnn_run_state_skip, 158 }; 159 160 struct subconvolution_params { 161 void* weights; 162 size_t w_stride; 163 const void** indirection_buffer; 164 void* output; 165 size_t slice_width; 166 size_t slice_height; 167 size_t indirection_y_stride; 168 size_t indirection_x_stride; 169 // scaled_kernel_size := kernel_size * mr * sizeof(void*). 170 size_t scaled_kernel_size; 171 }; 172 173 struct xnn_operator { 174 size_t batch_size; 175 uint32_t padding_top; 176 uint32_t padding_right; 177 uint32_t padding_bottom; 178 uint32_t padding_left; 179 uint32_t kernel_height; 180 uint32_t kernel_width; 181 uint32_t stride_height; 182 uint32_t stride_width; 183 uint32_t dilation_height; 184 uint32_t dilation_width; 185 uint32_t groups; 186 size_t group_channels; 187 size_t group_input_channels; 188 size_t group_output_channels; 189 size_t channels; 190 191 size_t pad_before_channels; 192 size_t pad_after_channels; 193 uint32_t pad_value; 194 195 size_t input_height; 196 size_t input_width; 197 size_t input_pixel_stride; 198 const void* input; 199 const void** indirection_buffer; 200 201 size_t input2_pixel_stride; 202 const void* input2; 203 204 size_t output_height; 205 size_t output_width; 206 size_t output_pixel_stride; 207 void* output; 208 209 void* packed_weights; 210 // Total number of non-zero kernel elements when weights use sparse representation. 211 size_t num_nonzero_values; 212 // Total number of non-zero kernel blocks when weights use sparse representation. 213 size_t num_nonzero_blocks; 214 // Total number of output channel blocks when weights use sparse representation. 215 size_t num_output_channel_blocks; 216 // Input channel corresponding to the first non-zero kernel element. 217 size_t first_input_channel; 218 219 float input_scale; 220 float output_scale; 221 uint8_t input_zero_point; 222 uint8_t kernel_zero_point; 223 uint8_t output_zero_point; 224 uint8_t output_min; 225 uint8_t output_max; 226 227 size_t valid_batch_size; 228 size_t last_input_height; 229 size_t last_input_width; 230 const void* last_input; 231 size_t last_output_height; 232 size_t last_output_width; 233 void* last_output; 234 235 void* zero_buffer; 236 void* lookup_table; 237 void* pixelwise_buffer; 238 struct subconvolution_params* subconvolution_buffer; 239 uint32_t flags; 240 241 union { 242 union xnn_f32_avgpool_params f32_avgpool_params; 243 union xnn_f32_gavgpool_params f32_gavgpool_params; 244 union xnn_f32_hswish_params f32_hswish_params; 245 union xnn_f32_output_params f32_output_params; 246 union xnn_f32_spchw_params f32_spchw_params; 247 union xnn_q8_add_params q8_add_params; 248 union xnn_q8_avgpool_params q8_avgpool_params; 249 union xnn_q8_gemm_params q8_gemm_params; 250 union xnn_u8_output_params u8_output_params; 251 }; 252 enum xnn_operator_type type; 253 struct xnn_ukernel ukernel; 254 255 struct compute_parameters compute; 256 struct compute_parameters compute2; 257 union { 258 struct add_contiguous_context add_contiguous; 259 struct add_strided_context add_strided; 260 struct argmax_pooling_context argmax_pooling; 261 struct average_pooling_context average_pooling; 262 struct channel_pad_context channel_pad; 263 struct channel_shuffle_context channel_shuffle; 264 struct dconv2d_context dconv2d; 265 struct dwconv2d_context dwconv2d; 266 struct dwconv_context dwconv; 267 struct elementwise_binary_context elementwise_binary; 268 struct gemm_context gemm; 269 struct global_average_pooling_nwc_context global_average_pooling_nwc; 270 struct global_average_pooling_ncw_context global_average_pooling_ncw; 271 struct igemm_context igemm; 272 struct lut_contiguous_context lut_contiguous; 273 struct lut_strided_context lut_strided; 274 struct max_pooling_context max_pooling; 275 struct pixelwise_average_pooling_context pixelwise_average_pooling; 276 struct prelu_context prelu; 277 struct resize_bilinear_context resize_bilinear; 278 struct spmm_context spmm; 279 struct subconv_context subconv; 280 struct f32_three_pass_softmax_context f32_three_pass_softmax; 281 struct u8_softmax_context u8_softmax; 282 struct univector_contiguous_context univector_contiguous; 283 struct univector_strided_context univector_strided; 284 struct unpooling_context unpooling; 285 struct vmulcaddc_context vmulcaddc; 286 } context; 287 288 enum xnn_run_state state; 289 }; 290