1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/vk_api.h> 8 9 #include <ATen/native/vulkan/api/Descriptor.h> 10 #include <ATen/native/vulkan/api/Pipeline.h> 11 #include <ATen/native/vulkan/api/Resource.h> 12 #include <ATen/native/vulkan/api/Shader.h> 13 #include <ATen/native/vulkan/api/Utils.h> 14 15 namespace at { 16 namespace native { 17 namespace vulkan { 18 namespace api { 19 20 class CommandBuffer final { 21 public: 22 explicit CommandBuffer(VkCommandBuffer, const VkCommandBufferUsageFlags); 23 24 CommandBuffer(const CommandBuffer&) = delete; 25 CommandBuffer& operator=(const CommandBuffer&) = delete; 26 27 CommandBuffer(CommandBuffer&&) noexcept; 28 CommandBuffer& operator=(CommandBuffer&&) noexcept; 29 30 ~CommandBuffer() = default; 31 32 // The lifecycle of a command buffer is as follows: 33 enum State { 34 INVALID, // Used to indicate the command buffer is moved from 35 NEW, // Set during constructor 36 RECORDING, // Set during call to begin(), dispatch(), and 37 // copy_*_to_*() 38 PIPELINE_BOUND, // Set during call to bind_pipeline() 39 DESCRIPTORS_BOUND, // Set during call to bind_descriptors() 40 BARRIERS_INSERTED, // Set during call to insert_barrier() 41 READY, // Set during call to end() 42 SUBMITTED, // Set during call to get_submit_handle() 43 }; 44 45 struct Bound { 46 VkPipeline pipeline; 47 VkPipelineLayout pipeline_layout; 48 utils::uvec3 local_workgroup_size; 49 VkDescriptorSet descriptors; 50 BoundBound51 explicit Bound() 52 : pipeline{VK_NULL_HANDLE}, 53 pipeline_layout{VK_NULL_HANDLE}, 54 local_workgroup_size{0u, 0u, 0u}, 55 descriptors{VK_NULL_HANDLE} {} 56 resetBound57 inline void reset() { 58 pipeline = VK_NULL_HANDLE; 59 pipeline_layout = VK_NULL_HANDLE; 60 local_workgroup_size = {0u, 0u, 0u}; 61 descriptors = VK_NULL_HANDLE; 62 } 63 }; 64 65 private: 66 VkCommandBuffer handle_; 67 VkCommandBufferUsageFlags flags_; 68 State state_; 69 Bound bound_; 70 71 public: is_reusable()72 inline bool is_reusable() { 73 return !(flags_ & VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT); 74 } 75 invalidate()76 inline void invalidate() { 77 handle_ = VK_NULL_HANDLE; 78 bound_.reset(); 79 } 80 81 void begin(); 82 void end(); 83 84 void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3); 85 void bind_descriptors(VkDescriptorSet); 86 87 void insert_barrier(PipelineBarrier& pipeline_barrier); 88 void dispatch(const utils::uvec3&); 89 90 void copy_buffer_to_buffer( 91 const api::VulkanBuffer&, 92 const api::VulkanBuffer&, 93 const api::utils::uvec3&, 94 const api::utils::uvec3&, 95 const api::utils::uvec3&); 96 97 void copy_texture_to_texture( 98 const api::VulkanImage&, 99 const api::VulkanImage&, 100 const api::utils::uvec3&, 101 const api::utils::uvec3&, 102 const api::utils::uvec3&); 103 104 void copy_texture_to_buffer( 105 const api::VulkanImage&, 106 const api::VulkanBuffer&, 107 const api::utils::uvec3&, 108 const api::utils::uvec3&, 109 const api::utils::uvec3&); 110 111 void copy_buffer_to_texture( 112 const api::VulkanBuffer&, 113 const api::VulkanImage&, 114 const api::utils::uvec3&, 115 const api::utils::uvec3&, 116 const api::utils::uvec3&); 117 118 void write_timestamp(VkQueryPool, const uint32_t) const; 119 void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const; 120 121 VkCommandBuffer get_submit_handle(const bool final_use = false); 122 123 inline operator bool() const { 124 return VK_NULL_HANDLE != handle_; 125 } 126 }; 127 128 struct CommandPoolConfig final { 129 uint32_t cmdPoolInitialSize; 130 uint32_t cmdPoolBatchSize; 131 }; 132 133 class CommandPool final { 134 public: 135 explicit CommandPool(VkDevice, const uint32_t, const CommandPoolConfig&); 136 137 CommandPool(const CommandPool&) = delete; 138 CommandPool& operator=(const CommandPool&) = delete; 139 140 CommandPool(CommandPool&&) = delete; 141 CommandPool& operator=(CommandPool&&) = delete; 142 143 ~CommandPool(); 144 145 private: 146 VkDevice device_; 147 uint32_t queue_family_idx_; 148 VkCommandPool pool_; 149 CommandPoolConfig config_; 150 // New Buffers 151 std::mutex mutex_; 152 std::vector<VkCommandBuffer> buffers_; 153 size_t in_use_; 154 155 public: 156 CommandBuffer get_new_cmd(bool reusable = false); 157 158 void flush(); 159 160 private: 161 void allocate_new_batch(const uint32_t); 162 }; 163 164 } // namespace api 165 } // namespace vulkan 166 } // namespace native 167 } // namespace at 168 169 #endif /* USE_VULKAN_API */ 170