/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName #include #include #include #include #include #include #define SV(x) ::vkcompute::vkapi::SpecVar(x) namespace vkcompute { namespace vkapi { struct SpecVar final { enum class Type : uint8_t { FLOAT, INT, UINT, BOOL, }; union Value { int32_t as_int32; uint32_t as_uint32; float as_float; bool as_bool; }; Value value; Type type; SpecVar(); SpecVar(const float val); SpecVar(const int32_t val); SpecVar(const uint32_t val); SpecVar(const bool val); uint32_t val_size() const; uint32_t val_offset() const; }; bool operator==(const SpecVar& lhs, const SpecVar& rhs); bool operator!=(const SpecVar& lhs, const SpecVar& rhs); class SpecVarList final { std::vector vars; public: SpecVarList(); SpecVarList(std::initializer_list init_list); inline const SpecVar& at(const size_t index) const { return vars.at(index); } inline const SpecVar* data() const { return vars.data(); } inline uint32_t size() const { return utils::safe_downcast(vars.size()); } inline uint32_t data_nbytes() const { return vars.size() * sizeof(SpecVar); } void append(const SpecVarList& other); std::vector generate_map_entries() const; friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); }; bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); struct PipelineBarrier final { struct Stages final { VkPipelineStageFlags src; VkPipelineStageFlags dst; } stage; std::vector buffers; std::vector images; std::vector buffer_barrier_handles; std::vector image_barrier_handles; inline operator bool() const { return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() || !images.empty(); } }; using PipelineStageFlags = uint8_t; enum PipelineStage : PipelineStageFlags { NO_STAGE = 0u << 0u, COMPUTE = 1u << 0u, HOST = 1u << 1u, TRANSFER = 1u << 2u, }; VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags); VkPipelineStageFlags vk_stage(const PipelineStageFlags); VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags); class PipelineLayout final { public: explicit PipelineLayout(VkDevice, VkDescriptorSetLayout); PipelineLayout(const PipelineLayout&) = delete; PipelineLayout& operator=(const PipelineLayout&) = delete; PipelineLayout(PipelineLayout&&) noexcept; PipelineLayout& operator=(PipelineLayout&&) = delete; ~PipelineLayout(); private: VkDevice device_; VkPipelineLayout handle_; public: VkPipelineLayout handle() const { return handle_; } // We need to define a custom swap function since this class // does not allow for move assignment. The swap function will // be used in the hash map. friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept; }; class ComputePipeline final { public: struct Descriptor final { VkPipelineLayout pipeline_layout; VkShaderModule shader_module; SpecVarList specialization_constants; }; explicit ComputePipeline( VkDevice device, const Descriptor& descriptor, VkPipelineCache pipeline_cache); ComputePipeline(const ComputePipeline&) = delete; ComputePipeline& operator=(const ComputePipeline&) = delete; ComputePipeline(ComputePipeline&&) noexcept; ComputePipeline& operator=(ComputePipeline&&) = delete; ~ComputePipeline(); private: VkDevice device_; VkPipeline handle_; public: inline VkPipeline handle() const { return handle_; } // We need to define a custom swap function since this class // does not allow for move assignment. The swap function will // be used in the hash map. friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept; }; class PipelineLayoutCache final { public: explicit PipelineLayoutCache(VkDevice device); PipelineLayoutCache(const PipelineLayoutCache&) = delete; PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete; PipelineLayoutCache(PipelineLayoutCache&&) noexcept; PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete; ~PipelineLayoutCache(); using Key = VkDescriptorSetLayout; using Value = PipelineLayout; struct Hasher { inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const { return std::hash()(descriptor_layout); } }; private: // Multiple threads could potentially be adding entries into the cache, so use // a mutex to manage access std::mutex cache_mutex_; VkDevice device_; std::unordered_map cache_; public: VkPipelineLayout retrieve(const Key&); void purge(); }; class ComputePipelineCache final { public: explicit ComputePipelineCache( VkDevice device, const std::string& cache_data_path); ComputePipelineCache(const ComputePipelineCache&) = delete; ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; ComputePipelineCache(ComputePipelineCache&&) noexcept; ComputePipelineCache& operator=(ComputePipelineCache&&) = delete; ~ComputePipelineCache(); using Key = ComputePipeline::Descriptor; using Value = ComputePipeline; struct Hasher { inline size_t operator()( const ComputePipeline::Descriptor& descriptor) const { size_t seed = 0; seed = utils::hash_combine( seed, std::hash()(descriptor.pipeline_layout)); seed = utils::hash_combine( seed, std::hash()(descriptor.shader_module)); const SpecVarList& spec_vars = descriptor.specialization_constants; seed = utils::hash_combine(seed, std::hash()(spec_vars.size())); for (int i = 0; i < spec_vars.size(); ++i) { const SpecVar& spec_var = spec_vars.at(i); size_t new_seed = 0; switch (spec_var.type) { case SpecVar::Type::FLOAT: new_seed = std::hash()(spec_var.value.as_float); break; case SpecVar::Type::INT: new_seed = std::hash()(spec_var.value.as_int32); break; case SpecVar::Type::UINT: new_seed = std::hash()(spec_var.value.as_uint32); break; case SpecVar::Type::BOOL: new_seed = std::hash()(spec_var.value.as_bool); break; } seed = utils::hash_combine(seed, new_seed); } return seed; } }; void save_cache(); private: std::vector load_cache(); // Multiple threads could potentially be adding entries into the cache, so use // a mutex to manage access std::mutex cache_mutex_; VkDevice device_; VkPipelineCache pipeline_cache_; std::unordered_map cache_; const std::string cache_data_path_; public: VkPipeline retrieve(const Key&); void purge(); }; // // Impl // } // namespace vkapi } // namespace vkcompute