1 #include <cstring>
2 #include <iostream>
3 #include <sstream>
4
5 #include <ATen/native/vulkan/api/Adapter.h>
6 #include <ATen/native/vulkan/api/Runtime.h>
7
8 namespace at {
9 namespace native {
10 namespace vulkan {
11 namespace api {
12
13 namespace {
14
find_requested_layers_and_extensions(std::vector<const char * > & enabled_layers,std::vector<const char * > & enabled_extensions,const std::vector<const char * > & requested_layers,const std::vector<const char * > & requested_extensions)15 void find_requested_layers_and_extensions(
16 std::vector<const char*>& enabled_layers,
17 std::vector<const char*>& enabled_extensions,
18 const std::vector<const char*>& requested_layers,
19 const std::vector<const char*>& requested_extensions) {
20 // Get supported instance layers
21 uint32_t layer_count = 0;
22 VK_CHECK(vkEnumerateInstanceLayerProperties(&layer_count, nullptr));
23
24 std::vector<VkLayerProperties> layer_properties(layer_count);
25 VK_CHECK(vkEnumerateInstanceLayerProperties(
26 &layer_count, layer_properties.data()));
27
28 // Search for requested layers
29 for (const auto& requested_layer : requested_layers) {
30 for (const auto& layer : layer_properties) {
31 if (strcmp(requested_layer, layer.layerName) == 0) {
32 enabled_layers.push_back(requested_layer);
33 break;
34 }
35 }
36 }
37
38 // Get supported instance extensions
39 uint32_t extension_count = 0;
40 VK_CHECK(vkEnumerateInstanceExtensionProperties(
41 nullptr, &extension_count, nullptr));
42
43 std::vector<VkExtensionProperties> extension_properties(extension_count);
44 VK_CHECK(vkEnumerateInstanceExtensionProperties(
45 nullptr, &extension_count, extension_properties.data()));
46
47 // Search for requested extensions
48 for (const auto& requested_extension : requested_extensions) {
49 for (const auto& extension : extension_properties) {
50 if (strcmp(requested_extension, extension.extensionName) == 0) {
51 enabled_extensions.push_back(requested_extension);
52 break;
53 }
54 }
55 }
56 }
57
create_instance(const RuntimeConfiguration & config)58 VkInstance create_instance(const RuntimeConfiguration& config) {
59 const VkApplicationInfo application_info{
60 VK_STRUCTURE_TYPE_APPLICATION_INFO, // sType
61 nullptr, // pNext
62 "PyTorch Vulkan Backend", // pApplicationName
63 0, // applicationVersion
64 nullptr, // pEngineName
65 0, // engineVersion
66 VK_API_VERSION_1_0, // apiVersion
67 };
68
69 std::vector<const char*> enabled_layers;
70 std::vector<const char*> enabled_extensions;
71
72 if (config.enableValidationMessages) {
73 std::vector<const char*> requested_layers{
74 // "VK_LAYER_LUNARG_api_dump",
75 "VK_LAYER_KHRONOS_validation",
76 };
77 std::vector<const char*> requested_extensions{
78 #ifdef VK_EXT_debug_report
79 VK_EXT_DEBUG_REPORT_EXTENSION_NAME,
80 #endif /* VK_EXT_debug_report */
81 };
82
83 find_requested_layers_and_extensions(
84 enabled_layers,
85 enabled_extensions,
86 requested_layers,
87 requested_extensions);
88 }
89
90 const VkInstanceCreateInfo instance_create_info{
91 VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType
92 nullptr, // pNext
93 0u, // flags
94 &application_info, // pApplicationInfo
95 static_cast<uint32_t>(enabled_layers.size()), // enabledLayerCount
96 enabled_layers.data(), // ppEnabledLayerNames
97 static_cast<uint32_t>(enabled_extensions.size()), // enabledExtensionCount
98 enabled_extensions.data(), // ppEnabledExtensionNames
99 };
100
101 VkInstance instance{};
102 VK_CHECK(vkCreateInstance(&instance_create_info, nullptr, &instance));
103 VK_CHECK_COND(instance, "Invalid Vulkan instance!");
104
105 #ifdef USE_VULKAN_VOLK
106 volkLoadInstance(instance);
107 #endif /* USE_VULKAN_VOLK */
108
109 return instance;
110 }
111
create_physical_devices(VkInstance instance)112 std::vector<Runtime::DeviceMapping> create_physical_devices(
113 VkInstance instance) {
114 if (VK_NULL_HANDLE == instance) {
115 return std::vector<Runtime::DeviceMapping>();
116 }
117
118 uint32_t device_count = 0;
119 VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, nullptr));
120
121 std::vector<VkPhysicalDevice> devices(device_count);
122 VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, devices.data()));
123
124 std::vector<Runtime::DeviceMapping> device_mappings;
125 device_mappings.reserve(device_count);
126 for (VkPhysicalDevice physical_device : devices) {
127 device_mappings.emplace_back(PhysicalDevice(physical_device), -1);
128 }
129
130 return device_mappings;
131 }
132
debug_report_callback_fn(const VkDebugReportFlagsEXT flags,const VkDebugReportObjectTypeEXT,const uint64_t,const size_t,const int32_t message_code,const char * const layer_prefix,const char * const message,void * const)133 VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
134 const VkDebugReportFlagsEXT flags,
135 const VkDebugReportObjectTypeEXT /* object_type */,
136 const uint64_t /* object */,
137 const size_t /* location */,
138 const int32_t message_code,
139 const char* const layer_prefix,
140 const char* const message,
141 void* const /* user_data */) {
142 (void)flags;
143
144 std::stringstream stream;
145 stream << layer_prefix << " " << message_code << " " << message << std::endl;
146 const std::string log = stream.str();
147
148 std::cout << log;
149
150 return VK_FALSE;
151 }
152
create_debug_report_callback(VkInstance instance,const RuntimeConfiguration config)153 VkDebugReportCallbackEXT create_debug_report_callback(
154 VkInstance instance,
155 const RuntimeConfiguration config) {
156 if (VK_NULL_HANDLE == instance || !config.enableValidationMessages) {
157 return VkDebugReportCallbackEXT{};
158 }
159
160 const VkDebugReportCallbackCreateInfoEXT debugReportCallbackCreateInfo{
161 VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT, // sType
162 nullptr, // pNext
163 VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT |
164 VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT |
165 VK_DEBUG_REPORT_ERROR_BIT_EXT |
166 VK_DEBUG_REPORT_DEBUG_BIT_EXT, // flags
167 debug_report_callback_fn, // pfnCallback
168 nullptr, // pUserData
169 };
170
171 const auto vkCreateDebugReportCallbackEXT =
172 (PFN_vkCreateDebugReportCallbackEXT)vkGetInstanceProcAddr(
173 instance, "vkCreateDebugReportCallbackEXT");
174
175 VK_CHECK_COND(
176 vkCreateDebugReportCallbackEXT,
177 "Could not load vkCreateDebugReportCallbackEXT");
178
179 VkDebugReportCallbackEXT debug_report_callback{};
180 VK_CHECK(vkCreateDebugReportCallbackEXT(
181 instance,
182 &debugReportCallbackCreateInfo,
183 nullptr,
184 &debug_report_callback));
185
186 VK_CHECK_COND(debug_report_callback, "Invalid Vulkan debug report callback!");
187
188 return debug_report_callback;
189 }
190
191 //
192 // Adapter selection methods
193 //
194
select_first(const std::vector<Runtime::DeviceMapping> & devices)195 uint32_t select_first(const std::vector<Runtime::DeviceMapping>& devices) {
196 if (devices.empty()) {
197 return devices.size() + 1; // return out of range to signal invalidity
198 }
199
200 // Select the first adapter that has compute capability
201 for (size_t i = 0; i < devices.size(); ++i) {
202 if (devices[i].first.num_compute_queues > 0) {
203 return i;
204 }
205 }
206
207 return devices.size() + 1;
208 }
209
210 //
211 // Global runtime initialization
212 //
213
init_global_vulkan_runtime()214 std::unique_ptr<Runtime> init_global_vulkan_runtime() {
215 // Load Vulkan drivers
216 #if defined(USE_VULKAN_VOLK)
217 if (VK_SUCCESS != volkInitialize()) {
218 return std::unique_ptr<Runtime>(nullptr);
219 }
220 #elif defined(USE_VULKAN_WRAPPER)
221 if (!InitVulkan()) {
222 return std::unique_ptr<Runtime>(nullptr);
223 }
224 #endif /* USE_VULKAN_VOLK, USE_VULKAN_WRAPPER */
225
226 const bool enableValidationMessages =
227 #if defined(VULKAN_DEBUG)
228 true;
229 #else
230 false;
231 #endif /* VULKAN_DEBUG */
232 const bool initDefaultDevice = true;
233 const uint32_t numRequestedQueues = 1; // TODO: raise this value
234
235 const RuntimeConfiguration default_config{
236 enableValidationMessages,
237 initDefaultDevice,
238 AdapterSelector::First,
239 numRequestedQueues,
240 };
241
242 try {
243 return std::make_unique<Runtime>(Runtime(default_config));
244 } catch (...) {
245 }
246
247 return std::unique_ptr<Runtime>(nullptr);
248 }
249
250 } // namespace
251
Runtime(const RuntimeConfiguration config)252 Runtime::Runtime(const RuntimeConfiguration config)
253 : config_(config),
254 instance_(create_instance(config_)),
255 device_mappings_(create_physical_devices(instance_)),
256 adapters_{},
257 default_adapter_i_(UINT32_MAX),
258 debug_report_callback_(create_debug_report_callback(instance_, config_)) {
259 // List of adapters will never exceed the number of physical devices
260 adapters_.reserve(device_mappings_.size());
261
262 if (config.initDefaultDevice) {
263 try {
264 switch (config.defaultSelector) {
265 case AdapterSelector::First:
266 default_adapter_i_ = create_adapter(select_first);
267 }
268 } catch (...) {
269 }
270 }
271 }
272
~Runtime()273 Runtime::~Runtime() {
274 if (VK_NULL_HANDLE == instance_) {
275 return;
276 }
277
278 // Clear adapters list to trigger device destruction before destroying
279 // VkInstance
280 adapters_.clear();
281
282 // Instance must be destroyed last as its used to destroy the debug report
283 // callback.
284 if (debug_report_callback_) {
285 const auto vkDestroyDebugReportCallbackEXT =
286 (PFN_vkDestroyDebugReportCallbackEXT)vkGetInstanceProcAddr(
287 instance_, "vkDestroyDebugReportCallbackEXT");
288
289 if (vkDestroyDebugReportCallbackEXT) {
290 vkDestroyDebugReportCallbackEXT(
291 instance_, debug_report_callback_, nullptr);
292 }
293
294 debug_report_callback_ = {};
295 }
296
297 vkDestroyInstance(instance_, nullptr);
298 instance_ = VK_NULL_HANDLE;
299 }
300
Runtime(Runtime && other)301 Runtime::Runtime(Runtime&& other) noexcept
302 : config_(other.config_),
303 instance_(other.instance_),
304 adapters_(std::move(other.adapters_)),
305 default_adapter_i_(other.default_adapter_i_),
306 debug_report_callback_(other.debug_report_callback_) {
307 other.instance_ = VK_NULL_HANDLE;
308 other.debug_report_callback_ = {};
309 }
310
create_adapter(const Selector & selector)311 uint32_t Runtime::create_adapter(const Selector& selector) {
312 VK_CHECK_COND(
313 !device_mappings_.empty(),
314 "Pytorch Vulkan Runtime: Could not initialize adapter because no "
315 "devices were found by the Vulkan instance.");
316
317 uint32_t physical_device_i = selector(device_mappings_);
318 VK_CHECK_COND(
319 physical_device_i < device_mappings_.size(),
320 "Pytorch Vulkan Runtime: no suitable device adapter was selected! "
321 "Device could not be initialized");
322
323 Runtime::DeviceMapping& device_mapping = device_mappings_[physical_device_i];
324 // If an Adapter has already been created, return that
325 int32_t adapter_i = device_mapping.second;
326 if (adapter_i >= 0) {
327 return adapter_i;
328 }
329 // Otherwise, create an adapter for the selected physical device
330 adapter_i = utils::safe_downcast<int32_t>(adapters_.size());
331 adapters_.emplace_back(
332 new Adapter(instance_, device_mapping.first, config_.numRequestedQueues));
333 device_mapping.second = adapter_i;
334
335 return adapter_i;
336 }
337
runtime()338 Runtime* runtime() {
339 // The global vulkan runtime is declared as a static local variable within a
340 // non-static function to ensure it has external linkage. If it were a global
341 // static variable there would be one copy per translation unit that includes
342 // Runtime.h as it would have internal linkage.
343 static const std::unique_ptr<Runtime> p_runtime =
344 init_global_vulkan_runtime();
345
346 VK_CHECK_COND(
347 p_runtime,
348 "Pytorch Vulkan Runtime: The global runtime could not be retrieved "
349 "because it failed to initialize.");
350
351 return p_runtime.get();
352 }
353
354 } // namespace api
355 } // namespace vulkan
356 } // namespace native
357 } // namespace at
358