1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h> 14 15 #include <executorch/backends/vulkan/runtime/utils/VecUtils.h> 16 17 #include <executorch/backends/vulkan/runtime/vk_api/Descriptor.h> 18 #include <executorch/backends/vulkan/runtime/vk_api/Pipeline.h> 19 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h> 20 21 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h> 22 #include <executorch/backends/vulkan/runtime/vk_api/memory/Image.h> 23 24 namespace vkcompute { 25 namespace vkapi { 26 27 class CommandBuffer final { 28 public: 29 explicit CommandBuffer(VkCommandBuffer, const VkCommandBufferUsageFlags); 30 31 CommandBuffer(const CommandBuffer&) = delete; 32 CommandBuffer& operator=(const CommandBuffer&) = delete; 33 34 CommandBuffer(CommandBuffer&&) noexcept; 35 CommandBuffer& operator=(CommandBuffer&&) noexcept; 36 37 ~CommandBuffer() = default; 38 39 // The lifecycle of a command buffer is as follows: 40 enum State { 41 INVALID, // Used to indicate the command buffer is moved from 42 NEW, // Set during constructor 43 RECORDING, // Set during call to begin() and dispatch() 44 PIPELINE_BOUND, // Set during call to bind_pipeline() 45 DESCRIPTORS_BOUND, // Set during call to bind_descriptors() 46 BARRIERS_INSERTED, // Set during call to insert_barrier() 47 READY, // Set during call to end() 48 SUBMITTED, // Set during call to get_submit_handle() 49 }; 50 51 struct Bound { 52 VkPipeline pipeline; 53 VkPipelineLayout pipeline_layout; 54 utils::uvec3 local_workgroup_size; 55 VkDescriptorSet descriptors; 56 BoundBound57 explicit Bound() 58 : pipeline{VK_NULL_HANDLE}, 59 pipeline_layout{VK_NULL_HANDLE}, 60 local_workgroup_size{0u, 0u, 0u}, 61 descriptors{VK_NULL_HANDLE} {} 62 resetBound63 inline void reset() { 64 pipeline = VK_NULL_HANDLE; 65 pipeline_layout = VK_NULL_HANDLE; 66 local_workgroup_size = {0u, 0u, 0u}; 67 descriptors = VK_NULL_HANDLE; 68 } 69 }; 70 71 private: 72 VkCommandBuffer handle_; 73 VkCommandBufferUsageFlags flags_; 74 State state_; 75 Bound bound_; 76 77 public: is_reusable()78 inline bool is_reusable() { 79 return !(flags_ & VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT); 80 } 81 invalidate()82 inline void invalidate() { 83 handle_ = VK_NULL_HANDLE; 84 bound_.reset(); 85 } 86 87 void begin(); 88 void end(); 89 90 void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3); 91 void bind_descriptors(VkDescriptorSet); 92 93 void insert_barrier(PipelineBarrier& pipeline_barrier); 94 void dispatch(const utils::uvec3&); 95 void blit(vkapi::VulkanImage& src, vkapi::VulkanImage& dst); 96 97 void write_timestamp(VkQueryPool, const uint32_t) const; 98 void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const; 99 100 VkCommandBuffer get_submit_handle(const bool final_use = false); 101 102 inline operator bool() const { 103 return handle_ != VK_NULL_HANDLE; 104 } 105 }; 106 107 struct CommandPoolConfig final { 108 uint32_t cmd_pool_initial_size; 109 uint32_t cmd_pool_batch_size; 110 }; 111 112 class CommandPool final { 113 public: 114 explicit CommandPool(VkDevice, const uint32_t, const CommandPoolConfig&); 115 116 CommandPool(const CommandPool&) = delete; 117 CommandPool& operator=(const CommandPool&) = delete; 118 119 CommandPool(CommandPool&&) = delete; 120 CommandPool& operator=(CommandPool&&) = delete; 121 122 ~CommandPool(); 123 124 private: 125 VkDevice device_; 126 uint32_t queue_family_idx_; 127 VkCommandPool pool_; 128 CommandPoolConfig config_; 129 // New Buffers 130 std::mutex mutex_; 131 std::vector<VkCommandBuffer> buffers_; 132 size_t in_use_; 133 134 public: 135 CommandBuffer get_new_cmd(bool reusable = false); 136 137 void flush(); 138 139 private: 140 void allocate_new_batch(const uint32_t); 141 }; 142 143 } // namespace vkapi 144 } // namespace vkcompute 145