• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
9 using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
10 using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
11 
12 DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
13 DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
14 DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
15 
16 } // namespace at::native
17