README.md
1This directory contains files copied from
2`src/fastertransformer/cutlass_extensions/include/cutlass_extensions`
3directory of the source tree of
4[*FasterTransformer*](https://github.com/NVIDIA/FasterTransformer)
5project. These are intended for supporting mixed datatypes GEMM
6implementation, in `aten/src/ATen/native/cuda/MixedDTypesLinear.cu`
7file of *PyTorch* source tree. Not all files from given directory of
8*FasterTransformer* project are here, only ones necessary to support
9mentioned functionality.
10
11The original copy of these files is made from commit `f8e42aa` of
12*FasterTransformer project*. The changes from original files are
13minimal, just to support *CUTLASS* 3.x (*FasterTransfomer* project
14was, as of mentioned commit, based on *CUTLASS* 2.10). However, the
15copies of files in the *PyTorch* source tree are linted using
16*PyTorch* lint rules, so at this stage they differ quite a bit from
17the original files. Thus, for keeping track of the original changes,
18here is the diff between the two sets of files, before linting:
19
20```
21Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions: compute_occupancy.h
22Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/epilogue: epilogue_quant_helper.h
23Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/epilogue: threadblock
24diff -r FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
25157c157,158
26< struct Params {
27---
28> struct Params
29> {
30183d183
31< CUTLASS_HOST_DEVICE
32186d185
33< CUTLASS_HOST_DEVICE
34188,190c187,188
35< cutlass::gemm::GemmCoord const& grid_tiled_shape,
36< const int gemm_k_size,
37< void* workspace = nullptr):
38---
39> int device_sms,
40> int sm_occupancy):
41192d189
42< grid_tiled_shape(grid_tiled_shape),
43205,206d201
44< semaphore(static_cast<int*>(workspace)),
45< gemm_k_size(gemm_k_size),
46210a206,227
47> ThreadblockSwizzle swizzle;
48> grid_tiled_shape = swizzle.get_tiled_shape(
49> args.problem_size,
50> {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
51> args.batch_count);
52>
53> gemm_k_size = args.problem_size.k();
54> }
55>
56> size_t get_workspace_size() const
57> {
58> return 0;
59> }
60>
61> Status init_workspace(void *workspace,cudaStream_t stream = nullptr)
62> {
63> return Status::kSuccess;
64> }
65>
66> dim3 get_grid_dims() const
67> {
68> return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape);
69278,283d294
70< static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
71< {
72<
73< return 0;
74< }
75<
76464a476,482
77> CUTLASS_DEVICE
78> static void invoke(Params const ¶ms, SharedStorage &shared_storage)
79> {
80> GemmFpAIntB op;
81> op(params, shared_storage);
82> }
83>
84492c510
85< } // namespace cutlass
86\ No newline at end of file
87---
88> } // namespace cutlass
89Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel: gemm_moe_problem_visitor.h
90Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel: gemm_with_epilogue_visitor.h
91Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel: moe_cutlass_kernel.h
92Only in FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/kernel: moe_problem_visitor.h
93diff -r FasterTransformer/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
9455c55,58
95< #include <src/fastertransformer/utils/cuda_bf16_wrapper.h>
96---
97> //#include <src/fastertransformer/utils/cuda_bf16_wrapper.h>
98> //#ifdef ENABLE_BF16
99> #include <cuda_bf16.h>
100> //#endif
101155c158,159
102< #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
103---
104> //#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
105> #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
106470c474
107< ////////////////////////////////////////////////////////////////////////////////
108\ No newline at end of file
109---
110> ////////////////////////////////////////////////////////////////////////////////
111```
112
113As mentioned [here](https://github.com/NVIDIA/cutlass/discussions/911)
114and [here](https://github.com/NVIDIA/cutlass/issues/1060), *CUTLASS*
115itself is expected to include the functionality provided by these
116extensions, so hopefully this whole directory will be removed from
117*PyTorch* source tree at some later time.
118