• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/api/Pipeline.h>
2 
3 namespace at {
4 namespace native {
5 namespace vulkan {
6 namespace api {
7 
8 //
9 // Utility Functions
10 //
11 
vk_access(const PipelineStageFlags stage,const MemoryAccessFlags access)12 VkAccessFlags vk_access(
13     const PipelineStageFlags stage,
14     const MemoryAccessFlags access) {
15   VkAccessFlags vk_access = 0u;
16 
17   if (access & MemoryAccessType::READ) {
18     if (stage & PipelineStage::COMPUTE) {
19       vk_access |= VK_ACCESS_SHADER_READ_BIT;
20     }
21 
22     if (stage & PipelineStage::HOST) {
23       vk_access |= VK_ACCESS_HOST_READ_BIT;
24     }
25 
26     if (stage & PipelineStage::TRANSFER) {
27       vk_access |= VK_ACCESS_TRANSFER_READ_BIT;
28     }
29   }
30 
31   if (access & MemoryAccessType::WRITE) {
32     if (stage & PipelineStage::COMPUTE) {
33       vk_access |= VK_ACCESS_SHADER_WRITE_BIT;
34     }
35 
36     if (stage & PipelineStage::HOST) {
37       vk_access |= VK_ACCESS_HOST_WRITE_BIT;
38     }
39 
40     if (stage & PipelineStage::TRANSFER) {
41       vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
42     }
43   }
44 
45   return vk_access;
46 }
47 
vk_stage(const PipelineStageFlags stage)48 VkPipelineStageFlags vk_stage(const PipelineStageFlags stage) {
49   VkPipelineStageFlags vk_stage = 0u;
50 
51   if (stage & PipelineStage::COMPUTE) {
52     vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
53   }
54 
55   if (stage & PipelineStage::HOST) {
56     vk_stage |= VK_PIPELINE_STAGE_HOST_BIT;
57   }
58 
59   if (stage & PipelineStage::TRANSFER) {
60     vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT;
61   }
62 
63   return vk_stage;
64 }
65 
vk_layout(const PipelineStageFlags stage,const MemoryAccessFlags access)66 VkImageLayout vk_layout(
67     const PipelineStageFlags stage,
68     const MemoryAccessFlags access) {
69   switch (stage) {
70     case PipelineStage::COMPUTE:
71       switch (access) {
72         case MemoryAccessType::READ:
73           return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
74         default:
75           return VK_IMAGE_LAYOUT_GENERAL;
76       }
77       break;
78     case PipelineStage::TRANSFER:
79       switch (access) {
80         case MemoryAccessType::READ:
81           return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL;
82         case MemoryAccessType::WRITE:
83           return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
84         default:
85           VK_THROW("Invalid memory access type for transfer stage!");
86       }
87       break;
88     default:
89       VK_THROW("Cannot determine appropriate image layout");
90   }
91 
92   return VK_IMAGE_LAYOUT_UNDEFINED;
93 }
94 
95 //
96 // PipelineLayout
97 //
98 
PipelineLayout(VkDevice device,VkDescriptorSetLayout descriptor_layout)99 PipelineLayout::PipelineLayout(
100     VkDevice device,
101     VkDescriptorSetLayout descriptor_layout)
102     : device_(device), handle_{VK_NULL_HANDLE} {
103   // TODO: Enable push constants
104   const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
105       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
106       nullptr, // pNext
107       0u, // flags
108       1u, // setLayoutCount
109       &descriptor_layout, // pSetLayouts
110       0u, // pushConstantRangeCount
111       nullptr, // pPushConstantRanges
112   };
113 
114   VK_CHECK(vkCreatePipelineLayout(
115       device_, &pipeline_layout_create_info, nullptr, &handle_));
116 }
117 
PipelineLayout(PipelineLayout && other)118 PipelineLayout::PipelineLayout(PipelineLayout&& other) noexcept
119     : device_(other.device_), handle_(other.handle_) {
120   other.handle_ = VK_NULL_HANDLE;
121 }
122 
~PipelineLayout()123 PipelineLayout::~PipelineLayout() {
124   if (VK_NULL_HANDLE == handle_) {
125     return;
126   }
127   vkDestroyPipelineLayout(device_, handle_, nullptr);
128   handle_ = VK_NULL_HANDLE;
129 }
130 
swap(PipelineLayout & lhs,PipelineLayout & rhs)131 void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
132   VkDevice tmp_device = lhs.device_;
133   VkPipelineLayout tmp_handle = lhs.handle_;
134 
135   lhs.device_ = rhs.device_;
136   lhs.handle_ = rhs.handle_;
137 
138   rhs.device_ = tmp_device;
139   rhs.handle_ = tmp_handle;
140 }
141 
142 //
143 // ComputePipeline
144 //
145 
ComputePipeline(VkDevice device,const ComputePipeline::Descriptor & descriptor,VkPipelineCache pipeline_cache)146 ComputePipeline::ComputePipeline(
147     VkDevice device,
148     const ComputePipeline::Descriptor& descriptor,
149     VkPipelineCache pipeline_cache)
150     : device_(device), handle_{VK_NULL_HANDLE} {
151   // NOLINTNEXTLINE
152   constexpr VkSpecializationMapEntry specialization_map_entries[3]{
153       // X
154       {
155           0u,
156           offsetof(utils::uvec3, data[0u]),
157           sizeof(utils::uvec3::data[0u]),
158       },
159       // Y
160       {
161           1u,
162           offsetof(utils::uvec3, data[1u]),
163           sizeof(utils::uvec3::data[1u]),
164       },
165       // Z
166       {
167           2u,
168           offsetof(utils::uvec3, data[2u]),
169           sizeof(utils::uvec3::data[2u]),
170       },
171   };
172 
173   const VkSpecializationInfo specialization_info{
174       3u, // mapEntryCount
175       specialization_map_entries, // pMapEntries
176       sizeof(descriptor.local_work_group), // dataSize
177       &descriptor.local_work_group, // pData
178   };
179 
180   const VkPipelineShaderStageCreateInfo shader_stage_create_info{
181       VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
182       nullptr, // pNext
183       0u, // flags
184       VK_SHADER_STAGE_COMPUTE_BIT, // stage
185       descriptor.shader_module, // module
186       "main", // pName
187       &specialization_info, // pSpecializationInfo
188   };
189 
190   const VkComputePipelineCreateInfo compute_pipeline_create_info{
191       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
192       nullptr, // pNext
193       0u, // flags
194       shader_stage_create_info, // stage
195       descriptor.pipeline_layout, // layout
196       VK_NULL_HANDLE, // basePipelineHandle
197       0u, // basePipelineIndex
198   };
199 
200   VK_CHECK(vkCreateComputePipelines(
201       device_,
202       pipeline_cache,
203       1u,
204       &compute_pipeline_create_info,
205       nullptr,
206       &handle_));
207 }
208 
ComputePipeline(ComputePipeline && other)209 ComputePipeline::ComputePipeline(ComputePipeline&& other) noexcept
210     : device_(other.device_), handle_(other.handle_) {
211   other.handle_ = VK_NULL_HANDLE;
212 }
213 
~ComputePipeline()214 ComputePipeline::~ComputePipeline() {
215   if (VK_NULL_HANDLE == handle_) {
216     return;
217   }
218   vkDestroyPipeline(device_, handle_, nullptr);
219   handle_ = VK_NULL_HANDLE;
220 }
221 
swap(ComputePipeline & lhs,ComputePipeline & rhs)222 void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept {
223   VkDevice tmp_device = lhs.device_;
224   VkPipeline tmp_handle = lhs.handle_;
225 
226   lhs.device_ = rhs.device_;
227   lhs.handle_ = rhs.handle_;
228 
229   rhs.device_ = tmp_device;
230   rhs.handle_ = tmp_handle;
231 }
232 
operator ==(const ComputePipeline::Descriptor & _1,const ComputePipeline::Descriptor & _2)233 static bool operator==(
234     const ComputePipeline::Descriptor& _1,
235     const ComputePipeline::Descriptor& _2) {
236   return (
237       _1.pipeline_layout == _2.pipeline_layout &&
238       _1.shader_module == _2.shader_module &&
239       _1.local_work_group == _2.local_work_group);
240 }
241 
242 //
243 // PipelineLayoutCache
244 //
245 
PipelineLayoutCache(VkDevice device)246 PipelineLayoutCache::PipelineLayoutCache(VkDevice device)
247     : cache_mutex_{}, device_(device), cache_{} {}
248 
PipelineLayoutCache(PipelineLayoutCache && other)249 PipelineLayoutCache::PipelineLayoutCache(PipelineLayoutCache&& other) noexcept
250     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
251   std::lock_guard<std::mutex> lock(other.cache_mutex_);
252 }
253 
~PipelineLayoutCache()254 PipelineLayoutCache::~PipelineLayoutCache() {
255   purge();
256 }
257 
retrieve(const PipelineLayoutCache::Key & key)258 VkPipelineLayout PipelineLayoutCache::retrieve(
259     const PipelineLayoutCache::Key& key) {
260   std::lock_guard<std::mutex> lock(cache_mutex_);
261 
262   auto it = cache_.find(key);
263   if (cache_.cend() == it) {
264     it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
265   }
266 
267   return it->second.handle();
268 }
269 
purge()270 void PipelineLayoutCache::purge() {
271   std::lock_guard<std::mutex> lock(cache_mutex_);
272   cache_.clear();
273 }
274 
275 //
276 // ComputePipelineCache
277 //
278 
ComputePipelineCache(VkDevice device)279 ComputePipelineCache::ComputePipelineCache(VkDevice device)
280     : cache_mutex_{},
281       device_(device),
282       pipeline_cache_{VK_NULL_HANDLE},
283       cache_{} {
284   const VkPipelineCacheCreateInfo pipeline_cache_create_info{
285       VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
286       nullptr, // pNext
287       0u, // flags
288       0u, // initialDataSize
289       nullptr, // pInitialData
290   };
291 
292   VK_CHECK(vkCreatePipelineCache(
293       device, &pipeline_cache_create_info, nullptr, &pipeline_cache_));
294 }
295 
ComputePipelineCache(ComputePipelineCache && other)296 ComputePipelineCache::ComputePipelineCache(
297     ComputePipelineCache&& other) noexcept
298     : cache_mutex_{},
299       device_(other.device_),
300       pipeline_cache_(other.pipeline_cache_),
301       cache_(std::move(other.cache_)) {
302   std::lock_guard<std::mutex> lock(other.cache_mutex_);
303 
304   other.pipeline_cache_ = VK_NULL_HANDLE;
305 }
306 
~ComputePipelineCache()307 ComputePipelineCache::~ComputePipelineCache() {
308   purge();
309 
310   if (VK_NULL_HANDLE == pipeline_cache_) {
311     return;
312   }
313   vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
314   pipeline_cache_ = VK_NULL_HANDLE;
315 }
316 
retrieve(const ComputePipelineCache::Key & key)317 VkPipeline ComputePipelineCache::retrieve(
318     const ComputePipelineCache::Key& key) {
319   std::lock_guard<std::mutex> lock(cache_mutex_);
320 
321   auto it = cache_.find(key);
322   if (cache_.cend() == it) {
323     it = cache_
324              .insert(
325                  {key,
326                   ComputePipelineCache::Value(device_, key, pipeline_cache_)})
327              .first;
328   }
329 
330   return it->second.handle();
331 }
332 
purge()333 void ComputePipelineCache::purge() {
334   cache_.clear();
335 }
336 
337 } // namespace api
338 } // namespace vulkan
339 } // namespace native
340 } // namespace at
341