• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/codegen/fuser/fused_kernel.h>
5 
6 #include <cuda.h>
7 #include <cuda_runtime.h>
8 #include <nvrtc.h>
9 
10 #include <cstdint>
11 #include <string>
12 #include <vector>
13 
14 namespace torch::jit::fuser::cuda {
15 
16 // query codegen output arch and target
17 TORCH_CUDA_CU_API void codegenOutputQuery(
18     const cudaDeviceProp* const prop,
19     int& major,
20     int& minor,
21     bool& compile_to_sass);
22 
23 // A class holding metadata for an actual CUDA function.
24 // Note: CUDA functions are per device.
25 struct TORCH_CUDA_CU_API FusedKernelCUDA
26     : public ::torch::jit::fuser::FusedKernel {
27   FusedKernelCUDA(
28       at::DeviceIndex device,
29       std::string name,
30       std::string code,
31       std::vector<TensorDesc> input_desc,
32       std::vector<TensorDesc> output_desc,
33       std::vector<PartitionDesc> chunk_desc,
34       std::vector<PartitionDesc> concat_desc,
35       bool has_random);
36 
37   ~FusedKernelCUDA() override;
38 
39   void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
40       const override;
41 
backendFusedKernelCUDA42   at::Backend backend() const override {
43     return at::Backend::CUDA;
44   }
45 
46  private:
47   static constexpr auto kBlockSize = 128;
48 
49   // Note: per device to store device properties and compute launch heuristics
50   //  Acquiring these values at launch time would be too slow
51   at::DeviceIndex device_;
52   int maxBlocks_{};
53   cudaDeviceProp* prop_{};
54   std::vector<char> ptx_;
55   CUmodule module_{};
56   CUfunction function_{};
57 };
58 
59 } // namespace torch::jit::fuser::cuda
60