1 #include <ATen/native/vulkan/api/Pipeline.h>
2
3 namespace at {
4 namespace native {
5 namespace vulkan {
6 namespace api {
7
8 //
9 // Utility Functions
10 //
11
vk_access(const PipelineStageFlags stage,const MemoryAccessFlags access)12 VkAccessFlags vk_access(
13 const PipelineStageFlags stage,
14 const MemoryAccessFlags access) {
15 VkAccessFlags vk_access = 0u;
16
17 if (access & MemoryAccessType::READ) {
18 if (stage & PipelineStage::COMPUTE) {
19 vk_access |= VK_ACCESS_SHADER_READ_BIT;
20 }
21
22 if (stage & PipelineStage::HOST) {
23 vk_access |= VK_ACCESS_HOST_READ_BIT;
24 }
25
26 if (stage & PipelineStage::TRANSFER) {
27 vk_access |= VK_ACCESS_TRANSFER_READ_BIT;
28 }
29 }
30
31 if (access & MemoryAccessType::WRITE) {
32 if (stage & PipelineStage::COMPUTE) {
33 vk_access |= VK_ACCESS_SHADER_WRITE_BIT;
34 }
35
36 if (stage & PipelineStage::HOST) {
37 vk_access |= VK_ACCESS_HOST_WRITE_BIT;
38 }
39
40 if (stage & PipelineStage::TRANSFER) {
41 vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
42 }
43 }
44
45 return vk_access;
46 }
47
vk_stage(const PipelineStageFlags stage)48 VkPipelineStageFlags vk_stage(const PipelineStageFlags stage) {
49 VkPipelineStageFlags vk_stage = 0u;
50
51 if (stage & PipelineStage::COMPUTE) {
52 vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
53 }
54
55 if (stage & PipelineStage::HOST) {
56 vk_stage |= VK_PIPELINE_STAGE_HOST_BIT;
57 }
58
59 if (stage & PipelineStage::TRANSFER) {
60 vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT;
61 }
62
63 return vk_stage;
64 }
65
vk_layout(const PipelineStageFlags stage,const MemoryAccessFlags access)66 VkImageLayout vk_layout(
67 const PipelineStageFlags stage,
68 const MemoryAccessFlags access) {
69 switch (stage) {
70 case PipelineStage::COMPUTE:
71 switch (access) {
72 case MemoryAccessType::READ:
73 return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
74 default:
75 return VK_IMAGE_LAYOUT_GENERAL;
76 }
77 break;
78 case PipelineStage::TRANSFER:
79 switch (access) {
80 case MemoryAccessType::READ:
81 return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL;
82 case MemoryAccessType::WRITE:
83 return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
84 default:
85 VK_THROW("Invalid memory access type for transfer stage!");
86 }
87 break;
88 default:
89 VK_THROW("Cannot determine appropriate image layout");
90 }
91
92 return VK_IMAGE_LAYOUT_UNDEFINED;
93 }
94
95 //
96 // PipelineLayout
97 //
98
PipelineLayout(VkDevice device,VkDescriptorSetLayout descriptor_layout)99 PipelineLayout::PipelineLayout(
100 VkDevice device,
101 VkDescriptorSetLayout descriptor_layout)
102 : device_(device), handle_{VK_NULL_HANDLE} {
103 // TODO: Enable push constants
104 const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
105 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
106 nullptr, // pNext
107 0u, // flags
108 1u, // setLayoutCount
109 &descriptor_layout, // pSetLayouts
110 0u, // pushConstantRangeCount
111 nullptr, // pPushConstantRanges
112 };
113
114 VK_CHECK(vkCreatePipelineLayout(
115 device_, &pipeline_layout_create_info, nullptr, &handle_));
116 }
117
PipelineLayout(PipelineLayout && other)118 PipelineLayout::PipelineLayout(PipelineLayout&& other) noexcept
119 : device_(other.device_), handle_(other.handle_) {
120 other.handle_ = VK_NULL_HANDLE;
121 }
122
~PipelineLayout()123 PipelineLayout::~PipelineLayout() {
124 if (VK_NULL_HANDLE == handle_) {
125 return;
126 }
127 vkDestroyPipelineLayout(device_, handle_, nullptr);
128 handle_ = VK_NULL_HANDLE;
129 }
130
swap(PipelineLayout & lhs,PipelineLayout & rhs)131 void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
132 VkDevice tmp_device = lhs.device_;
133 VkPipelineLayout tmp_handle = lhs.handle_;
134
135 lhs.device_ = rhs.device_;
136 lhs.handle_ = rhs.handle_;
137
138 rhs.device_ = tmp_device;
139 rhs.handle_ = tmp_handle;
140 }
141
142 //
143 // ComputePipeline
144 //
145
ComputePipeline(VkDevice device,const ComputePipeline::Descriptor & descriptor,VkPipelineCache pipeline_cache)146 ComputePipeline::ComputePipeline(
147 VkDevice device,
148 const ComputePipeline::Descriptor& descriptor,
149 VkPipelineCache pipeline_cache)
150 : device_(device), handle_{VK_NULL_HANDLE} {
151 // NOLINTNEXTLINE
152 constexpr VkSpecializationMapEntry specialization_map_entries[3]{
153 // X
154 {
155 0u,
156 offsetof(utils::uvec3, data[0u]),
157 sizeof(utils::uvec3::data[0u]),
158 },
159 // Y
160 {
161 1u,
162 offsetof(utils::uvec3, data[1u]),
163 sizeof(utils::uvec3::data[1u]),
164 },
165 // Z
166 {
167 2u,
168 offsetof(utils::uvec3, data[2u]),
169 sizeof(utils::uvec3::data[2u]),
170 },
171 };
172
173 const VkSpecializationInfo specialization_info{
174 3u, // mapEntryCount
175 specialization_map_entries, // pMapEntries
176 sizeof(descriptor.local_work_group), // dataSize
177 &descriptor.local_work_group, // pData
178 };
179
180 const VkPipelineShaderStageCreateInfo shader_stage_create_info{
181 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
182 nullptr, // pNext
183 0u, // flags
184 VK_SHADER_STAGE_COMPUTE_BIT, // stage
185 descriptor.shader_module, // module
186 "main", // pName
187 &specialization_info, // pSpecializationInfo
188 };
189
190 const VkComputePipelineCreateInfo compute_pipeline_create_info{
191 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
192 nullptr, // pNext
193 0u, // flags
194 shader_stage_create_info, // stage
195 descriptor.pipeline_layout, // layout
196 VK_NULL_HANDLE, // basePipelineHandle
197 0u, // basePipelineIndex
198 };
199
200 VK_CHECK(vkCreateComputePipelines(
201 device_,
202 pipeline_cache,
203 1u,
204 &compute_pipeline_create_info,
205 nullptr,
206 &handle_));
207 }
208
ComputePipeline(ComputePipeline && other)209 ComputePipeline::ComputePipeline(ComputePipeline&& other) noexcept
210 : device_(other.device_), handle_(other.handle_) {
211 other.handle_ = VK_NULL_HANDLE;
212 }
213
~ComputePipeline()214 ComputePipeline::~ComputePipeline() {
215 if (VK_NULL_HANDLE == handle_) {
216 return;
217 }
218 vkDestroyPipeline(device_, handle_, nullptr);
219 handle_ = VK_NULL_HANDLE;
220 }
221
swap(ComputePipeline & lhs,ComputePipeline & rhs)222 void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept {
223 VkDevice tmp_device = lhs.device_;
224 VkPipeline tmp_handle = lhs.handle_;
225
226 lhs.device_ = rhs.device_;
227 lhs.handle_ = rhs.handle_;
228
229 rhs.device_ = tmp_device;
230 rhs.handle_ = tmp_handle;
231 }
232
operator ==(const ComputePipeline::Descriptor & _1,const ComputePipeline::Descriptor & _2)233 static bool operator==(
234 const ComputePipeline::Descriptor& _1,
235 const ComputePipeline::Descriptor& _2) {
236 return (
237 _1.pipeline_layout == _2.pipeline_layout &&
238 _1.shader_module == _2.shader_module &&
239 _1.local_work_group == _2.local_work_group);
240 }
241
242 //
243 // PipelineLayoutCache
244 //
245
PipelineLayoutCache(VkDevice device)246 PipelineLayoutCache::PipelineLayoutCache(VkDevice device)
247 : cache_mutex_{}, device_(device), cache_{} {}
248
PipelineLayoutCache(PipelineLayoutCache && other)249 PipelineLayoutCache::PipelineLayoutCache(PipelineLayoutCache&& other) noexcept
250 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
251 std::lock_guard<std::mutex> lock(other.cache_mutex_);
252 }
253
~PipelineLayoutCache()254 PipelineLayoutCache::~PipelineLayoutCache() {
255 purge();
256 }
257
retrieve(const PipelineLayoutCache::Key & key)258 VkPipelineLayout PipelineLayoutCache::retrieve(
259 const PipelineLayoutCache::Key& key) {
260 std::lock_guard<std::mutex> lock(cache_mutex_);
261
262 auto it = cache_.find(key);
263 if (cache_.cend() == it) {
264 it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
265 }
266
267 return it->second.handle();
268 }
269
purge()270 void PipelineLayoutCache::purge() {
271 std::lock_guard<std::mutex> lock(cache_mutex_);
272 cache_.clear();
273 }
274
275 //
276 // ComputePipelineCache
277 //
278
ComputePipelineCache(VkDevice device)279 ComputePipelineCache::ComputePipelineCache(VkDevice device)
280 : cache_mutex_{},
281 device_(device),
282 pipeline_cache_{VK_NULL_HANDLE},
283 cache_{} {
284 const VkPipelineCacheCreateInfo pipeline_cache_create_info{
285 VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
286 nullptr, // pNext
287 0u, // flags
288 0u, // initialDataSize
289 nullptr, // pInitialData
290 };
291
292 VK_CHECK(vkCreatePipelineCache(
293 device, &pipeline_cache_create_info, nullptr, &pipeline_cache_));
294 }
295
ComputePipelineCache(ComputePipelineCache && other)296 ComputePipelineCache::ComputePipelineCache(
297 ComputePipelineCache&& other) noexcept
298 : cache_mutex_{},
299 device_(other.device_),
300 pipeline_cache_(other.pipeline_cache_),
301 cache_(std::move(other.cache_)) {
302 std::lock_guard<std::mutex> lock(other.cache_mutex_);
303
304 other.pipeline_cache_ = VK_NULL_HANDLE;
305 }
306
~ComputePipelineCache()307 ComputePipelineCache::~ComputePipelineCache() {
308 purge();
309
310 if (VK_NULL_HANDLE == pipeline_cache_) {
311 return;
312 }
313 vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
314 pipeline_cache_ = VK_NULL_HANDLE;
315 }
316
retrieve(const ComputePipelineCache::Key & key)317 VkPipeline ComputePipelineCache::retrieve(
318 const ComputePipelineCache::Key& key) {
319 std::lock_guard<std::mutex> lock(cache_mutex_);
320
321 auto it = cache_.find(key);
322 if (cache_.cend() == it) {
323 it = cache_
324 .insert(
325 {key,
326 ComputePipelineCache::Value(device_, key, pipeline_cache_)})
327 .first;
328 }
329
330 return it->second.handle();
331 }
332
purge()333 void ComputePipelineCache::purge() {
334 cache_.clear();
335 }
336
337 } // namespace api
338 } // namespace vulkan
339 } // namespace native
340 } // namespace at
341