• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "gpu_program_vk.h"
17 
18 #include <cstdint>
19 
20 #include <base/containers/array_view.h>
21 #include <base/containers/vector.h>
22 #include <render/device/pipeline_layout_desc.h>
23 #include <render/namespace.h>
24 
25 #include "device/device.h"
26 #include "device/gpu_program_util.h"
27 #include "util/log.h"
28 #include "vulkan/device_vk.h"
29 #include "vulkan/shader_module_vk.h"
30 
31 using namespace BASE_NS;
32 
RENDER_BEGIN_NAMESPACE()33 RENDER_BEGIN_NAMESPACE()
34 GpuShaderProgramVk::GpuShaderProgramVk(const GpuShaderProgramCreateData& createData) : GpuShaderProgram()
35 {
36     PLUGIN_ASSERT(createData.vertShaderModule);
37     PLUGIN_ASSERT(createData.fragShaderModule);
38 
39     // combine vertex and fragment shader data
40     if (createData.vertShaderModule && createData.fragShaderModule) {
41         vertShaderModule_ = static_cast<ShaderModuleVk*>(createData.vertShaderModule);
42         fragShaderModule_ = static_cast<ShaderModuleVk*>(createData.fragShaderModule);
43         auto& pipelineLayout = reflection_.pipelineLayout;
44 
45         { // vert
46             const ShaderModuleVk& mod = *vertShaderModule_;
47             plat_.vert = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
48             pipelineLayout = mod.GetPipelineLayout();
49             const auto& sscv = mod.GetSpecilization();
50             // has sort inside
51             GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
52 
53             // not owned, directly reflected from vertex shader module
54             const auto& vidv = mod.GetVertexInputDeclaration();
55             reflection_.vertexInputDeclarationView.bindingDescriptions = vidv.bindingDescriptions;
56             reflection_.vertexInputDeclarationView.attributeDescriptions = vidv.attributeDescriptions;
57         }
58         { // frag
59             const ShaderModuleVk& mod = *fragShaderModule_;
60             plat_.frag = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
61 
62             const auto& sscv = mod.GetSpecilization();
63             // has sort inside
64             GpuProgramUtil::CombineSpecializationConstants(sscv.constants, constants_);
65 
66             const auto& reflPl = mod.GetPipelineLayout();
67             // has sort inside
68             GpuProgramUtil::CombinePipelineLayouts({ &reflPl, 1u }, pipelineLayout);
69         }
70 
71         reflection_.shaderSpecializationConstantView.constants =
72             array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
73     }
74 }
75 
GetPlatformData() const76 const GpuShaderProgramPlatformDataVk& GpuShaderProgramVk::GetPlatformData() const
77 {
78     return plat_;
79 }
80 
GetReflection() const81 const ShaderReflection& GpuShaderProgramVk::GetReflection() const
82 {
83     return reflection_;
84 }
85 
GpuComputeProgramVk(const GpuComputeProgramCreateData & createData)86 GpuComputeProgramVk::GpuComputeProgramVk(const GpuComputeProgramCreateData& createData) : GpuComputeProgram()
87 {
88     PLUGIN_ASSERT(createData.compShaderModule);
89 
90     if (createData.compShaderModule) {
91         shaderModule_ = static_cast<ShaderModuleVk*>(createData.compShaderModule);
92         {
93             const ShaderModuleVk& mod = *shaderModule_;
94             plat_.comp = ((const ShaderModulePlatformDataVk&)mod.GetPlatformData()).shaderModule;
95             // copy needed data
96             reflection_.pipelineLayout = mod.GetPipelineLayout();
97             const auto& tgs = mod.GetThreadGroupSize();
98             reflection_.threadGroupSizeX = Math::max(1u, tgs.x);
99             reflection_.threadGroupSizeY = Math::max(1u, tgs.y);
100             reflection_.threadGroupSizeZ = Math::max(1u, tgs.z);
101             const auto& sscv = mod.GetSpecilization();
102             constants_ =
103                 vector<ShaderSpecialization::Constant>(sscv.constants.cbegin().ptr(), sscv.constants.cend().ptr());
104         }
105 
106         reflection_.shaderSpecializationConstantView.constants =
107             array_view<ShaderSpecialization::Constant const>(constants_.data(), constants_.size());
108     }
109 }
110 
GetPlatformData() const111 const GpuComputeProgramPlatformDataVk& GpuComputeProgramVk::GetPlatformData() const
112 {
113     return plat_;
114 }
115 
GetReflection() const116 const ComputeShaderReflection& GpuComputeProgramVk::GetReflection() const
117 {
118     return reflection_;
119 }
120 RENDER_END_NAMESPACE()
121