1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "shader_module_vk.h"
17
18 #include <cstdint>
19 #include <vulkan/vulkan_core.h>
20
21 #include <render/device/pipeline_layout_desc.h>
22 #include <render/namespace.h>
23
24 #include "device/device.h"
25 #include "device/gpu_program_util.h"
26 #include "device/shader_manager.h"
27 #include "util/log.h"
28 #include "vulkan/device_vk.h"
29 #include "vulkan/validate_vk.h"
30
31 using namespace BASE_NS;
32
33 RENDER_BEGIN_NAMESPACE()
34 namespace {
CreateShaderModule(const VkDevice device,array_view<const uint8_t> data)35 VkShaderModule CreateShaderModule(const VkDevice device, array_view<const uint8_t> data)
36 {
37 PLUGIN_ASSERT(!data.empty());
38 VkShaderModule shaderModule { VK_NULL_HANDLE };
39
40 constexpr VkShaderModuleCreateFlags shaderModuleCreateFlags { 0 };
41 const VkShaderModuleCreateInfo shaderModuleCreateInfo {
42 VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType
43 nullptr, // pNext
44 shaderModuleCreateFlags, // flags
45 static_cast<uint32_t>(data.size()), // codeSize
46 reinterpret_cast<const uint32_t*>(data.data()) // pCode
47 };
48
49 VALIDATE_VK_RESULT(vkCreateShaderModule(device, // device
50 &shaderModuleCreateInfo, // pCreateInfo
51 nullptr, // pAllocator
52 &shaderModule)); // pShaderModule
53
54 return shaderModule;
55 }
56 } // namespace
57
ShaderModuleVk(Device & device,const ShaderModuleCreateInfo & createInfo)58 ShaderModuleVk::ShaderModuleVk(Device& device, const ShaderModuleCreateInfo& createInfo)
59 : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
60 {
61 PLUGIN_ASSERT(!createInfo.spvData.empty());
62 PLUGIN_ASSERT(createInfo.shaderStageFlags & (ShaderStageFlagBits::CORE_SHADER_STAGE_VERTEX_BIT |
63 ShaderStageFlagBits::CORE_SHADER_STAGE_FRAGMENT_BIT |
64 ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT));
65
66 bool valid = false;
67 if (createInfo.reflectionData.IsValid()) {
68 valid = true;
69 pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
70
71 constants_ = createInfo.reflectionData.GetSpecializationConstants();
72 sscv_.constants = constants_;
73
74 if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_VERTEX_BIT) {
75 vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
76 for (const auto& attrib : vertexInputAttributeDescriptions_) {
77 VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
78 bindingDesc.binding = attrib.binding;
79 bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
80 bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
81 vertexInputBindingDescriptions_.push_back(bindingDesc);
82 }
83 vidv_.bindingDescriptions = vertexInputBindingDescriptions_;
84 vidv_.attributeDescriptions = vertexInputAttributeDescriptions_;
85 } else if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_FRAGMENT_BIT) {
86 } else if (shaderStageFlags_ == ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT) {
87 const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
88 stg_.x = tgs[0u];
89 stg_.y = tgs[1u];
90 stg_.z = tgs[2u];
91 } else {
92 PLUGIN_LOG_E("invalid shader stage flags for module creation");
93 valid = false;
94 }
95 }
96
97 // NOTE: sorting not needed?
98
99 if (valid) {
100 const VkDevice vkDevice = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
101 plat_.shaderModule = CreateShaderModule(vkDevice, createInfo.spvData);
102 } else {
103 PLUGIN_LOG_E("invalid vulkan shader module");
104 }
105 }
106
~ShaderModuleVk()107 ShaderModuleVk::~ShaderModuleVk()
108 {
109 const VkDevice device = ((const DevicePlatformDataVk&)device_.GetPlatformData()).device;
110 if (plat_.shaderModule != VK_NULL_HANDLE) {
111 vkDestroyShaderModule(device, // device
112 plat_.shaderModule, // shaderModule
113 nullptr); // pAllocator
114 }
115 }
116
GetShaderStageFlags() const117 ShaderStageFlags ShaderModuleVk::GetShaderStageFlags() const
118 {
119 return shaderStageFlags_;
120 }
121
GetPlatformData() const122 const ShaderModulePlatformData& ShaderModuleVk::GetPlatformData() const
123 {
124 return plat_;
125 }
126
GetPipelineLayout() const127 const PipelineLayout& ShaderModuleVk::GetPipelineLayout() const
128 {
129 return pipelineLayout_;
130 }
131
GetSpecilization() const132 ShaderSpecializationConstantView ShaderModuleVk::GetSpecilization() const
133 {
134 return sscv_;
135 }
136
GetVertexInputDeclaration() const137 VertexInputDeclarationView ShaderModuleVk::GetVertexInputDeclaration() const
138 {
139 return vidv_;
140 }
141
GetThreadGroupSize() const142 ShaderThreadGroup ShaderModuleVk::GetThreadGroupSize() const
143 {
144 return stg_;
145 }
146 RENDER_END_NAMESPACE()
147