• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "shader_module_gles.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #include <base/containers/array_view.h>
22 #include <base/containers/fixed_string.h>
23 #include <base/containers/string.h>
24 #include <base/containers/string_view.h>
25 #include <base/containers/type_traits.h>
26 #include <base/math/vector.h>
27 #include <render/device/pipeline_layout_desc.h>
28 #include <render/namespace.h>
29 
30 #include "device/gpu_program_util.h"
31 #include "device/shader_manager.h"
32 #include "gles/spirv_cross_helpers_gles.h"
33 #include "util/log.h"
34 
35 using namespace BASE_NS;
36 
37 RENDER_BEGIN_NAMESPACE()
38 namespace {
Collect(const uint32_t set,const DescriptorSetLayoutBinding & binding,BASE_NS::vector<ShaderModulePlatformDataGLES::Bind> & sets)39 void Collect(const uint32_t set, const DescriptorSetLayoutBinding& binding,
40     BASE_NS::vector<ShaderModulePlatformDataGLES::Bind>& sets)
41 {
42     const auto name = "s" + to_string(set) + "_b" + to_string(binding.binding);
43     sets.push_back({ static_cast<uint8_t>(set), static_cast<uint8_t>(binding.binding),
44         static_cast<uint8_t>(binding.descriptorCount), string { name } });
45 }
46 
CollectRes(const PipelineLayout & pipeline,ShaderModulePlatformDataGLES & plat_)47 void CollectRes(const PipelineLayout& pipeline, ShaderModulePlatformDataGLES& plat_)
48 {
49     struct Bind {
50         uint8_t set;
51         uint8_t bind;
52     };
53     vector<Bind> samplers;
54     vector<Bind> images;
55     for (const auto& set : pipeline.descriptorSetLayouts) {
56         if (set.set != PipelineLayoutConstants::INVALID_INDEX) {
57             for (const auto& binding : set.bindings) {
58                 switch (binding.descriptorType) {
59                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLER:
60                         samplers.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
61                         break;
62                     case DescriptorType::CORE_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
63                         Collect(set.set, binding, plat_.cbSets);
64                         break;
65                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
66                         images.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
67                         break;
68                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_IMAGE:
69                         Collect(set.set, binding, plat_.ciSets);
70                         break;
71                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
72                         [[fallthrough]];
73                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
74                         break;
75                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
76                         Collect(set.set, binding, plat_.ubSets);
77                         break;
78                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER:
79                         Collect(set.set, binding, plat_.sbSets);
80                         break;
81                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
82                         [[fallthrough]];
83                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
84                         break;
85                     case DescriptorType::CORE_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
86                         Collect(set.set, binding, plat_.siSets);
87                         break;
88                     case DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE:
89                         [[fallthrough]];
90                     case DescriptorType::CORE_DESCRIPTOR_TYPE_MAX_ENUM:
91                         break;
92                 }
93             }
94         }
95     }
96     for (const auto& sBinding : samplers) {
97         for (const auto& iBinding : images) {
98             const auto name = "s" + to_string(iBinding.set) + "_b" + to_string(iBinding.bind) + "_s" +
99                               to_string(sBinding.set) + "_b" + to_string(sBinding.bind);
100             plat_.combSets.push_back({ sBinding.set, sBinding.bind, iBinding.set, iBinding.bind, string { name } });
101         }
102     }
103 }
104 
CreateSpecInfos(array_view<const ShaderSpecialization::Constant> constants,vector<Gles::SpecConstantInfo> & outSpecInfo)105 void CreateSpecInfos(
106     array_view<const ShaderSpecialization::Constant> constants, vector<Gles::SpecConstantInfo>& outSpecInfo)
107 {
108     static_assert(static_cast<uint32_t>(Gles::SpecConstantInfo::Types::BOOL) ==
109                   static_cast<uint32_t>(ShaderSpecialization::Constant::Type::BOOL));
110     for (const auto& constant : constants) {
111         Gles::SpecConstantInfo info { static_cast<Gles::SpecConstantInfo::Types>(constant.type), constant.id, 1U, 1U,
112             {} };
113         outSpecInfo.push_back(info);
114     }
115 }
116 
SortSets(PipelineLayout & pipelineLayout)117 void SortSets(PipelineLayout& pipelineLayout)
118 {
119     pipelineLayout.descriptorSetCount = 0;
120     for (auto& currSet : pipelineLayout.descriptorSetLayouts) {
121         if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
122             pipelineLayout.descriptorSetCount++;
123             std::sort(currSet.bindings.begin(), currSet.bindings.end(),
124                 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
125         }
126     }
127 }
128 } // namespace
129 struct Reader {
130     const uint8_t* ptr;
GetUint8Reader131     uint8_t GetUint8()
132     {
133         return *ptr++;
134     }
135 
GetUint16Reader136     uint16_t GetUint16()
137     {
138         const auto value = static_cast<uint16_t>(*ptr | (*(ptr + 1) << 8));
139         ptr += sizeof(uint16_t);
140         return value;
141     }
GetUint32Reader142     uint32_t GetUint32()
143     {
144         const auto value =
145             static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | ((*(ptr + 2)) << 16) | ((*(ptr + 3)) << 24));
146         ptr += sizeof(uint32_t);
147         return value;
148     }
GetStringViewReader149     string_view GetStringView()
150     {
151         string_view value;
152         const uint16_t len = GetUint16();
153         value = string_view(static_cast<const char*>(static_cast<const void*>(ptr)), len);
154         ptr += len;
155         return value;
156     }
157 };
158 template<typename ShaderBase>
ProcessShaderModule(ShaderBase & me,const ShaderModuleCreateInfo & createInfo)159 void ProcessShaderModule(ShaderBase& me, const ShaderModuleCreateInfo& createInfo)
160 {
161     me.pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
162     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_VERTEX_BIT) {
163         me.vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
164         me.vertexInputBindingDescriptions_.reserve(me.vertexInputAttributeDescriptions_.size());
165         for (const auto& attrib : me.vertexInputAttributeDescriptions_) {
166             VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
167             bindingDesc.binding = attrib.binding;
168             bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
169             bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
170             me.vertexInputBindingDescriptions_.push_back(bindingDesc);
171         }
172         me.vidv_.bindingDescriptions = { me.vertexInputBindingDescriptions_.data(),
173             me.vertexInputBindingDescriptions_.size() };
174         me.vidv_.attributeDescriptions = { me.vertexInputAttributeDescriptions_.data(),
175             me.vertexInputAttributeDescriptions_.size() };
176     }
177 
178     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_COMPUTE_BIT) {
179         const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
180         me.stg_.x = tgs.x;
181         me.stg_.y = tgs.y;
182         me.stg_.z = tgs.z;
183     }
184     if (auto* ptr = createInfo.reflectionData.GetPushConstants(); ptr) {
185         Reader read { ptr };
186         const auto constants = read.GetUint8();
187         for (uint8_t i = 0U; i < constants; ++i) {
188             Gles::PushConstantReflection refl;
189             refl.type = read.GetUint32();
190             refl.offset = read.GetUint16();
191             refl.size = read.GetUint16();
192             refl.arraySize = read.GetUint16();
193             refl.arrayStride = read.GetUint16();
194             refl.matrixStride = read.GetUint16();
195             refl.name = "CORE_PC_0";
196             refl.name += read.GetStringView();
197             refl.stage = me.shaderStageFlags_;
198             me.plat_.infos.push_back(move(refl));
199         }
200     }
201 
202     me.constants_ = createInfo.reflectionData.GetSpecializationConstants();
203     me.sscv_.constants = { me.constants_.data(), me.constants_.size() };
204     CollectRes(me.pipelineLayout_, me.plat_);
205     CreateSpecInfos(me.constants_, me.specInfo_);
206     // sort bindings inside sets (and count them)
207     SortSets(me.pipelineLayout_);
208 
209     me.source_.assign(
210         static_cast<const char*>(static_cast<const void*>(createInfo.spvData.data())), createInfo.spvData.size());
211 }
212 
213 template<typename ShaderBase>
SpecializeShaderModule(const ShaderBase & base,const ShaderSpecializationConstantDataView & specData)214 string SpecializeShaderModule(const ShaderBase& base, const ShaderSpecializationConstantDataView& specData)
215 {
216     return Gles::Specialize(base.shaderStageFlags_, base.source_, base.constants_, specData);
217 }
218 
ShaderModuleGLES(Device & device,const ShaderModuleCreateInfo & createInfo)219 ShaderModuleGLES::ShaderModuleGLES(Device& device, const ShaderModuleCreateInfo& createInfo)
220     : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
221 {
222     if (createInfo.reflectionData.IsValid() &&
223         (shaderStageFlags_ &
224             (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT))) {
225         ProcessShaderModule(*this, createInfo);
226     } else {
227         PLUGIN_LOG_E("invalid shader stages or invalid reflection data for shader module, invalid shader module");
228     }
229 }
230 
231 ShaderModuleGLES::~ShaderModuleGLES() = default;
232 
GetShaderStageFlags() const233 ShaderStageFlags ShaderModuleGLES::GetShaderStageFlags() const
234 {
235     return shaderStageFlags_;
236 }
237 
GetGLSL(const ShaderSpecializationConstantDataView & specData) const238 string ShaderModuleGLES::GetGLSL(const ShaderSpecializationConstantDataView& specData) const
239 {
240     return SpecializeShaderModule(*this, specData);
241 }
242 
GetPlatformData() const243 const ShaderModulePlatformData& ShaderModuleGLES::GetPlatformData() const
244 {
245     return plat_;
246 }
247 
GetPipelineLayout() const248 const PipelineLayout& ShaderModuleGLES::GetPipelineLayout() const
249 {
250     return pipelineLayout_;
251 }
252 
GetSpecilization() const253 ShaderSpecializationConstantView ShaderModuleGLES::GetSpecilization() const
254 {
255     return sscv_;
256 }
257 
GetVertexInputDeclaration() const258 VertexInputDeclarationView ShaderModuleGLES::GetVertexInputDeclaration() const
259 {
260     return vidv_;
261 }
262 
GetThreadGroupSize() const263 ShaderThreadGroup ShaderModuleGLES::GetThreadGroupSize() const
264 {
265     return stg_;
266 }
267 RENDER_END_NAMESPACE()
268