1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stddef.h> 12 #include <stdint.h> 13 14 #include <pthreadpool.h> 15 16 #include <xnnpack/params.h> 17 #include <xnnpack/compute.h> 18 19 20 enum xnn_ukernel_type { 21 xnn_ukernel_type_default = 0, 22 xnn_ukernel_type_average_pooling, 23 xnn_ukernel_type_conv2d_hwc2chw, 24 xnn_ukernel_type_dwconv, 25 xnn_ukernel_type_gemm, 26 xnn_ukernel_type_igemm, 27 xnn_ukernel_type_pixelwise_average_pooling, 28 xnn_ukernel_type_spmm, 29 xnn_ukernel_type_subconv2d, 30 xnn_ukernel_type_vmulcaddc, 31 }; 32 33 enum xnn_operator_type { 34 xnn_operator_type_invalid = 0, 35 xnn_operator_type_abs_nc_f32, 36 xnn_operator_type_add_nd_f16, 37 xnn_operator_type_add_nd_f32, 38 xnn_operator_type_add_nd_qs8, 39 xnn_operator_type_argmax_pooling_nhwc_f32, 40 xnn_operator_type_average_pooling_nhwc_f32, 41 xnn_operator_type_average_pooling_nhwc_qu8, 42 xnn_operator_type_bankers_rounding_nc_f32, 43 xnn_operator_type_channel_shuffle_nc_x32, 44 xnn_operator_type_channel_shuffle_nc_x8, 45 xnn_operator_type_clamp_nc_f32, 46 xnn_operator_type_clamp_nc_u8, 47 xnn_operator_type_ceiling_nc_f32, 48 xnn_operator_type_constant_pad_nd_x32, 49 xnn_operator_type_convolution_nchw_f32, 50 xnn_operator_type_convolution_nhwc_f16, 51 xnn_operator_type_convolution_nhwc_f32, 52 xnn_operator_type_convolution_nhwc_qs8, 53 xnn_operator_type_convolution_nhwc_qu8, 54 xnn_operator_type_copy_nc_x32, 55 xnn_operator_type_deconvolution_nhwc_f32, 56 xnn_operator_type_deconvolution_nhwc_qu8, 57 xnn_operator_type_depth_to_space_nchw2nhwc_x32, 58 xnn_operator_type_depth_to_space_nhwc_x32, 59 xnn_operator_type_divide_nd_f32, 60 xnn_operator_type_elu_nc_f32, 61 xnn_operator_type_fully_connected_nc_f32, 62 xnn_operator_type_fully_connected_nc_qu8, 63 xnn_operator_type_floor_nc_f32, 64 xnn_operator_type_global_average_pooling_nwc_f16, 65 xnn_operator_type_global_average_pooling_nwc_f32, 66 xnn_operator_type_global_average_pooling_nwc_qs8, 67 xnn_operator_type_global_average_pooling_nwc_qu8, 68 xnn_operator_type_global_average_pooling_ncw_f32, 69 xnn_operator_type_hardswish_nc_f16, 70 xnn_operator_type_hardswish_nc_f32, 71 xnn_operator_type_leaky_relu_nc_f32, 72 xnn_operator_type_leaky_relu_nc_qu8, 73 xnn_operator_type_max_pooling_nhwc_f32, 74 xnn_operator_type_max_pooling_nhwc_u8, 75 xnn_operator_type_maximum_nd_f32, 76 xnn_operator_type_minimum_nd_f32, 77 xnn_operator_type_multiply_nd_f16, 78 xnn_operator_type_multiply_nd_f32, 79 xnn_operator_type_negate_nc_f32, 80 xnn_operator_type_prelu_nc_f32, 81 xnn_operator_type_resize_bilinear_nchw_f32, 82 xnn_operator_type_resize_bilinear_nhwc_f32, 83 xnn_operator_type_sigmoid_nc_f32, 84 xnn_operator_type_sigmoid_nc_qu8, 85 xnn_operator_type_softmax_nc_f32, 86 xnn_operator_type_softmax_nc_qu8, 87 xnn_operator_type_square_nc_f32, 88 xnn_operator_type_square_root_nc_f32, 89 xnn_operator_type_squared_difference_nd_f32, 90 xnn_operator_type_subtract_nd_f32, 91 xnn_operator_type_truncation_nc_f32, 92 xnn_operator_type_unpooling_nhwc_x32, 93 }; 94 95 struct xnn_ukernel_conv2d { 96 union { 97 xnn_conv_hwc2chw_ukernel_function hwc2chw_function; 98 xnn_conv_hwc_ukernel_function hwc_function; 99 }; 100 uint8_t output_height_tile; 101 uint8_t output_channel_tile; 102 }; 103 104 struct xnn_ukernel_dwconv { 105 union { 106 xnn_dwconv_unipass_ukernel_function unipass_function; 107 xnn_dwconv_multipass_ukernel_function multipass_function; 108 }; 109 uint8_t primary_tile; 110 uint8_t incremental_tile; 111 }; 112 113 // Direct 2D Depthwise Convolution 114 struct xnn_ukernel_dwconv2d { 115 union { 116 xnn_dwconv2d_chw_ukernel_function chw_function; 117 }; 118 uint8_t output_width_tile; 119 }; 120 121 struct xnn_ukernel_gemm { 122 struct xnn_hmp_gemm_ukernel general_case; 123 struct xnn_hmp_gemm_ukernel mr1_case; 124 uint8_t mr; 125 uint8_t nr; 126 uint8_t kr; 127 }; 128 129 struct xnn_ukernel_igemm { 130 struct xnn_hmp_igemm_ukernel general_case; 131 struct xnn_hmp_igemm_ukernel mr1_case; 132 struct xnn_hmp_gemm_ukernel gemm_case; 133 uint8_t mr; 134 uint8_t nr; 135 uint8_t kr; 136 }; 137 138 struct xnn_ukernel_spmm { 139 xnn_spmm_ukernel_function function; 140 uint8_t mr; 141 }; 142 143 struct xnn_ukernel_vmulcaddc { 144 xnn_vmulcaddc_ukernel_function function; 145 uint8_t mr; 146 }; 147 148 struct xnn_ukernel_vbinary { 149 xnn_vbinary_ukernel_function op_function; 150 xnn_vbinary_ukernel_function opc_function; 151 xnn_vbinary_ukernel_function ropc_function; 152 }; 153 154 struct xnn_ukernel_vunary { 155 xnn_vunary_ukernel_function function; 156 }; 157 158 struct xnn_ukernel { 159 enum xnn_ukernel_type type; 160 union { 161 struct xnn_ukernel_conv2d conv2d; 162 struct xnn_ukernel_dwconv dwconv; 163 struct xnn_ukernel_dwconv2d dwconv2d; 164 struct xnn_ukernel_gemm gemm; 165 struct xnn_ukernel_igemm igemm; 166 struct xnn_ukernel_spmm spmm; 167 struct xnn_ukernel_vmulcaddc vmulcaddc; 168 struct xnn_ukernel_vbinary vbinary; 169 struct xnn_ukernel_vunary vunary; 170 }; 171 }; 172 173 enum xnn_run_state { 174 xnn_run_state_invalid = 0, 175 xnn_run_state_ready, 176 xnn_run_state_skip, 177 }; 178 179 struct subconvolution_params { 180 void* weights; 181 size_t w_stride; 182 const void** indirection_buffer; 183 void* output; 184 size_t slice_width; 185 size_t slice_height; 186 size_t indirection_y_stride; 187 size_t indirection_x_stride; 188 // scaled_kernel_size := kernel_size * mr * sizeof(void*). 189 size_t scaled_kernel_size; 190 }; 191 192 struct xnn_operator { 193 size_t batch_size; 194 uint32_t padding_top; 195 uint32_t padding_right; 196 uint32_t padding_bottom; 197 uint32_t padding_left; 198 uint32_t kernel_height; 199 uint32_t kernel_width; 200 uint32_t stride_height; 201 uint32_t stride_width; 202 uint32_t dilation_height; 203 uint32_t dilation_width; 204 uint32_t groups; 205 size_t group_channels; 206 size_t group_input_channels; 207 size_t group_output_channels; 208 size_t channels; 209 210 size_t pad_before_channels; 211 size_t pad_after_channels; 212 uint32_t pad_value; 213 214 size_t input_height; 215 size_t input_width; 216 size_t input_pixel_stride; 217 const void* input; 218 const void* input2; 219 const void** indirection_buffer; 220 221 size_t output_height; 222 size_t output_width; 223 size_t output_pixel_stride; 224 void* output; 225 226 void* packed_weights; 227 // Total number of non-zero kernel elements when weights use sparse representation. 228 size_t num_nonzero_values; 229 // Total number of non-zero kernel blocks when weights use sparse representation. 230 size_t num_nonzero_blocks; 231 // Total number of output channel blocks when weights use sparse representation. 232 size_t num_output_channel_blocks; 233 // Input channel corresponding to the first non-zero kernel element. 234 size_t first_input_channel; 235 236 float input_scale; 237 float output_scale; 238 int32_t input_zero_point; 239 uint8_t output_zero_point; 240 uint8_t output_min; 241 uint8_t output_max; 242 243 size_t valid_batch_size; 244 size_t last_input_height; 245 size_t last_input_width; 246 const void* last_input; 247 size_t last_output_height; 248 size_t last_output_width; 249 void* last_output; 250 251 uint32_t block_size; 252 253 void* zero_buffer; 254 void* lookup_table; 255 void* pixelwise_buffer; 256 struct subconvolution_params* subconvolution_buffer; 257 uint32_t flags; 258 259 union { 260 union xnn_f32_abs_params f32_abs; 261 union xnn_f32_elu_params f32_elu; 262 union xnn_f32_lrelu_params f32_lrelu; 263 union xnn_f32_neg_params f32_neg; 264 union xnn_f32_rnd_params f32_rnd; 265 // Parameters for Global Average Pooling in CHW layout 266 union xnn_f32_gavgpool_params f32_gavgpool; 267 struct xnn_f16_hswish_params f16_hswish; 268 union xnn_f32_hswish_params f32_hswish; 269 struct { 270 struct xnn_f16_minmax_params f16_minmax; 271 struct xnn_f16_scaleminmax_params f16_scaleminmax; 272 }; 273 // Pixelwise Average Pooling normally use f32_minmax_params, but also initialize 274 // f32_scaleminmax_params in case it needs to switch to Global Average Pooling operation. 275 struct { 276 union xnn_f32_minmax_params f32_minmax; 277 union xnn_f32_scaleminmax_params f32_scaleminmax; 278 }; 279 union xnn_f32_chw_params f32_chw; 280 union xnn_qs8_gemm_params qs8_gemm; 281 // Average Pooling normally use qs8_avgpool_params, but also initialize qs8_gavgpool_params in case it needs to switch 282 // to Global Average Pooling operation. 283 struct { 284 union xnn_qs8_avgpool_params qs8_avgpool; 285 union xnn_qs8_avgpool_params qs8_gavgpool; 286 }; 287 // Quantized Add parameters are sensitive to order of inputs, so we initialize an extra copy with the reversed order. 288 struct { 289 union xnn_qs8_add_params qs8_add; 290 union xnn_qs8_add_params qs8_radd; 291 }; 292 union xnn_qu8_add_params qu8_add; 293 union xnn_qu8_gemm_params qu8_gemm; 294 // Average Pooling normally use qu8_avgpool_params, but also initialize qu8_gavgpool_params in case it needs to switch 295 // to Global Average Pooling operation. 296 struct { 297 union xnn_qu8_avgpool_params qu8_avgpool; 298 union xnn_qu8_avgpool_params qu8_gavgpool; 299 }; 300 union xnn_u8_minmax_params u8_minmax; 301 } params; 302 enum xnn_operator_type type; 303 struct xnn_ukernel ukernel; 304 305 struct compute_parameters compute; 306 struct compute_parameters compute2; 307 union { 308 struct argmax_pooling_context argmax_pooling; 309 struct average_pooling_context average_pooling; 310 struct channel_shuffle_context channel_shuffle; 311 struct conv2d_context conv2d; 312 struct dwconv2d_context dwconv2d; 313 struct dwconv_context dwconv; 314 struct depthtospace2d_chw2hwc_context depthtospace2d_chw; 315 struct depthtospace2d_hwc_context depthtospace2d_hwc; 316 struct elementwise_binary_context elementwise_binary; 317 struct gemm_context gemm; 318 struct global_average_pooling_nwc_context global_average_pooling_nwc; 319 struct global_average_pooling_ncw_context global_average_pooling_ncw; 320 struct igemm_context igemm; 321 struct lut_contiguous_context lut_contiguous; 322 struct lut_strided_context lut_strided; 323 struct max_pooling_context max_pooling; 324 struct pad_context pad; 325 struct pixelwise_average_pooling_context pixelwise_average_pooling; 326 struct prelu_context prelu; 327 struct resize_bilinear_context resize_bilinear; 328 struct resize_bilinear_chw_context resize_bilinear_chw; 329 struct spmm_context spmm; 330 struct subconv_context subconv; 331 struct subgemm_context subgemm; 332 struct f32_three_pass_softmax_context f32_three_pass_softmax; 333 struct u8_softmax_context u8_softmax; 334 struct univector_contiguous_context univector_contiguous; 335 struct univector_strided_context univector_strided; 336 struct unpooling_context unpooling; 337 struct vmulcaddc_context vmulcaddc; 338 } context; 339 340 enum xnn_run_state state; 341 }; 342