1 #pragma once 2 3 #include <ATen/Config.h> 4 5 #include <c10/core/Device.h> 6 #include <c10/xpu/XPUFunctions.h> 7 #include <c10/xpu/XPUStream.h> 8 9 #include <oneapi/dnnl/dnnl.hpp> 10 #include <oneapi/dnnl/dnnl_sycl.hpp> 11 #include <vector> 12 13 namespace at::native::onednn { 14 15 TORCH_XPU_API dnnl::memory make_onednn_memory( 16 dnnl::memory::desc md, 17 dnnl::engine& engine, 18 void* ptr); 19 20 // Keep non-static and non-inline 21 bool set_onednn_verbose(int level); 22 23 // GpuEngineManager singleton 24 struct TORCH_XPU_API GpuEngineManager { 25 static GpuEngineManager& Instance(); // Singleton 26 get_engineGpuEngineManager27 dnnl::engine& get_engine(const Device& device) { 28 TORCH_INTERNAL_ASSERT(device.type() == kXPU); 29 TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count()); 30 return *engine_pool[device.index()]; 31 } 32 33 GpuEngineManager(GpuEngineManager const&) = delete; 34 GpuEngineManager& operator=(GpuEngineManager const&) = delete; 35 36 protected: GpuEngineManagerGpuEngineManager37 GpuEngineManager() { 38 int device_count = (int)c10::xpu::device_count(); 39 TORCH_INTERNAL_ASSERT(device_count > 0); 40 for (int i = 0; i < device_count; i++) { 41 engine_pool.push_back( 42 std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine( 43 c10::xpu::get_raw_device(i), c10::xpu::get_device_context() 44 ))); 45 } 46 } ~GpuEngineManagerGpuEngineManager47 ~GpuEngineManager() {} 48 49 private: 50 std::vector<std::shared_ptr<dnnl::engine>> engine_pool; 51 }; 52 53 // GpuStreamManager singleton 54 struct TORCH_XPU_API GpuStreamManager { 55 static GpuStreamManager& Instance(); // Singleton 56 get_streamGpuStreamManager57 dnnl::stream get_stream() { 58 c10::DeviceIndex device_index = c10::xpu::current_device(); 59 TORCH_INTERNAL_ASSERT(device_index < c10::xpu::device_count()); 60 return dnnl::sycl_interop::make_stream( 61 GpuEngineManager::Instance().get_engine({c10::kXPU, device_index}), 62 c10::xpu::getCurrentXPUStream(device_index).queue()); 63 } 64 65 GpuStreamManager(GpuStreamManager const&) = delete; 66 GpuStreamManager& operator=(GpuStreamManager const&) = delete; 67 68 protected: GpuStreamManagerGpuStreamManager69 GpuStreamManager() { 70 } ~GpuStreamManagerGpuStreamManager71 ~GpuStreamManager() {} 72 73 }; 74 75 } // namespace at::native::onednn 76