1 #include <ATen/native/vulkan/api/Adapter.h>
2 #include <ATen/native/vulkan/api/Command.h>
3
4 #include <mutex>
5
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace api {
10
11 //
12 // CommandBuffer
13 //
14
CommandBuffer(VkCommandBuffer handle,const VkCommandBufferUsageFlags flags)15 CommandBuffer::CommandBuffer(
16 VkCommandBuffer handle,
17 const VkCommandBufferUsageFlags flags)
18 : handle_(handle),
19 flags_(flags),
20 state_(CommandBuffer::State::NEW),
21 bound_{} {}
22
CommandBuffer(CommandBuffer && other)23 CommandBuffer::CommandBuffer(CommandBuffer&& other) noexcept
24 : handle_(other.handle_),
25 flags_(other.flags_),
26 state_(CommandBuffer::State::INVALID),
27 bound_(other.bound_) {
28 other.handle_ = VK_NULL_HANDLE;
29 other.bound_.reset();
30 }
31
operator =(CommandBuffer && other)32 CommandBuffer& CommandBuffer::operator=(CommandBuffer&& other) noexcept {
33 handle_ = other.handle_;
34 flags_ = other.flags_;
35 state_ = other.state_;
36 bound_ = other.bound_;
37
38 other.handle_ = VK_NULL_HANDLE;
39 other.bound_.reset();
40 other.state_ = CommandBuffer::State::INVALID;
41
42 return *this;
43 }
44
begin()45 void CommandBuffer::begin() {
46 VK_CHECK_COND(
47 state_ == CommandBuffer::State::NEW,
48 "Vulkan CommandBuffer: called begin() on a command buffer whose state "
49 "is not NEW.");
50
51 const VkCommandBufferBeginInfo begin_info{
52 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
53 nullptr,
54 flags_,
55 nullptr,
56 };
57
58 VK_CHECK(vkBeginCommandBuffer(handle_, &begin_info));
59 state_ = CommandBuffer::State::RECORDING;
60 }
61
end()62 void CommandBuffer::end() {
63 VK_CHECK_COND(
64 state_ == CommandBuffer::State::RECORDING ||
65 state_ == CommandBuffer::State::SUBMITTED,
66 "Vulkan CommandBuffer: called end() on a command buffer whose state "
67 "is not RECORDING or SUBMITTED.");
68
69 if (state_ == CommandBuffer::State::RECORDING) {
70 VK_CHECK(vkEndCommandBuffer(handle_));
71 }
72 state_ = CommandBuffer::State::READY;
73 }
74
bind_pipeline(VkPipeline pipeline,VkPipelineLayout pipeline_layout,const utils::uvec3 local_workgroup_size)75 void CommandBuffer::bind_pipeline(
76 VkPipeline pipeline,
77 VkPipelineLayout pipeline_layout,
78 const utils::uvec3 local_workgroup_size) {
79 VK_CHECK_COND(
80 state_ == CommandBuffer::State::RECORDING,
81 "Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state "
82 "is not RECORDING.");
83
84 if (pipeline != bound_.pipeline) {
85 vkCmdBindPipeline(handle_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
86
87 bound_.pipeline = pipeline;
88 }
89
90 bound_.pipeline_layout = pipeline_layout;
91 bound_.local_workgroup_size = local_workgroup_size;
92
93 state_ = CommandBuffer::State::PIPELINE_BOUND;
94 }
95
bind_descriptors(VkDescriptorSet descriptors)96 void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
97 VK_CHECK_COND(
98 state_ == CommandBuffer::State::PIPELINE_BOUND,
99 "Vulkan CommandBuffer: called bind_descriptors() on a command buffer whose state "
100 "is not PIPELINE_BOUND.");
101
102 if (descriptors != bound_.descriptors) {
103 vkCmdBindDescriptorSets(
104 handle_, // commandBuffer
105 VK_PIPELINE_BIND_POINT_COMPUTE, // pipelineBindPoint
106 bound_.pipeline_layout, // layout
107 0u, // firstSet
108 1u, // descriptorSetCount
109 &descriptors, // pDescriptorSets
110 0u, // dynamicOffsetCount
111 nullptr); // pDynamicOffsets
112 }
113
114 bound_.descriptors = descriptors;
115
116 state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
117 }
118
insert_barrier(PipelineBarrier & pipeline_barrier)119 void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
120 VK_CHECK_COND(
121 state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
122 state_ == CommandBuffer::State::RECORDING,
123 "Vulkan CommandBuffer: called insert_barrier() on a command buffer whose state "
124 "is not DESCRIPTORS_BOUND or RECORDING.");
125
126 if (pipeline_barrier) {
127 if (!pipeline_barrier.buffer_barrier_handles.empty()) {
128 pipeline_barrier.buffer_barrier_handles.clear();
129 }
130 for (const api::BufferMemoryBarrier& memory_barrier :
131 pipeline_barrier.buffers) {
132 pipeline_barrier.buffer_barrier_handles.push_back(memory_barrier.handle);
133 }
134
135 if (!pipeline_barrier.image_barrier_handles.empty()) {
136 pipeline_barrier.image_barrier_handles.clear();
137 }
138 for (const api::ImageMemoryBarrier& memory_barrier :
139 pipeline_barrier.images) {
140 pipeline_barrier.image_barrier_handles.push_back(memory_barrier.handle);
141 }
142 vkCmdPipelineBarrier(
143 handle_, // commandBuffer
144 pipeline_barrier.stage.src, // srcStageMask
145 pipeline_barrier.stage.dst, // dstStageMask
146 0u, // dependencyFlags
147 0u, // memoryBarrierCount
148 nullptr, // pMemoryBarriers
149 pipeline_barrier.buffers.size(), // bufferMemoryBarrierCount
150 !pipeline_barrier.buffers.empty()
151 ? pipeline_barrier.buffer_barrier_handles.data()
152 : nullptr, // pMemoryBarriers
153 pipeline_barrier.images.size(), // imageMemoryBarrierCount
154 !pipeline_barrier.images.empty()
155 ? pipeline_barrier.image_barrier_handles.data()
156 : nullptr); // pImageMemoryBarriers
157 }
158
159 state_ = CommandBuffer::State::BARRIERS_INSERTED;
160 }
161
dispatch(const utils::uvec3 & global_workgroup_size)162 void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) {
163 VK_CHECK_COND(
164 state_ == CommandBuffer::State::BARRIERS_INSERTED,
165 "Vulkan CommandBuffer: called dispatch() on a command buffer whose state "
166 "is not BARRIERS_INSERTED.");
167
168 vkCmdDispatch(
169 handle_,
170 utils::div_up(
171 global_workgroup_size.data[0u], bound_.local_workgroup_size.data[0u]),
172 utils::div_up(
173 global_workgroup_size.data[1u], bound_.local_workgroup_size.data[1u]),
174 utils::div_up(
175 global_workgroup_size.data[2u],
176 bound_.local_workgroup_size.data[2u]));
177
178 state_ = CommandBuffer::State::RECORDING;
179 }
180
copy_buffer_to_buffer(const api::VulkanBuffer & source,const api::VulkanBuffer & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)181 void CommandBuffer::copy_buffer_to_buffer(
182 const api::VulkanBuffer& source,
183 const api::VulkanBuffer& destination,
184 const api::utils::uvec3& copy_range,
185 const api::utils::uvec3& src_offset,
186 const api::utils::uvec3& dst_offset) {
187 VK_CHECK_COND(
188 state_ == CommandBuffer::State::BARRIERS_INSERTED,
189 "Vulkan CommandBuffer: called copy_buffer_to_buffer() on a command buffer whose state "
190 "is not BARRIERS_INSERTED.");
191
192 const VkBufferCopy copy_details{
193 src_offset.data[0u], // srcOffset
194 dst_offset.data[0u], // dstOffset
195 copy_range.data[0u], // size
196 };
197
198 vkCmdCopyBuffer(
199 handle_, source.handle(), destination.handle(), 1u, ©_details);
200
201 state_ = CommandBuffer::State::RECORDING;
202 }
203
copy_texture_to_texture(const api::VulkanImage & source,const api::VulkanImage & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)204 void CommandBuffer::copy_texture_to_texture(
205 const api::VulkanImage& source,
206 const api::VulkanImage& destination,
207 const api::utils::uvec3& copy_range,
208 const api::utils::uvec3& src_offset,
209 const api::utils::uvec3& dst_offset) {
210 VK_CHECK_COND(
211 state_ == CommandBuffer::State::BARRIERS_INSERTED,
212 "Vulkan CommandBuffer: called copy_texture_to_texture() on a command buffer whose state "
213 "is not BARRIERS_INSERTED.");
214
215 const VkImageSubresourceLayers src_subresource_layers{
216 VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
217 0u, // mipLevel
218 0u, // baseArrayLayer
219 1u, // layerCount
220 };
221
222 const VkImageSubresourceLayers dst_subresource_layers{
223 VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
224 0u, // mipLevel
225 0u, // baseArrayLayer
226 1u, // layerCount
227 };
228
229 const VkImageCopy copy_details{
230 src_subresource_layers, // srcSubresource
231 create_offset3d(src_offset), // srcOffset
232 dst_subresource_layers, // dstSubresource
233 create_offset3d(dst_offset), // dstOffset
234 create_extent3d(copy_range), // extent
235 };
236
237 vkCmdCopyImage(
238 handle_,
239 source.handle(),
240 source.layout(),
241 destination.handle(),
242 destination.layout(),
243 1u,
244 ©_details);
245
246 state_ = CommandBuffer::State::RECORDING;
247 }
248
copy_texture_to_buffer(const api::VulkanImage & source,const api::VulkanBuffer & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)249 void CommandBuffer::copy_texture_to_buffer(
250 const api::VulkanImage& source,
251 const api::VulkanBuffer& destination,
252 const api::utils::uvec3& copy_range,
253 const api::utils::uvec3& src_offset,
254 const api::utils::uvec3& dst_offset) {
255 VK_CHECK_COND(
256 state_ == CommandBuffer::State::BARRIERS_INSERTED,
257 "Vulkan CommandBuffer: called copy_texture_to_buffer() on a command buffer whose state "
258 "is not BARRIERS_INSERTED.");
259
260 const VkImageSubresourceLayers src_subresource_layers{
261 VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
262 0u, // mipLevel
263 0u, // baseArrayLayer
264 1u, // layerCount
265 };
266
267 const VkBufferImageCopy copy_details{
268 dst_offset.data[0u], // bufferOffset
269 dst_offset.data[1u], // bufferRowLength
270 dst_offset.data[2u], // bufferImageHeight
271 src_subresource_layers, // imageSubresource
272 create_offset3d(src_offset), // imageOffset
273 create_extent3d(copy_range), // imageExtent
274 };
275
276 vkCmdCopyImageToBuffer(
277 handle_,
278 source.handle(),
279 source.layout(),
280 destination.handle(),
281 1u,
282 ©_details);
283
284 state_ = CommandBuffer::State::RECORDING;
285 }
286
copy_buffer_to_texture(const api::VulkanBuffer & source,const api::VulkanImage & destination,const api::utils::uvec3 & copy_range,const api::utils::uvec3 & src_offset,const api::utils::uvec3 & dst_offset)287 void CommandBuffer::copy_buffer_to_texture(
288 const api::VulkanBuffer& source,
289 const api::VulkanImage& destination,
290 const api::utils::uvec3& copy_range,
291 const api::utils::uvec3& src_offset,
292 const api::utils::uvec3& dst_offset) {
293 VK_CHECK_COND(
294 state_ == CommandBuffer::State::BARRIERS_INSERTED,
295 "Vulkan CommandBuffer: called copy_buffer_to_texture() on a command buffer whose state "
296 "is not BARRIERS_INSERTED.");
297
298 const VkImageSubresourceLayers dst_subresource_layers{
299 VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask
300 0u, // mipLevel
301 0u, // baseArrayLayer
302 1u, // layerCount
303 };
304
305 const VkBufferImageCopy copy_details{
306 src_offset.data[0u], // bufferOffset
307 src_offset.data[1u], // bufferRowLength
308 src_offset.data[2u], // bufferImageHeight
309 dst_subresource_layers, // imageSubresource
310 create_offset3d(dst_offset), // imageOffset
311 create_extent3d(copy_range), // imageExtent
312 };
313
314 vkCmdCopyBufferToImage(
315 handle_,
316 source.handle(),
317 destination.handle(),
318 destination.layout(),
319 1u,
320 ©_details);
321
322 state_ = CommandBuffer::State::RECORDING;
323 }
324
write_timestamp(VkQueryPool querypool,const uint32_t idx) const325 void CommandBuffer::write_timestamp(VkQueryPool querypool, const uint32_t idx)
326 const {
327 VK_CHECK_COND(
328 state_ == CommandBuffer::State::RECORDING,
329 "Vulkan CommandBuffer: called write_timestamp() on a command buffer whose state "
330 "is not RECORDING.");
331
332 vkCmdWriteTimestamp(
333 handle_, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, querypool, idx);
334 }
335
reset_querypool(VkQueryPool querypool,const uint32_t first_idx,const uint32_t count) const336 void CommandBuffer::reset_querypool(
337 VkQueryPool querypool,
338 const uint32_t first_idx,
339 const uint32_t count) const {
340 VK_CHECK_COND(
341 state_ == CommandBuffer::State::RECORDING,
342 "Vulkan CommandBuffer: called reset_querypool() on a command buffer whose state "
343 "is not RECORDING.");
344
345 vkCmdResetQueryPool(handle_, querypool, first_idx, count);
346 }
347
get_submit_handle(const bool final_use)348 VkCommandBuffer CommandBuffer::get_submit_handle(const bool final_use) {
349 VK_CHECK_COND(
350 state_ == CommandBuffer::State::READY,
351 "Vulkan CommandBuffer: called begin() on a command buffer whose state "
352 "is not READY.");
353
354 VkCommandBuffer handle = handle_;
355
356 if (!is_reusable() || final_use) {
357 invalidate();
358 }
359 state_ = CommandBuffer::State::SUBMITTED;
360
361 return handle;
362 }
363
364 //
365 // CommandPool
366 //
367
CommandPool(VkDevice device,const uint32_t queue_family_idx,const CommandPoolConfig & config)368 CommandPool::CommandPool(
369 VkDevice device,
370 const uint32_t queue_family_idx,
371 const CommandPoolConfig& config)
372 : device_(device),
373 queue_family_idx_(queue_family_idx),
374 pool_(VK_NULL_HANDLE),
375 config_(config),
376 mutex_{},
377 buffers_{},
378 in_use_(0u) {
379 const VkCommandPoolCreateInfo create_info{
380 VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO,
381 nullptr,
382 VK_COMMAND_POOL_CREATE_TRANSIENT_BIT,
383 queue_family_idx_,
384 };
385
386 VK_CHECK(vkCreateCommandPool(device_, &create_info, nullptr, &pool_));
387
388 // Pre-allocate some command buffers
389 allocate_new_batch(config_.cmdPoolInitialSize);
390 }
391
~CommandPool()392 CommandPool::~CommandPool() {
393 if (VK_NULL_HANDLE == pool_) {
394 return;
395 }
396 vkDestroyCommandPool(device_, pool_, nullptr);
397 }
398
get_new_cmd(bool reusable)399 CommandBuffer CommandPool::get_new_cmd(bool reusable) {
400 std::lock_guard<std::mutex> lock(mutex_);
401
402 // No-ops if there are command buffers available
403 allocate_new_batch(config_.cmdPoolBatchSize);
404
405 VkCommandBuffer handle = buffers_[in_use_];
406
407 VkCommandBufferUsageFlags cmd_flags = 0u;
408 if (!reusable) {
409 cmd_flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
410 }
411
412 in_use_++;
413 return CommandBuffer(handle, cmd_flags);
414 }
415
flush()416 void CommandPool::flush() {
417 std::lock_guard<std::mutex> lock(mutex_);
418 VK_CHECK(vkResetCommandPool(device_, pool_, 0u));
419 in_use_ = 0u;
420 }
421
allocate_new_batch(const uint32_t count)422 void CommandPool::allocate_new_batch(const uint32_t count) {
423 // No-ops if there are still command buffers available
424 if (in_use_ < buffers_.size()) {
425 return;
426 }
427
428 buffers_.resize(buffers_.size() + count);
429
430 const VkCommandBufferAllocateInfo allocate_info{
431 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, // sType
432 nullptr, // pNext
433 pool_, // commandPool
434 VK_COMMAND_BUFFER_LEVEL_PRIMARY, // level
435 count, // commandBufferCount
436 };
437
438 VK_CHECK(vkAllocateCommandBuffers(
439 device_, &allocate_info, buffers_.data() + in_use_));
440 }
441
442 } // namespace api
443 } // namespace vulkan
444 } // namespace native
445 } // namespace at
446