• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/native/vulkan/api/Adapter.h>
2 #include <ATen/native/vulkan/api/Command.h>
3 
4 #include <mutex>
5 
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace api {
10 
11 //
12 // CommandBuffer
13 //
14 
CommandBuffer(VkCommandBuffer handle,const VkCommandBufferUsageFlags flags)15 CommandBuffer::CommandBuffer(
16     VkCommandBuffer handle,
17     const VkCommandBufferUsageFlags flags)
18     : handle_(handle),
19       flags_(flags),
20       state_(CommandBuffer::State::NEW),
21       bound_{} {}
22 
CommandBuffer(CommandBuffer && other)23 CommandBuffer::CommandBuffer(CommandBuffer&& other) noexcept
24     : handle_(other.handle_),
25       flags_(other.flags_),
26       state_(CommandBuffer::State::INVALID),
27       bound_(other.bound_) {
28   other.handle_ = VK_NULL_HANDLE;
29   other.bound_.reset();
30 }
31 
operator =(CommandBuffer && other)32 CommandBuffer& CommandBuffer::operator=(CommandBuffer&& other) noexcept {
33   handle_ = other.handle_;
34   flags_ = other.flags_;
35   state_ = other.state_;
36   bound_ = other.bound_;
37 
38   other.handle_ = VK_NULL_HANDLE;
39   other.bound_.reset();
40   other.state_ = CommandBuffer::State::INVALID;
41 
42   return *this;
43 }
44 
begin()45 void CommandBuffer::begin() {
46   VK_CHECK_COND(
47       state_ == CommandBuffer::State::NEW,
48       "Vulkan CommandBuffer: called begin() on a command buffer whose state "
49       "is not NEW.");
50 
51   const VkCommandBufferBeginInfo begin_info{
52       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
53       nullptr,
54       flags_,
55       nullptr,
56   };
57 
58   VK_CHECK(vkBeginCommandBuffer(handle_, &begin_info));
59   state_ = CommandBuffer::State::RECORDING;
60 }
61 
end()62 void CommandBuffer::end() {
63   VK_CHECK_COND(
64       state_ == CommandBuffer::State::RECORDING ||
65           state_ == CommandBuffer::State::SUBMITTED,
66       "Vulkan CommandBuffer: called end() on a command buffer whose state "
67       "is not RECORDING or SUBMITTED.");
68 
69   if (state_ == CommandBuffer::State::RECORDING) {
70     VK_CHECK(vkEndCommandBuffer(handle_));
71   }
72   state_ = CommandBuffer::State::READY;
73 }
74 
bind_pipeline(VkPipeline pipeline,VkPipelineLayout pipeline_layout,const utils::uvec3 local_workgroup_size)75 void CommandBuffer::bind_pipeline(
76     VkPipeline pipeline,
77     VkPipelineLayout pipeline_layout,
78     const utils::uvec3 local_workgroup_size) {
79   VK_CHECK_COND(
80       state_ == CommandBuffer::State::RECORDING,
81       "Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state "
82       "is not RECORDING.");
83 
84   if (pipeline != bound_.pipeline) {
85     vkCmdBindPipeline(handle_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
86 
87     bound_.pipeline = pipeline;
88   }
89 
90   bound_.pipeline_layout = pipeline_layout;
91   bound_.local_workgroup_size = local_workgroup_size;
92 
93   state_ = CommandBuffer::State::PIPELINE_BOUND;
94 }
95 
bind_descriptors(VkDescriptorSet descriptors)96 void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
97   VK_CHECK_COND(
98       state_ == CommandBuffer::State::PIPELINE_BOUND,
99       "Vulkan CommandBuffer: called bind_descriptors() on a command buffer whose state "
100       "is not PIPELINE_BOUND.");
101 
102   if (descriptors != bound_.descriptors) {
103     vkCmdBindDescriptorSets(
104         handle_, // commandBuffer
105         VK_PIPELINE_BIND_POINT_COMPUTE, // pipelineBindPoint
106         bound_.pipeline_layout, // layout
107         0u, // firstSet
108         1u, // descriptorSetCount
109         &descriptors, // pDescriptorSets
110         0u, // dynamicOffsetCount
111         nullptr); // pDynamicOffsets
112   }
113 
114   bound_.descriptors = descriptors;
115 
116   state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
117 }
118 
insert_barrier(PipelineBarrier & pipeline_barrier)119 void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
120   VK_CHECK_COND(
121       state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
122           state_ == CommandBuffer::State::RECORDING,
123       "Vulkan CommandBuffer: called insert_barrier() on a command buffer whose state "
124       "is not DESCRIPTORS_BOUND or RECORDING.");
125 
126   if (pipeline_barrier) {
127     if (!pipeline_barrier.buffer_barrier_handles.empty()) {
128       pipeline_barrier.buffer_barrier_handles.clear();
129     }
130     for (const api::BufferMemoryBarrier& memory_barrier :
131          pipeline_barrier.buffers) {
132       pipeline_barrier.buffer_barrier_handles.push_back(memory_barrier.handle);
133     }
134 
135     if (!pipeline_barrier.image_barrier_handles.empty()) {
136       pipeline_barrier.image_barrier_handles.clear();
137     }
138     for (const api::ImageMemoryBarrier& memory_barrier :
139          pipeline_barrier.images) {
140       pipeline_barrier.image_barrier_handles.push_back(memory_barrier.handle);
141     }
142     vkCmdPipelineBarrier(
143         handle_, // commandBuffer
144         pipeline_barrier.stage.src, // srcStageMask
145         pipeline_barrier.stage.dst, // dstStageMask
146         0u, // dependencyFlags
147         0u, // memoryBarrierCount
148         nullptr, // pMemoryBarriers
149         pipeline_barrier.buffers.size(), // bufferMemoryBarrierCount
150         !pipeline_barrier.buffers.empty()
151             ? pipeline_barrier.buffer_barrier_handles.data()
152             : nullptr, // pMemoryBarriers
153         pipeline_barrier.images.size(), // imageMemoryBarrierCount
154         !pipeline_barrier.images.empty()
155             ? pipeline_barrier.image_barrier_handles.data()
156             : nullptr); // pImageMemoryBarriers
157   }
158 
159   state_ = CommandBuffer::State::BARRIERS_INSERTED;
160 }
161 
dispatch(const utils::uvec3 & global_workgroup_size)162 void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) {
163   VK_CHECK_COND(
164       state_ == CommandBuffer::State::BARRIERS_INSERTED,
165       "Vulkan CommandBuffer: called dispatch() on a command buffer whose state "
166       "is not BARRIERS_INSERTED.");
167 
168   vkCmdDispatch(
169       handle_,
170       utils::div_up(
171           global_workgroup_size.data[0u], bound_.local_workgroup_size.data[0u]),
172       utils::div_up(
173           global_workgroup_size.data[1u], bound_.local_workgroup_size.data[1u]),
174       utils::div_up(
175           global_workgroup_size.data[2u],
176           bound_.local_workgroup_size.data[2u]));
177 
178   state_ = CommandBuffer::State::RECORDING;
179 }
180 
copy_buffer_to_buffer(const api::VulkanBuffer & source,const api::VulkanBuffer & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)181 void CommandBuffer::copy_buffer_to_buffer(
182     const api::VulkanBuffer& source,
183     const api::VulkanBuffer& destination,
184     const api::utils::uvec3& copy_range,
185     const api::utils::uvec3& src_offset,
186     const api::utils::uvec3& dst_offset) {
187   VK_CHECK_COND(
188       state_ == CommandBuffer::State::BARRIERS_INSERTED,
189       "Vulkan CommandBuffer: called copy_buffer_to_buffer() on a command buffer whose state "
190       "is not BARRIERS_INSERTED.");
191 
192   const VkBufferCopy copy_details{
193       src_offset.data[0u], // srcOffset
194       dst_offset.data[0u], // dstOffset
195       copy_range.data[0u], // size
196   };
197 
198   vkCmdCopyBuffer(
199       handle_, source.handle(), destination.handle(), 1u, &copy_details);
200 
201   state_ = CommandBuffer::State::RECORDING;
202 }
203 
copy_texture_to_texture(const api::VulkanImage & source,const api::VulkanImage & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)204 void CommandBuffer::copy_texture_to_texture(
205     const api::VulkanImage& source,
206     const api::VulkanImage& destination,
207     const api::utils::uvec3& copy_range,
208     const api::utils::uvec3& src_offset,
209     const api::utils::uvec3& dst_offset) {
210   VK_CHECK_COND(
211       state_ == CommandBuffer::State::BARRIERS_INSERTED,
212       "Vulkan CommandBuffer: called copy_texture_to_texture() on a command buffer whose state "
213       "is not BARRIERS_INSERTED.");
214 
215   const VkImageSubresourceLayers src_subresource_layers{
216       VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
217       0u, // mipLevel
218       0u, // baseArrayLayer
219       1u, // layerCount
220   };
221 
222   const VkImageSubresourceLayers dst_subresource_layers{
223       VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
224       0u, // mipLevel
225       0u, // baseArrayLayer
226       1u, // layerCount
227   };
228 
229   const VkImageCopy copy_details{
230       src_subresource_layers, // srcSubresource
231       create_offset3d(src_offset), // srcOffset
232       dst_subresource_layers, // dstSubresource
233       create_offset3d(dst_offset), // dstOffset
234       create_extent3d(copy_range), // extent
235   };
236 
237   vkCmdCopyImage(
238       handle_,
239       source.handle(),
240       source.layout(),
241       destination.handle(),
242       destination.layout(),
243       1u,
244       &copy_details);
245 
246   state_ = CommandBuffer::State::RECORDING;
247 }
248 
copy_texture_to_buffer(const api::VulkanImage & source,const api::VulkanBuffer & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)249 void CommandBuffer::copy_texture_to_buffer(
250     const api::VulkanImage& source,
251     const api::VulkanBuffer& destination,
252     const api::utils::uvec3& copy_range,
253     const api::utils::uvec3& src_offset,
254     const api::utils::uvec3& dst_offset) {
255   VK_CHECK_COND(
256       state_ == CommandBuffer::State::BARRIERS_INSERTED,
257       "Vulkan CommandBuffer: called copy_texture_to_buffer() on a command buffer whose state "
258       "is not BARRIERS_INSERTED.");
259 
260   const VkImageSubresourceLayers src_subresource_layers{
261       VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
262       0u, // mipLevel
263       0u, // baseArrayLayer
264       1u, // layerCount
265   };
266 
267   const VkBufferImageCopy copy_details{
268       dst_offset.data[0u], // bufferOffset
269       dst_offset.data[1u], // bufferRowLength
270       dst_offset.data[2u], // bufferImageHeight
271       src_subresource_layers, // imageSubresource
272       create_offset3d(src_offset), // imageOffset
273       create_extent3d(copy_range), // imageExtent
274   };
275 
276   vkCmdCopyImageToBuffer(
277       handle_,
278       source.handle(),
279       source.layout(),
280       destination.handle(),
281       1u,
282       &copy_details);
283 
284   state_ = CommandBuffer::State::RECORDING;
285 }
286 
copy_buffer_to_texture(const api::VulkanBuffer & source,const api::VulkanImage & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)287 void CommandBuffer::copy_buffer_to_texture(
288     const api::VulkanBuffer& source,
289     const api::VulkanImage& destination,
290     const api::utils::uvec3& copy_range,
291     const api::utils::uvec3& src_offset,
292     const api::utils::uvec3& dst_offset) {
293   VK_CHECK_COND(
294       state_ == CommandBuffer::State::BARRIERS_INSERTED,
295       "Vulkan CommandBuffer: called copy_buffer_to_texture() on a command buffer whose state "
296       "is not BARRIERS_INSERTED.");
297 
298   const VkImageSubresourceLayers dst_subresource_layers{
299       VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
300       0u, // mipLevel
301       0u, // baseArrayLayer
302       1u, // layerCount
303   };
304 
305   const VkBufferImageCopy copy_details{
306       src_offset.data[0u], // bufferOffset
307       src_offset.data[1u], // bufferRowLength
308       src_offset.data[2u], // bufferImageHeight
309       dst_subresource_layers, // imageSubresource
310       create_offset3d(dst_offset), // imageOffset
311       create_extent3d(copy_range), // imageExtent
312   };
313 
314   vkCmdCopyBufferToImage(
315       handle_,
316       source.handle(),
317       destination.handle(),
318       destination.layout(),
319       1u,
320       &copy_details);
321 
322   state_ = CommandBuffer::State::RECORDING;
323 }
324 
write_timestamp(VkQueryPool querypool,const uint32_t idx) const325 void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
326     const {
327   VK_CHECK_COND(
328       state_ == CommandBuffer::State::RECORDING,
329       "Vulkan CommandBuffer: called write_timestamp() on a command buffer whose state "
330       "is not RECORDING.");
331 
332   vkCmdWriteTimestamp(
333       handle_, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool, idx);
334 }
335 
reset_querypool(VkQueryPool querypool,const uint32_t first_idx,const uint32_t count) const336 void CommandBuffer::reset_querypool(
337     VkQueryPool querypool,
338     const uint32_t first_idx,
339     const uint32_t count) const {
340   VK_CHECK_COND(
341       state_ == CommandBuffer::State::RECORDING,
342       "Vulkan CommandBuffer: called reset_querypool() on a command buffer whose state "
343       "is not RECORDING.");
344 
345   vkCmdResetQueryPool(handle_, querypool, first_idx, count);
346 }
347 
get_submit_handle(const bool final_use)348 VkCommandBuffer CommandBuffer::get_submit_handle(const bool final_use) {
349   VK_CHECK_COND(
350       state_ == CommandBuffer::State::READY,
351       "Vulkan CommandBuffer: called begin() on a command buffer whose state "
352       "is not READY.");
353 
354   VkCommandBuffer handle = handle_;
355 
356   if (!is_reusable() || final_use) {
357     invalidate();
358   }
359   state_ = CommandBuffer::State::SUBMITTED;
360 
361   return handle;
362 }
363 
364 //
365 // CommandPool
366 //
367 
CommandPool(VkDevice device,const uint32_t queue_family_idx,const CommandPoolConfig & config)368 CommandPool::CommandPool(
369     VkDevice device,
370     const uint32_t queue_family_idx,
371     const CommandPoolConfig& config)
372     : device_(device),
373       queue_family_idx_(queue_family_idx),
374       pool_(VK_NULL_HANDLE),
375       config_(config),
376       mutex_{},
377       buffers_{},
378       in_use_(0u) {
379   const VkCommandPoolCreateInfo create_info{
380       VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO,
381       nullptr,
382       VK_COMMAND_POOL_CREATE_TRANSIENT_BIT,
383       queue_family_idx_,
384   };
385 
386   VK_CHECK(vkCreateCommandPool(device_, &create_info, nullptr, &pool_));
387 
388   // Pre-allocate some command buffers
389   allocate_new_batch(config_.cmdPoolInitialSize);
390 }
391 
~CommandPool()392 CommandPool::~CommandPool() {
393   if (VK_NULL_HANDLE == pool_) {
394     return;
395   }
396   vkDestroyCommandPool(device_, pool_, nullptr);
397 }
398 
get_new_cmd(bool reusable)399 CommandBuffer CommandPool::get_new_cmd(bool reusable) {
400   std::lock_guard<std::mutex> lock(mutex_);
401 
402   // No-ops if there are command buffers available
403   allocate_new_batch(config_.cmdPoolBatchSize);
404 
405   VkCommandBuffer handle = buffers_[in_use_];
406 
407   VkCommandBufferUsageFlags cmd_flags = 0u;
408   if (!reusable) {
409     cmd_flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
410   }
411 
412   in_use_++;
413   return CommandBuffer(handle, cmd_flags);
414 }
415 
flush()416 void CommandPool::flush() {
417   std::lock_guard<std::mutex> lock(mutex_);
418   VK_CHECK(vkResetCommandPool(device_, pool_, 0u));
419   in_use_ = 0u;
420 }
421 
allocate_new_batch(const uint32_t count)422 void CommandPool::allocate_new_batch(const uint32_t count) {
423   // No-ops if there are still command buffers available
424   if (in_use_ < buffers_.size()) {
425     return;
426   }
427 
428   buffers_.resize(buffers_.size() + count);
429 
430   const VkCommandBufferAllocateInfo allocate_info{
431       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, // sType
432       nullptr, // pNext
433       pool_, // commandPool
434       VK_COMMAND_BUFFER_LEVEL_PRIMARY, // level
435       count, // commandBufferCount
436   };
437 
438   VK_CHECK(vkAllocateCommandBuffers(
439       device_, &allocate_info, buffers_.data() + in_use_));
440 }
441 
442 } // namespace api
443 } // namespace vulkan
444 } // namespace native
445 } // namespace at
446