• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <cstring>
2 #include <iostream>
3 #include <sstream>
4 
5 #include <ATen/native/vulkan/api/Adapter.h>
6 #include <ATen/native/vulkan/api/Runtime.h>
7 
8 namespace at {
9 namespace native {
10 namespace vulkan {
11 namespace api {
12 
13 namespace {
14 
find_requested_layers_and_extensions(std::vector<const char * > & enabled_layers,std::vector<const char * > & enabled_extensions,const std::vector<const char * > & requested_layers,const std::vector<const char * > & requested_extensions)15 void find_requested_layers_and_extensions(
16     std::vector<const char*>& enabled_layers,
17     std::vector<const char*>& enabled_extensions,
18     const std::vector<const char*>& requested_layers,
19     const std::vector<const char*>& requested_extensions) {
20   // Get supported instance layers
21   uint32_t layer_count = 0;
22   VK_CHECK(vkEnumerateInstanceLayerProperties(&layer_count, nullptr));
23 
24   std::vector<VkLayerProperties> layer_properties(layer_count);
25   VK_CHECK(vkEnumerateInstanceLayerProperties(
26       &layer_count, layer_properties.data()));
27 
28   // Search for requested layers
29   for (const auto& requested_layer : requested_layers) {
30     for (const auto& layer : layer_properties) {
31       if (strcmp(requested_layer, layer.layerName) == 0) {
32         enabled_layers.push_back(requested_layer);
33         break;
34       }
35     }
36   }
37 
38   // Get supported instance extensions
39   uint32_t extension_count = 0;
40   VK_CHECK(vkEnumerateInstanceExtensionProperties(
41       nullptr, &extension_count, nullptr));
42 
43   std::vector<VkExtensionProperties> extension_properties(extension_count);
44   VK_CHECK(vkEnumerateInstanceExtensionProperties(
45       nullptr, &extension_count, extension_properties.data()));
46 
47   // Search for requested extensions
48   for (const auto& requested_extension : requested_extensions) {
49     for (const auto& extension : extension_properties) {
50       if (strcmp(requested_extension, extension.extensionName) == 0) {
51         enabled_extensions.push_back(requested_extension);
52         break;
53       }
54     }
55   }
56 }
57 
create_instance(const RuntimeConfiguration & config)58 VkInstance create_instance(const RuntimeConfiguration& config) {
59   const VkApplicationInfo application_info{
60       VK_STRUCTURE_TYPE_APPLICATION_INFO, // sType
61       nullptr, // pNext
62       "PyTorch Vulkan Backend", // pApplicationName
63       0, // applicationVersion
64       nullptr, // pEngineName
65       0, // engineVersion
66       VK_API_VERSION_1_0, // apiVersion
67   };
68 
69   std::vector<const char*> enabled_layers;
70   std::vector<const char*> enabled_extensions;
71 
72   if (config.enableValidationMessages) {
73     std::vector<const char*> requested_layers{
74         // "VK_LAYER_LUNARG_api_dump",
75         "VK_LAYER_KHRONOS_validation",
76     };
77     std::vector<const char*> requested_extensions{
78 #ifdef VK_EXT_debug_report
79         VK_EXT_DEBUG_REPORT_EXTENSION_NAME,
80 #endif /* VK_EXT_debug_report */
81     };
82 
83     find_requested_layers_and_extensions(
84         enabled_layers,
85         enabled_extensions,
86         requested_layers,
87         requested_extensions);
88   }
89 
90   const VkInstanceCreateInfo instance_create_info{
91       VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType
92       nullptr, // pNext
93       0u, // flags
94       &application_info, // pApplicationInfo
95       static_cast<uint32_t>(enabled_layers.size()), // enabledLayerCount
96       enabled_layers.data(), // ppEnabledLayerNames
97       static_cast<uint32_t>(enabled_extensions.size()), // enabledExtensionCount
98       enabled_extensions.data(), // ppEnabledExtensionNames
99   };
100 
101   VkInstance instance{};
102   VK_CHECK(vkCreateInstance(&instance_create_info, nullptr, &instance));
103   VK_CHECK_COND(instance, "Invalid Vulkan instance!");
104 
105 #ifdef USE_VULKAN_VOLK
106   volkLoadInstance(instance);
107 #endif /* USE_VULKAN_VOLK */
108 
109   return instance;
110 }
111 
create_physical_devices(VkInstance instance)112 std::vector<Runtime::DeviceMapping> create_physical_devices(
113     VkInstance instance) {
114   if (VK_NULL_HANDLE == instance) {
115     return std::vector<Runtime::DeviceMapping>();
116   }
117 
118   uint32_t device_count = 0;
119   VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, nullptr));
120 
121   std::vector<VkPhysicalDevice> devices(device_count);
122   VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, devices.data()));
123 
124   std::vector<Runtime::DeviceMapping> device_mappings;
125   device_mappings.reserve(device_count);
126   for (VkPhysicalDevice physical_device : devices) {
127     device_mappings.emplace_back(PhysicalDevice(physical_device), -1);
128   }
129 
130   return device_mappings;
131 }
132 
debug_report_callback_fn(const VkDebugReportFlagsEXT flags,const VkDebugReportObjectTypeEXT,const uint64_t,const size_t,const int32_t message_code,const char * const layer_prefix,const char * const message,void * const)133 VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
134     const VkDebugReportFlagsEXT flags,
135     const VkDebugReportObjectTypeEXT /* object_type */,
136     const uint64_t /* object */,
137     const size_t /* location */,
138     const int32_t message_code,
139     const char* const layer_prefix,
140     const char* const message,
141     void* const /* user_data */) {
142   (void)flags;
143 
144   std::stringstream stream;
145   stream << layer_prefix << " " << message_code << " " << message << std::endl;
146   const std::string log = stream.str();
147 
148   std::cout << log;
149 
150   return VK_FALSE;
151 }
152 
create_debug_report_callback(VkInstance instance,const RuntimeConfiguration config)153 VkDebugReportCallbackEXT create_debug_report_callback(
154     VkInstance instance,
155     const RuntimeConfiguration config) {
156   if (VK_NULL_HANDLE == instance || !config.enableValidationMessages) {
157     return VkDebugReportCallbackEXT{};
158   }
159 
160   const VkDebugReportCallbackCreateInfoEXT debugReportCallbackCreateInfo{
161       VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT, // sType
162       nullptr, // pNext
163       VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT |
164           VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT |
165           VK_DEBUG_REPORT_ERROR_BIT_EXT |
166           VK_DEBUG_REPORT_DEBUG_BIT_EXT, // flags
167       debug_report_callback_fn, // pfnCallback
168       nullptr, // pUserData
169   };
170 
171   const auto vkCreateDebugReportCallbackEXT =
172       (PFN_vkCreateDebugReportCallbackEXT)vkGetInstanceProcAddr(
173           instance, "vkCreateDebugReportCallbackEXT");
174 
175   VK_CHECK_COND(
176       vkCreateDebugReportCallbackEXT,
177       "Could not load vkCreateDebugReportCallbackEXT");
178 
179   VkDebugReportCallbackEXT debug_report_callback{};
180   VK_CHECK(vkCreateDebugReportCallbackEXT(
181       instance,
182       &debugReportCallbackCreateInfo,
183       nullptr,
184       &debug_report_callback));
185 
186   VK_CHECK_COND(debug_report_callback, "Invalid Vulkan debug report callback!");
187 
188   return debug_report_callback;
189 }
190 
191 //
192 // Adapter selection methods
193 //
194 
select_first(const std::vector<Runtime::DeviceMapping> & devices)195 uint32_t select_first(const std::vector<Runtime::DeviceMapping>& devices) {
196   if (devices.empty()) {
197     return devices.size() + 1; // return out of range to signal invalidity
198   }
199 
200   // Select the first adapter that has compute capability
201   for (size_t i = 0; i < devices.size(); ++i) {
202     if (devices[i].first.num_compute_queues > 0) {
203       return i;
204     }
205   }
206 
207   return devices.size() + 1;
208 }
209 
210 //
211 // Global runtime initialization
212 //
213 
init_global_vulkan_runtime()214 std::unique_ptr<Runtime> init_global_vulkan_runtime() {
215   // Load Vulkan drivers
216 #if defined(USE_VULKAN_VOLK)
217   if (VK_SUCCESS != volkInitialize()) {
218     return std::unique_ptr<Runtime>(nullptr);
219   }
220 #elif defined(USE_VULKAN_WRAPPER)
221   if (!InitVulkan()) {
222     return std::unique_ptr<Runtime>(nullptr);
223   }
224 #endif /* USE_VULKAN_VOLK, USE_VULKAN_WRAPPER */
225 
226   const bool enableValidationMessages =
227 #if defined(VULKAN_DEBUG)
228       true;
229 #else
230       false;
231 #endif /* VULKAN_DEBUG */
232   const bool initDefaultDevice = true;
233   const uint32_t numRequestedQueues = 1; // TODO: raise this value
234 
235   const RuntimeConfiguration default_config{
236       enableValidationMessages,
237       initDefaultDevice,
238       AdapterSelector::First,
239       numRequestedQueues,
240   };
241 
242   try {
243     return std::make_unique<Runtime>(Runtime(default_config));
244   } catch (...) {
245   }
246 
247   return std::unique_ptr<Runtime>(nullptr);
248 }
249 
250 } // namespace
251 
Runtime(const RuntimeConfiguration config)252 Runtime::Runtime(const RuntimeConfiguration config)
253     : config_(config),
254       instance_(create_instance(config_)),
255       device_mappings_(create_physical_devices(instance_)),
256       adapters_{},
257       default_adapter_i_(UINT32_MAX),
258       debug_report_callback_(create_debug_report_callback(instance_, config_)) {
259   // List of adapters will never exceed the number of physical devices
260   adapters_.reserve(device_mappings_.size());
261 
262   if (config.initDefaultDevice) {
263     try {
264       switch (config.defaultSelector) {
265         case AdapterSelector::First:
266           default_adapter_i_ = create_adapter(select_first);
267       }
268     } catch (...) {
269     }
270   }
271 }
272 
~Runtime()273 Runtime::~Runtime() {
274   if (VK_NULL_HANDLE == instance_) {
275     return;
276   }
277 
278   // Clear adapters list to trigger device destruction before destroying
279   // VkInstance
280   adapters_.clear();
281 
282   // Instance must be destroyed last as its used to destroy the debug report
283   // callback.
284   if (debug_report_callback_) {
285     const auto vkDestroyDebugReportCallbackEXT =
286         (PFN_vkDestroyDebugReportCallbackEXT)vkGetInstanceProcAddr(
287             instance_, "vkDestroyDebugReportCallbackEXT");
288 
289     if (vkDestroyDebugReportCallbackEXT) {
290       vkDestroyDebugReportCallbackEXT(
291           instance_, debug_report_callback_, nullptr);
292     }
293 
294     debug_report_callback_ = {};
295   }
296 
297   vkDestroyInstance(instance_, nullptr);
298   instance_ = VK_NULL_HANDLE;
299 }
300 
Runtime(Runtime && other)301 Runtime::Runtime(Runtime&& other) noexcept
302     : config_(other.config_),
303       instance_(other.instance_),
304       adapters_(std::move(other.adapters_)),
305       default_adapter_i_(other.default_adapter_i_),
306       debug_report_callback_(other.debug_report_callback_) {
307   other.instance_ = VK_NULL_HANDLE;
308   other.debug_report_callback_ = {};
309 }
310 
create_adapter(const Selector & selector)311 uint32_t Runtime::create_adapter(const Selector& selector) {
312   VK_CHECK_COND(
313       !device_mappings_.empty(),
314       "Pytorch Vulkan Runtime: Could not initialize adapter because no "
315       "devices were found by the Vulkan instance.");
316 
317   uint32_t physical_device_i = selector(device_mappings_);
318   VK_CHECK_COND(
319       physical_device_i < device_mappings_.size(),
320       "Pytorch Vulkan Runtime: no suitable device adapter was selected! "
321       "Device could not be initialized");
322 
323   Runtime::DeviceMapping& device_mapping = device_mappings_[physical_device_i];
324   // If an Adapter has already been created, return that
325   int32_t adapter_i = device_mapping.second;
326   if (adapter_i >= 0) {
327     return adapter_i;
328   }
329   // Otherwise, create an adapter for the selected physical device
330   adapter_i = utils::safe_downcast<int32_t>(adapters_.size());
331   adapters_.emplace_back(
332       new Adapter(instance_, device_mapping.first, config_.numRequestedQueues));
333   device_mapping.second = adapter_i;
334 
335   return adapter_i;
336 }
337 
runtime()338 Runtime* runtime() {
339   // The global vulkan runtime is declared as a static local variable within a
340   // non-static function to ensure it has external linkage. If it were a global
341   // static variable there would be one copy per translation unit that includes
342   // Runtime.h as it would have internal linkage.
343   static const std::unique_ptr<Runtime> p_runtime =
344       init_global_vulkan_runtime();
345 
346   VK_CHECK_COND(
347       p_runtime,
348       "Pytorch Vulkan Runtime: The global runtime could not be retrieved "
349       "because it failed to initialize.");
350 
351   return p_runtime.get();
352 }
353 
354 } // namespace api
355 } // namespace vulkan
356 } // namespace native
357 } // namespace at
358