1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include <pthreadpool.h> 15 16 #include <xnnpack/allocator.h> 17 #include <xnnpack/params.h> 18 #include <xnnpack/compute.h> 19 20 21 enum xnn_ukernel_type { 22 xnn_ukernel_type_default = 0, 23 xnn_ukernel_type_average_pooling, 24 xnn_ukernel_type_conv2d_hwc2chw, 25 xnn_ukernel_type_dwconv, 26 xnn_ukernel_type_gemm, 27 xnn_ukernel_type_igemm, 28 xnn_ukernel_type_pixelwise_average_pooling, 29 xnn_ukernel_type_spmm, 30 xnn_ukernel_type_subconv2d, 31 xnn_ukernel_type_vmulcaddc, 32 }; 33 34 enum xnn_operator_type { 35 xnn_operator_type_invalid = 0, 36 xnn_operator_type_abs_nc_f32, 37 xnn_operator_type_add_nd_f16, 38 xnn_operator_type_add_nd_f32, 39 xnn_operator_type_add_nd_qs8, 40 xnn_operator_type_add_nd_qu8, 41 xnn_operator_type_argmax_pooling_nhwc_f32, 42 xnn_operator_type_average_pooling_nhwc_f32, 43 xnn_operator_type_average_pooling_nhwc_qu8, 44 xnn_operator_type_bankers_rounding_nc_f32, 45 xnn_operator_type_channel_shuffle_nc_x8, 46 xnn_operator_type_channel_shuffle_nc_x32, 47 xnn_operator_type_clamp_nc_f32, 48 xnn_operator_type_clamp_nc_s8, 49 xnn_operator_type_clamp_nc_u8, 50 xnn_operator_type_ceiling_nc_f32, 51 xnn_operator_type_constant_pad_nd_x8, 52 xnn_operator_type_constant_pad_nd_x16, 53 xnn_operator_type_constant_pad_nd_x32, 54 xnn_operator_type_convert_nc_f16_f32, 55 xnn_operator_type_convert_nc_f32_f16, 56 xnn_operator_type_convert_nc_f32_qs8, 57 xnn_operator_type_convert_nc_f32_qu8, 58 xnn_operator_type_convert_nc_qs8_f32, 59 xnn_operator_type_convert_nc_qu8_f32, 60 xnn_operator_type_convolution_nchw_f32, 61 xnn_operator_type_convolution_nhwc_f16, 62 xnn_operator_type_convolution_nhwc_f32, 63 xnn_operator_type_convolution_nhwc_qc8, 64 xnn_operator_type_convolution_nhwc_qs8, 65 xnn_operator_type_convolution_nhwc_qu8, 66 xnn_operator_type_copy_nc_x8, 67 xnn_operator_type_copy_nc_x16, 68 xnn_operator_type_copy_nc_x32, 69 xnn_operator_type_deconvolution_nhwc_f32, 70 xnn_operator_type_deconvolution_nhwc_qs8, 71 xnn_operator_type_deconvolution_nhwc_qu8, 72 xnn_operator_type_depth_to_space_nchw2nhwc_x32, 73 xnn_operator_type_depth_to_space_nhwc_x32, 74 xnn_operator_type_divide_nd_f32, 75 xnn_operator_type_elu_nc_f32, 76 xnn_operator_type_elu_nc_qs8, 77 xnn_operator_type_fully_connected_nc_f16, 78 xnn_operator_type_fully_connected_nc_f32, 79 xnn_operator_type_fully_connected_nc_qs8, 80 xnn_operator_type_fully_connected_nc_qu8, 81 xnn_operator_type_floor_nc_f32, 82 xnn_operator_type_global_average_pooling_nwc_f16, 83 xnn_operator_type_global_average_pooling_nwc_f32, 84 xnn_operator_type_global_average_pooling_nwc_qs8, 85 xnn_operator_type_global_average_pooling_nwc_qu8, 86 xnn_operator_type_global_average_pooling_ncw_f32, 87 xnn_operator_type_hardswish_nc_f16, 88 xnn_operator_type_hardswish_nc_f32, 89 xnn_operator_type_leaky_relu_nc_f32, 90 xnn_operator_type_leaky_relu_nc_qu8, 91 xnn_operator_type_max_pooling_nhwc_f16, 92 xnn_operator_type_max_pooling_nhwc_f32, 93 xnn_operator_type_max_pooling_nhwc_s8, 94 xnn_operator_type_max_pooling_nhwc_u8, 95 xnn_operator_type_maximum_nd_f32, 96 xnn_operator_type_minimum_nd_f32, 97 xnn_operator_type_multiply_nd_f16, 98 xnn_operator_type_multiply_nd_f32, 99 xnn_operator_type_multiply_nd_qs8, 100 xnn_operator_type_multiply_nd_qu8, 101 xnn_operator_type_negate_nc_f32, 102 xnn_operator_type_prelu_nc_f16, 103 xnn_operator_type_prelu_nc_f32, 104 xnn_operator_type_resize_bilinear_nchw_f32, 105 xnn_operator_type_resize_bilinear_nhwc_f32, 106 xnn_operator_type_resize_bilinear_nhwc_s8, 107 xnn_operator_type_resize_bilinear_nhwc_u8, 108 xnn_operator_type_sigmoid_nc_f32, 109 xnn_operator_type_sigmoid_nc_qs8, 110 xnn_operator_type_sigmoid_nc_qu8, 111 xnn_operator_type_softmax_nc_f32, 112 xnn_operator_type_softmax_nc_qu8, 113 xnn_operator_type_square_nc_f32, 114 xnn_operator_type_square_root_nc_f32, 115 xnn_operator_type_squared_difference_nd_f32, 116 xnn_operator_type_subtract_nd_f32, 117 xnn_operator_type_subtract_nd_qs8, 118 xnn_operator_type_subtract_nd_qu8, 119 xnn_operator_type_tanh_nc_qs8, 120 xnn_operator_type_tanh_nc_qu8, 121 xnn_operator_type_truncation_nc_f32, 122 xnn_operator_type_unpooling_nhwc_x32, 123 }; 124 125 struct xnn_ukernel_conv2d { 126 union { 127 xnn_conv_hwc2chw_ukernel_function hwc2chw_function; 128 xnn_conv_hwc_ukernel_function hwc_function; 129 }; 130 uint8_t output_height_tile; 131 uint8_t output_channel_tile; 132 }; 133 134 struct xnn_ukernel_dwconv { 135 union { 136 xnn_dwconv_unipass_ukernel_function unipass_function; 137 xnn_dwconv_multipass_ukernel_function multipass_function; 138 }; 139 uint8_t primary_tile; 140 uint8_t incremental_tile; 141 }; 142 143 // Direct 2D Depthwise Convolution 144 struct xnn_ukernel_dwconv2d { 145 union { 146 xnn_dwconv2d_chw_ukernel_function chw_function; 147 }; 148 uint8_t output_width_tile; 149 }; 150 151 struct xnn_ukernel_gemm { 152 struct xnn_hmp_gemm_ukernel general_case; 153 struct xnn_hmp_gemm_ukernel mr1_case; 154 #if XNN_PLATFORM_JIT 155 struct xnn_code_buffer general_code_buffer; 156 struct xnn_code_buffer mr1_code_buffer; 157 #endif // XNN_PLATFORM_JIT 158 uint8_t mr; 159 uint8_t nr; 160 uint8_t kr; 161 uint8_t sr; 162 }; 163 164 struct xnn_ukernel_igemm { 165 struct xnn_hmp_igemm_ukernel general_case; 166 struct xnn_hmp_igemm_ukernel mr1_case; 167 struct xnn_hmp_gemm_ukernel gemm_case; 168 #if XNN_PLATFORM_JIT 169 struct xnn_code_buffer general_code_buffer; 170 struct xnn_code_buffer mr1_code_buffer; 171 #endif // XNN_PLATFORM_JIT 172 uint8_t mr; 173 uint8_t nr; 174 uint8_t kr; 175 uint8_t sr; 176 }; 177 178 struct xnn_ukernel_spmm { 179 xnn_spmm_ukernel_function function; 180 uint8_t mr; 181 }; 182 183 struct xnn_ukernel_vmulcaddc { 184 xnn_vmulcaddc_ukernel_function function; 185 uint8_t mr; 186 }; 187 188 struct xnn_ukernel_vbinary { 189 xnn_vbinary_ukernel_function op_function; 190 xnn_vbinary_ukernel_function opc_function; 191 xnn_vbinary_ukernel_function ropc_function; 192 }; 193 194 struct xnn_ukernel_vunary { 195 xnn_vunary_ukernel_function function; 196 }; 197 198 struct xnn_ukernel { 199 enum xnn_ukernel_type type; 200 union { 201 struct xnn_ukernel_conv2d conv2d; 202 struct xnn_ukernel_dwconv dwconv; 203 struct xnn_ukernel_dwconv2d dwconv2d; 204 struct xnn_ukernel_gemm gemm; 205 struct xnn_ukernel_igemm igemm; 206 struct xnn_ukernel_spmm spmm; 207 struct xnn_ukernel_vmulcaddc vmulcaddc; 208 struct xnn_ukernel_vbinary vbinary; 209 struct xnn_ukernel_vunary vunary; 210 }; 211 }; 212 213 enum xnn_run_state { 214 xnn_run_state_invalid = 0, 215 xnn_run_state_ready, 216 xnn_run_state_skip, 217 }; 218 219 struct subconvolution_params { 220 void* weights; 221 size_t w_stride; 222 const void** indirection_buffer; 223 void* output; 224 size_t slice_width; 225 size_t slice_height; 226 size_t indirection_y_stride; 227 size_t indirection_x_stride; 228 // scaled_kernel_size := kernel_size * mr * sizeof(void*). 229 size_t scaled_kernel_size; 230 }; 231 232 struct xnn_operator { 233 size_t batch_size; 234 uint32_t padding_top; 235 uint32_t padding_right; 236 uint32_t padding_bottom; 237 uint32_t padding_left; 238 uint32_t kernel_height; 239 uint32_t kernel_width; 240 uint32_t stride_height; 241 uint32_t stride_width; 242 uint32_t dilation_height; 243 uint32_t dilation_width; 244 uint32_t groups; 245 size_t group_channels; 246 size_t group_input_channels; 247 size_t group_output_channels; 248 size_t channels; 249 250 size_t pad_before_channels; 251 size_t pad_after_channels; 252 uint32_t pad_value; 253 254 size_t input_height; 255 size_t input_width; 256 size_t input_pixel_stride; 257 const void* input; 258 const void* input2; 259 const void** indirection_buffer; 260 261 size_t output_height; 262 size_t output_width; 263 size_t output_pixel_stride; 264 void* output; 265 266 void* packed_weights; 267 // Total number of non-zero kernel elements when weights use sparse representation. 268 size_t num_nonzero_values; 269 // Total number of non-zero kernel blocks when weights use sparse representation. 270 size_t num_nonzero_blocks; 271 // Total number of output channel blocks when weights use sparse representation. 272 size_t num_output_channel_blocks; 273 // Input channel corresponding to the first non-zero kernel element. 274 size_t first_input_channel; 275 276 float input_scale; 277 float output_scale; 278 int32_t input_zero_point; 279 uint8_t output_zero_point; 280 uint8_t output_min; 281 uint8_t output_max; 282 283 size_t valid_batch_size; 284 size_t last_input_height; 285 size_t last_input_width; 286 const void* last_input; 287 size_t last_output_height; 288 size_t last_output_width; 289 void* last_output; 290 291 uint32_t block_size; 292 293 void* zero_buffer; 294 void* lookup_table; 295 void* pixelwise_buffer; 296 struct subconvolution_params* subconvolution_buffer; 297 uint32_t flags; 298 299 union { 300 union xnn_f16_f32_cvt_params f16_f32_cvt; 301 union xnn_f16_hswish_params f16_hswish; 302 union xnn_f32_abs_params f32_abs; 303 union xnn_f32_default_params f32_default; 304 union xnn_f32_elu_params f32_elu; 305 union xnn_f32_lrelu_params f32_lrelu; 306 union xnn_f32_neg_params f32_neg; 307 union xnn_f32_rnd_params f32_rnd; 308 union xnn_f32_sigmoid_params f32_sigmoid; 309 union xnn_f32_sqrt_params f32_sqrt; 310 // Parameters for Global Average Pooling in CHW layout 311 union xnn_f32_gavgpool_params f32_gavgpool; 312 union xnn_f32_hswish_params f32_hswish; 313 union xnn_f16_minmax_params f16_minmax; 314 union xnn_f16_scaleminmax_params f16_scaleminmax; 315 // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize 316 // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation. 317 struct { 318 union xnn_f32_minmax_params f32_minmax; 319 union xnn_f32_scaleminmax_params f32_scaleminmax; 320 }; 321 union xnn_f32_chw_params f32_chw; 322 union xnn_f32_f16_cvt_params f32_f16_cvt; 323 union xnn_f32_qs8_cvt_params f32_qs8_cvt; 324 union xnn_f32_qu8_cvt_params f32_qu8_cvt; 325 union xnn_qs8_f32_cvt_params qs8_f32_cvt; 326 union xnn_qu8_f32_cvt_params qu8_f32_cvt; 327 union xnn_qs8_conv_minmax_params qs8_conv_minmax; 328 // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch 329 // to Global Average Pooling operation. 330 struct { 331 union xnn_qs8_avgpool_minmax_params qs8_avgpool; 332 union xnn_qs8_avgpool_minmax_params qs8_gavgpool; 333 }; 334 // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order. 335 struct { 336 union xnn_qs8_addsub_minmax_params qs8_addsub; 337 union xnn_qs8_addsub_minmax_params qs8_raddsub; 338 }; 339 struct { 340 union xnn_qs8_mul_minmax_params qs8_mul; 341 union xnn_qs8_mul_minmax_params qs8_rmul; 342 }; 343 struct { 344 union xnn_qu8_addsub_minmax_params qu8_addsub; 345 union xnn_qu8_addsub_minmax_params qu8_raddsub; 346 }; 347 struct { 348 union xnn_qu8_mul_minmax_params qu8_mul; 349 union xnn_qu8_mul_minmax_params qu8_rmul; 350 }; 351 union xnn_qu8_conv_minmax_params qu8_conv_minmax; 352 // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch 353 // to Global Average Pooling operation. 354 struct { 355 union xnn_qu8_avgpool_minmax_params qu8_avgpool; 356 union xnn_qu8_avgpool_minmax_params qu8_gavgpool; 357 }; 358 union xnn_s8_minmax_params s8_minmax; 359 union xnn_u8_minmax_params u8_minmax; 360 } params; 361 enum xnn_operator_type type; 362 struct xnn_ukernel ukernel; 363 364 struct compute_parameters compute; 365 struct compute_parameters compute2; 366 union { 367 struct argmax_pooling_context argmax_pooling; 368 struct average_pooling_context average_pooling; 369 struct channel_shuffle_context channel_shuffle; 370 struct conv2d_context conv2d; 371 struct dwconv2d_context dwconv2d; 372 struct dwconv_context dwconv; 373 struct depthtospace2d_chw2hwc_context depthtospace2d_chw; 374 struct depthtospace2d_hwc_context depthtospace2d_hwc; 375 struct elementwise_binary_context elementwise_binary; 376 struct gemm_context gemm; 377 struct global_average_pooling_nwc_context global_average_pooling_nwc; 378 struct global_average_pooling_ncw_context global_average_pooling_ncw; 379 struct igemm_context igemm; 380 struct lut_contiguous_context lut_contiguous; 381 struct lut_strided_context lut_strided; 382 struct max_pooling_context max_pooling; 383 struct pad_context pad; 384 struct pixelwise_average_pooling_context pixelwise_average_pooling; 385 struct prelu_context prelu; 386 struct resize_bilinear_context resize_bilinear; 387 struct resize_bilinear_chw_context resize_bilinear_chw; 388 struct spmm_context spmm; 389 struct subconv_context subconv; 390 struct subgemm_context subgemm; 391 struct f32_three_pass_softmax_context f32_three_pass_softmax; 392 struct u8_softmax_context u8_softmax; 393 struct univector_contiguous_context univector_contiguous; 394 struct univector_strided_context univector_strided; 395 struct unpooling_context unpooling; 396 struct vmulcaddc_context vmulcaddc; 397 } context; 398 399 enum xnn_run_state state; 400 }; 401