1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stdint.h> 12 #include <stddef.h> 13 14 #include <xnnpack/common.h> 15 #include <xnnpack/operator.h> 16 17 18 #ifdef __cplusplus 19 extern "C" { 20 #endif 21 22 23 struct xnn_qu8_packing_params { 24 uint8_t input_zero_point; 25 uint8_t kernel_zero_point; 26 }; 27 28 struct xnn_qs8_packing_params { 29 int8_t input_zero_point; 30 }; 31 32 33 typedef void (*xnn_pack_gemm_goi_w_function)( 34 size_t g, 35 size_t nc, 36 size_t kc, 37 size_t nr, 38 size_t kr, 39 size_t sr, 40 const void* k, 41 const void* b, 42 void* packed_w, 43 const void* params); 44 45 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w( 46 size_t g, 47 size_t nc, 48 size_t kc, 49 size_t nr, 50 size_t kr, 51 size_t sr, 52 const float* k, 53 const float* b, 54 float* packed_w, 55 const void* params); 56 57 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w( 58 size_t g, 59 size_t nc, 60 size_t kc, 61 size_t nr, 62 size_t kr, 63 size_t sr, 64 const uint16_t* k, 65 const uint16_t* b, 66 uint16_t* packed_w, 67 const void* params); 68 69 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w( 70 size_t g, 71 size_t nc, 72 size_t kc, 73 size_t nr, 74 size_t kr, 75 size_t sr, 76 const uint8_t* k, 77 const int32_t* b, 78 void* packed_w, 79 const struct xnn_qu8_packing_params* params); 80 81 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w( 82 size_t g, 83 size_t nc, 84 size_t kc, 85 size_t nr, 86 size_t kr, 87 size_t sr, 88 const int8_t* k, 89 const int32_t* b, 90 void* packed_w, 91 const struct xnn_qs8_packing_params* params); 92 93 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w( 94 size_t g, 95 size_t nc, 96 size_t kc, 97 size_t nr, 98 size_t kr, 99 size_t sr, 100 const int8_t* k, 101 const int32_t* b, 102 void* packed_w, 103 const struct xnn_qs8_packing_params* params); 104 105 106 typedef void (*xnn_pack_gemm_io_w_function)( 107 size_t nc, 108 size_t kc, 109 size_t nr, 110 size_t kr, 111 size_t sr, 112 const void* k, 113 const void* b, 114 void* packed_w, 115 const void* params); 116 117 XNN_INTERNAL void xnn_pack_f32_gemm_io_w( 118 size_t nc, 119 size_t kc, 120 size_t nr, 121 size_t kr, 122 size_t sr, 123 const float* k, 124 const float* b, 125 float* packed_w, 126 const void* params); 127 128 XNN_INTERNAL void xnn_pack_f16_gemm_io_w( 129 size_t nc, 130 size_t kc, 131 size_t nr, 132 size_t kr, 133 size_t sr, 134 const uint16_t* k, 135 const uint16_t* b, 136 uint16_t* packed_w, 137 const void* params); 138 139 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w( 140 size_t nc, 141 size_t kc, 142 size_t nr, 143 size_t kr, 144 size_t sr, 145 const uint8_t* k, 146 const int32_t* b, 147 void* packed_w, 148 const struct xnn_qu8_packing_params* params); 149 150 151 typedef void (*xnn_pack_conv_goki_w_function)( 152 size_t g, 153 size_t nc, 154 size_t ks, 155 size_t kc, 156 size_t nr, 157 size_t kr, 158 size_t sr, 159 const void* k, 160 const void* b, 161 void* packed_w, 162 const void* params); 163 164 XNN_INTERNAL void xnn_pack_f32_conv_goki_w( 165 size_t g, 166 size_t nc, 167 size_t ks, 168 size_t kc, 169 size_t nr, 170 size_t kr, 171 size_t sr, 172 const float* k, 173 const float* b, 174 float* packed_w, 175 const void* params); 176 177 XNN_INTERNAL void xnn_pack_f16_conv_goki_w( 178 size_t g, 179 size_t nc, 180 size_t ks, 181 size_t kc, 182 size_t nr, 183 size_t kr, 184 size_t sr, 185 const uint16_t* k, 186 const uint16_t* b, 187 uint16_t* packed_w, 188 const void* params); 189 190 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w( 191 size_t g, 192 size_t nc, 193 size_t ks, 194 size_t kc, 195 size_t nr, 196 size_t kr, 197 size_t sr, 198 const uint8_t* k, 199 const int32_t* b, 200 void* packed_w, 201 const struct xnn_qu8_packing_params* params); 202 203 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w( 204 size_t g, 205 size_t nc, 206 size_t ks, 207 size_t kc, 208 size_t nr, 209 size_t kr, 210 size_t sr, 211 const int8_t* k, 212 const int32_t* b, 213 void* packed_w, 214 const struct xnn_qs8_packing_params* params); 215 216 217 typedef void (*xnn_pack_conv_kgo_w_function)( 218 size_t g, 219 size_t nc, 220 size_t ks, 221 size_t nr, 222 size_t kr, 223 const void* k, 224 const void* b, 225 void* packed_w, 226 const void* params); 227 228 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w( 229 size_t g, 230 size_t nc, 231 size_t ks, 232 size_t nr, 233 size_t kr, 234 const float* k, 235 const float* b, 236 float* packed_w, 237 const void* params); 238 239 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w( 240 size_t g, 241 size_t nc, 242 size_t ks, 243 size_t nr, 244 size_t kr, 245 const uint16_t* k, 246 const uint16_t* b, 247 uint16_t* packed_w, 248 const void* params); 249 250 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w( 251 size_t g, 252 size_t nc, 253 size_t ks, 254 size_t nr, 255 size_t kr, 256 const uint8_t* k, 257 const int32_t* b, 258 void* packed_w, 259 const struct xnn_qu8_packing_params* params); 260 261 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w( 262 size_t g, 263 size_t nc, 264 size_t ks, 265 size_t nr, 266 size_t kr, 267 const int8_t* k, 268 const int32_t* b, 269 void* packed_w, 270 const struct xnn_qs8_packing_params* params); 271 272 273 typedef void (*xnn_pack_deconv_goki_w_function)( 274 size_t g, 275 size_t nc, 276 size_t kh, 277 size_t kw, 278 size_t kc, 279 size_t sh, 280 size_t sw, 281 size_t nr, 282 size_t kr, 283 size_t sr, 284 const void* k, 285 const void* b, 286 void* packed_w, 287 struct subconvolution_params* subconv_params, 288 const void* params); 289 290 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w( 291 size_t g, 292 size_t nc, 293 size_t kh, 294 size_t kw, 295 size_t kc, 296 size_t sh, 297 size_t sw, 298 size_t nr, 299 size_t kr, 300 size_t sr, 301 const float* k, 302 const float* b, 303 float* packed_w, 304 struct subconvolution_params* subconv_params, 305 const void* params); 306 307 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w( 308 size_t g, 309 size_t nc, 310 size_t kh, 311 size_t kw, 312 size_t kc, 313 size_t sh, 314 size_t sw, 315 size_t nr, 316 size_t kr, 317 size_t sr, 318 const uint16_t* k, 319 const uint16_t* b, 320 uint16_t* packed_w, 321 struct subconvolution_params* subconv_params, 322 const void* params); 323 324 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w( 325 size_t g, 326 size_t nc, 327 size_t kh, 328 size_t kw, 329 size_t kc, 330 size_t sh, 331 size_t sw, 332 size_t nr, 333 size_t kr, 334 size_t sr, 335 const uint8_t* k, 336 const int32_t* b, 337 void* packed_w, 338 struct subconvolution_params* subconv_params, 339 const struct xnn_qu8_packing_params* params); 340 341 342 typedef void (*xnn_pack_dwconv_ghw_w_function)( 343 size_t h, 344 size_t w, 345 size_t c, 346 size_t cr, 347 const void* k, 348 const void* b, 349 void* packed_w, 350 const void* params); 351 352 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w( 353 size_t h, 354 size_t w, 355 size_t c, 356 size_t cr, 357 const float* k, 358 const float* b, 359 float* packed_w, 360 const void* params); 361 362 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w( 363 size_t h, 364 size_t w, 365 size_t c, 366 size_t cr, 367 const uint16_t* k, 368 const uint16_t* b, 369 uint16_t* packed_w, 370 const void* params); 371 372 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w( 373 size_t h, 374 size_t w, 375 size_t c, 376 size_t cr, 377 const uint8_t* k, 378 const int32_t* b, 379 void* packed_w, 380 const struct xnn_qu8_packing_params* params); 381 382 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w( 383 size_t h, 384 size_t w, 385 size_t c, 386 size_t cr, 387 const int8_t* k, 388 const int32_t* b, 389 void* packed_w, 390 const struct xnn_qs8_packing_params* params); 391 392 393 typedef void (*xnn_pack_dwconv_hwg_w_function)( 394 size_t h, 395 size_t w, 396 size_t c, 397 size_t cr, 398 const void* k, 399 const void* b, 400 void* packed_w, 401 const void* params); 402 403 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w( 404 size_t h, 405 size_t w, 406 size_t c, 407 size_t cr, 408 const float* k, 409 const float* b, 410 float* packed_w, 411 const void* params); 412 413 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w( 414 size_t h, 415 size_t w, 416 size_t c, 417 size_t cr, 418 const uint16_t* k, 419 const uint16_t* b, 420 uint16_t* packed_w, 421 const void* params); 422 423 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w( 424 size_t h, 425 size_t w, 426 size_t c, 427 size_t cr, 428 const uint8_t* k, 429 const int32_t* b, 430 void* packed_w, 431 const struct xnn_qu8_packing_params* params); 432 433 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w( 434 size_t h, 435 size_t w, 436 size_t c, 437 size_t cr, 438 const int8_t* k, 439 const int32_t* b, 440 void* packed_w, 441 const struct xnn_qs8_packing_params* params); 442 443 444 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w( 445 size_t g, 446 size_t nc, 447 size_t kc, 448 size_t nr, 449 size_t kr, 450 size_t sr, 451 const float* k, 452 float* packed_w, 453 const void* params); 454 455 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w( 456 size_t g, 457 size_t nc, 458 size_t kc, 459 size_t nr, 460 size_t kr, 461 size_t sr, 462 const uint16_t* k, 463 uint16_t* packed_w, 464 const void* params); 465 466 467 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w( 468 size_t nc, 469 size_t kc, 470 size_t nr, 471 size_t kh, 472 size_t kw, 473 const float* k, 474 const float* b, 475 float* packed_w, 476 const void* params); 477 478 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w( 479 size_t nc, 480 size_t kc, 481 size_t nr, 482 size_t kh, 483 size_t kw, 484 const uint16_t* k, 485 const uint16_t* b, 486 uint16_t* packed_w, 487 const void* params); 488 489 490 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w( 491 size_t kernel_size, 492 size_t groups, 493 const float* kernel, 494 const float* bias, 495 float* packed_weights, 496 const void* params); 497 498 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w( 499 size_t kernel_size, 500 size_t groups, 501 const uint16_t* kernel, 502 const uint16_t* bias, 503 uint16_t* packed_weights, 504 const void* params); 505 506 507 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w( 508 size_t kernel_size, 509 size_t groups, 510 const float* kernel, 511 const float* bias, 512 float* packed_weights, 513 const void* params); 514 515 516 typedef void (*xnn_pack_vmulcaddc_w_function)( 517 size_t c, 518 size_t cr, 519 const void* s, 520 const void* b, 521 void* packed_w, 522 const void* params); 523 524 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w( 525 size_t c, 526 size_t cr, 527 const float* s, 528 const float* b, 529 float* packed_w, 530 const void* params); 531 532 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w( 533 size_t c, 534 size_t cr, 535 const uint16_t* s, 536 const uint16_t* b, 537 uint16_t* packed_w, 538 const void* params); 539 540 #ifdef __cplusplus 541 } // extern "C" 542 #endif 543