• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/Config.h>
4 
5 #include <c10/core/Device.h>
6 #include <c10/xpu/XPUFunctions.h>
7 #include <c10/xpu/XPUStream.h>
8 
9 #include <oneapi/dnnl/dnnl.hpp>
10 #include <oneapi/dnnl/dnnl_sycl.hpp>
11 #include <vector>
12 
13 namespace at::native::onednn {
14 
15 TORCH_XPU_API dnnl::memory make_onednn_memory(
16     dnnl::memory::desc md,
17     dnnl::engine& engine,
18     void* ptr);
19 
20 // Keep non-static and non-inline
21 bool set_onednn_verbose(int level);
22 
23 // GpuEngineManager singleton
24 struct TORCH_XPU_API GpuEngineManager {
25   static GpuEngineManager& Instance(); // Singleton
26 
get_engineGpuEngineManager27   dnnl::engine& get_engine(const Device& device) {
28     TORCH_INTERNAL_ASSERT(device.type() == kXPU);
29     TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count());
30     return *engine_pool[device.index()];
31   }
32 
33   GpuEngineManager(GpuEngineManager const&) = delete;
34   GpuEngineManager& operator=(GpuEngineManager const&) = delete;
35 
36  protected:
GpuEngineManagerGpuEngineManager37   GpuEngineManager() {
38     int device_count = (int)c10::xpu::device_count();
39     TORCH_INTERNAL_ASSERT(device_count > 0);
40     for (int i = 0; i < device_count; i++) {
41         engine_pool.push_back(
42             std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine(
43               c10::xpu::get_raw_device(i), c10::xpu::get_device_context()
44             )));
45     }
46   }
~GpuEngineManagerGpuEngineManager47   ~GpuEngineManager() {}
48 
49  private:
50   std::vector<std::shared_ptr<dnnl::engine>> engine_pool;
51 };
52 
53 // GpuStreamManager singleton
54 struct TORCH_XPU_API GpuStreamManager {
55   static GpuStreamManager& Instance(); // Singleton
56 
get_streamGpuStreamManager57   dnnl::stream get_stream() {
58     c10::DeviceIndex device_index = c10::xpu::current_device();
59     TORCH_INTERNAL_ASSERT(device_index < c10::xpu::device_count());
60     return dnnl::sycl_interop::make_stream(
61         GpuEngineManager::Instance().get_engine({c10::kXPU, device_index}),
62         c10::xpu::getCurrentXPUStream(device_index).queue());
63   }
64 
65   GpuStreamManager(GpuStreamManager const&) = delete;
66   GpuStreamManager& operator=(GpuStreamManager const&) = delete;
67 
68  protected:
GpuStreamManagerGpuStreamManager69   GpuStreamManager() {
70   }
~GpuStreamManagerGpuStreamManager71   ~GpuStreamManager() {}
72 
73 };
74 
75 } // namespace at::native::onednn
76