1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "gpu_program_vk.h"
17
18 #include <cstdint>
19
20 #include <base/containers/array_view.h>
21 #include <base/containers/vector.h>
22 #include <render/device/pipeline_layout_desc.h>
23 #include <render/namespace.h>
24
25 #include "device/device.h"
26 #include "device/gpu_program_util.h"
27 #include "util/log.h"
28 #include "vulkan/device_vk.h"
29 #include "vulkan/shader_module_vk.h"
30
31 using namespace BASE_NS;
32
RENDER_BEGIN_NAMESPACE()33 RENDER_BEGIN_NAMESPACE()
34 GpuShaderProgramVk::GpuShaderProgramVk(const GpuShaderProgramCreateData& createData) : GpuShaderProgram()
35 {
36 PLUGIN_ASSERT(createData.vertShaderModule);
37 PLUGIN_ASSERT(createData.fragShaderModule);
38
39 // combine vertex and fragment shader data
40 if (createData.vertShaderModule && createData.fragShaderModule) {
41 vertShaderModule_ = static_cast<ShaderModuleVk*>(createData.vertShaderModule);
42 fragShaderModule_ = static_cast<ShaderModuleVk*>(createData.fragShaderModule);
43 auto& pipelineLayout = reflection_.pipelineLayout;
44
45 { // vert
46 const ShaderModuleVk& mod = *vertShaderModule_;
47 plat_.vert = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
48 pipelineLayout = mod.GetPipelineLayout();
49 const auto& sscv = mod.GetSpecilization();
50 // has sort inside
51 GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
52
53 // not owned, directly reflected from vertex shader module
54 const auto& vidv = mod.GetVertexInputDeclaration();
55 reflection_.vertexInputDeclarationView.bindingDescriptions = vidv.bindingDescriptions;
56 reflection_.vertexInputDeclarationView.attributeDescriptions = vidv.attributeDescriptions;
57 }
58 { // frag
59 const ShaderModuleVk& mod = *fragShaderModule_;
60 plat_.frag = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
61
62 const auto& sscv = mod.GetSpecilization();
63 // has sort inside
64 GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
65
66 const auto& reflPl = mod.GetPipelineLayout();
67 // has sort inside
68 GpuProgramUtil::CombinePipelineLayouts({ &reflPl, 1u }, pipelineLayout);
69 }
70
71 reflection_.shaderSpecializationConstantView.constants =
72 array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
73 }
74 }
75
GetPlatformData() const76 const GpuShaderProgramPlatformDataVk& GpuShaderProgramVk::GetPlatformData() const
77 {
78 return plat_;
79 }
80
GetReflection() const81 const ShaderReflection& GpuShaderProgramVk::GetReflection() const
82 {
83 return reflection_;
84 }
85
GpuComputeProgramVk(const GpuComputeProgramCreateData & createData)86 GpuComputeProgramVk::GpuComputeProgramVk(const GpuComputeProgramCreateData& createData) : GpuComputeProgram()
87 {
88 PLUGIN_ASSERT(createData.compShaderModule);
89
90 if (createData.compShaderModule) {
91 shaderModule_ = static_cast<ShaderModuleVk*>(createData.compShaderModule);
92 {
93 const ShaderModuleVk& mod = *shaderModule_;
94 plat_.comp = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
95 // copy needed data
96 reflection_.pipelineLayout = mod.GetPipelineLayout();
97 const auto& tgs = mod.GetThreadGroupSize();
98 reflection_.threadGroupSizeX = Math::max(1u, tgs.x);
99 reflection_.threadGroupSizeY = Math::max(1u, tgs.y);
100 reflection_.threadGroupSizeZ = Math::max(1u, tgs.z);
101 const auto& sscv = mod.GetSpecilization();
102 constants_ =
103 vector<ShaderSpecialization::Constant>(sscv.constants.cbegin().ptr(), sscv.constants.cend().ptr());
104 }
105
106 reflection_.shaderSpecializationConstantView.constants =
107 array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
108 }
109 }
110
GetPlatformData() const111 const GpuComputeProgramPlatformDataVk& GpuComputeProgramVk::GetPlatformData() const
112 {
113 return plat_;
114 }
115
GetReflection() const116 const ComputeShaderReflection& GpuComputeProgramVk::GetReflection() const
117 {
118 return reflection_;
119 }
120 RENDER_END_NAMESPACE()
121