/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include namespace vkcompute { namespace vkapi { // // ShaderInfo // ShaderInfo::ShaderInfo() : src_code{ nullptr, 0u, } {} ShaderInfo::ShaderInfo( std::string name, const uint32_t* const spirv_bin, const uint32_t size, std::vector layout, const utils::uvec3 tile_size) : src_code{ spirv_bin, size, }, kernel_name{std::move(name)}, kernel_layout{std::move(layout)}, out_tile_size(tile_size) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { return ( _1.src_code.bin == _2.src_code.bin && _1.src_code.size == _2.src_code.size); } // // ShaderLayout // ShaderLayout::ShaderLayout( VkDevice device, const ShaderLayout::Signature& signature) : device_(device), handle_{VK_NULL_HANDLE} { std::vector bindings; uint32_t binding_num = 0u; for (const VkDescriptorType type : signature) { bindings.push_back({ binding_num++, // binding type, // descriptorType 1u, // descriptorCount VK_SHADER_STAGE_COMPUTE_BIT, // stageFlags nullptr, // pImmutableSamplers }); } const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{ VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType nullptr, // pNext 0u, // flags static_cast(bindings.size()), // bindingCount bindings.data(), // pBindings }; VK_CHECK(vkCreateDescriptorSetLayout( device_, &descriptor_set_layout_create_info, nullptr, &handle_)); } ShaderLayout::ShaderLayout(ShaderLayout&& other) noexcept : device_(other.device_), handle_(other.handle_) { other.handle_ = VK_NULL_HANDLE; } ShaderLayout::~ShaderLayout() { if (handle_ == VK_NULL_HANDLE) { return; } vkDestroyDescriptorSetLayout(device_, handle_, nullptr); handle_ = VK_NULL_HANDLE; } void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept { VkDevice tmp_device = lhs.device_; VkDescriptorSetLayout tmp_handle = lhs.handle_; lhs.device_ = rhs.device_; lhs.handle_ = rhs.handle_; rhs.device_ = tmp_device; rhs.handle_ = tmp_handle; } // // ShaderModule // ShaderModule::ShaderModule(VkDevice device, const ShaderInfo& source) : device_(device), handle_{VK_NULL_HANDLE} { const uint32_t* code = source.src_code.bin; uint32_t size = source.src_code.size; const VkShaderModuleCreateInfo shader_module_create_info{ VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType nullptr, // pNext 0u, // flags size, // codeSize code, // pCode }; VK_CHECK(vkCreateShaderModule( device_, &shader_module_create_info, nullptr, &handle_)); } ShaderModule::ShaderModule(ShaderModule&& other) noexcept : device_(other.device_), handle_(other.handle_) { other.handle_ = VK_NULL_HANDLE; } ShaderModule::~ShaderModule() { if (handle_ == VK_NULL_HANDLE) { return; } vkDestroyShaderModule(device_, handle_, nullptr); handle_ = VK_NULL_HANDLE; } void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept { VkDevice tmp_device = lhs.device_; VkShaderModule tmp_handle = lhs.handle_; lhs.device_ = rhs.device_; lhs.handle_ = rhs.handle_; rhs.device_ = tmp_device; rhs.handle_ = tmp_handle; } // // ShaderLayoutCache // ShaderLayoutCache::ShaderLayoutCache(VkDevice device) : cache_mutex_{}, device_(device), cache_{} {} ShaderLayoutCache::ShaderLayoutCache(ShaderLayoutCache&& other) noexcept : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { std::lock_guard lock(other.cache_mutex_); } ShaderLayoutCache::~ShaderLayoutCache() { purge(); } VkDescriptorSetLayout ShaderLayoutCache::retrieve( const ShaderLayoutCache::Key& key) { std::lock_guard lock(cache_mutex_); auto it = cache_.find(key); if (cache_.cend() == it) { it = cache_.insert({key, ShaderLayoutCache::Value(device_, key)}).first; } return it->second.handle(); } void ShaderLayoutCache::purge() { std::lock_guard lock(cache_mutex_); cache_.clear(); } // // ShaderCache // ShaderCache::ShaderCache(VkDevice device) : cache_mutex_{}, device_(device), cache_{} {} ShaderCache::ShaderCache(ShaderCache&& other) noexcept : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { std::lock_guard lock(other.cache_mutex_); } ShaderCache::~ShaderCache() { purge(); } VkShaderModule ShaderCache::retrieve(const ShaderCache::Key& key) { std::lock_guard lock(cache_mutex_); auto it = cache_.find(key); if (cache_.cend() == it) { it = cache_.insert({key, ShaderCache::Value(device_, key)}).first; } return it->second.handle(); } void ShaderCache::purge() { cache_.clear(); } } // namespace vkapi } // namespace vkcompute