• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "shader_module_gles.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #include "device/gpu_program_util.h"
22 #include "device/shader_manager.h"
23 #include "gles/device_gles.h"
24 #include "gles/spirv_cross_helpers_gles.h"
25 #include "util/log.h"
26 
27 using namespace BASE_NS;
28 
29 RENDER_BEGIN_NAMESPACE()
30 namespace {
31 template<typename SetType>
Collect(const uint32_t set,const DescriptorSetLayoutBinding & binding,SetType & sets)32 void Collect(const uint32_t set, const DescriptorSetLayoutBinding& binding, SetType& sets)
33 {
34     const auto name = "s" + to_string(set) + "_b" + to_string(binding.binding);
35     sets.push_back({ static_cast<uint8_t>(set), static_cast<uint8_t>(binding.binding),
36         static_cast<uint8_t>(binding.descriptorCount), string { name } });
37 }
38 
CollectRes(const PipelineLayout & pipeline,ShaderModulePlatformDataGLES & plat_)39 void CollectRes(const PipelineLayout& pipeline, ShaderModulePlatformDataGLES& plat_)
40 {
41     struct Bind {
42         uint8_t set;
43         uint8_t bind;
44     };
45     vector<Bind> samplers;
46     vector<Bind> images;
47     for (const auto& set : pipeline.descriptorSetLayouts) {
48         if (set.set != PipelineLayoutConstants::INVALID_INDEX) {
49             for (const auto& binding : set.bindings) {
50                 switch (binding.descriptorType) {
51                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLER:
52                         samplers.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
53                         break;
54                     case DescriptorType::CORE_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
55                         Collect(set.set, binding, plat_.cbSets);
56                         break;
57                     case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
58                         images.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
59                         break;
60                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_IMAGE:
61                         Collect(set.set, binding, plat_.ciSets);
62                         break;
63                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
64                         break;
65                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
66                         break;
67                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
68                         Collect(set.set, binding, plat_.ubSets);
69                         break;
70                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER:
71                         Collect(set.set, binding, plat_.sbSets);
72                         break;
73                     case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
74                         break;
75                     case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
76                         break;
77                     case DescriptorType::CORE_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
78                         Collect(set.set, binding, plat_.siSets);
79                         break;
80                     case DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE:
81                         break;
82                     case DescriptorType::CORE_DESCRIPTOR_TYPE_MAX_ENUM:
83                         break;
84                 }
85             }
86         }
87     }
88     for (const auto& sBinding : samplers) {
89         for (const auto& iBinding : images) {
90             const auto name = "s" + to_string(iBinding.set) + "_b" + to_string(iBinding.bind) + "_s" +
91                               to_string(sBinding.set) + "_b" + to_string(sBinding.bind);
92             plat_.combSets.push_back({ sBinding.set, sBinding.bind, iBinding.set, iBinding.bind, string { name } });
93         }
94     }
95 }
96 
CreateSpecInfos(array_view<const ShaderSpecialization::Constant> constants,vector<Gles::SpecConstantInfo> & outSpecInfo)97 void CreateSpecInfos(
98     array_view<const ShaderSpecialization::Constant> constants, vector<Gles::SpecConstantInfo>& outSpecInfo)
99 {
100     static_assert(static_cast<uint32_t>(Gles::SpecConstantInfo::Types::BOOL) ==
101                   static_cast<uint32_t>(ShaderSpecialization::Constant::Type::BOOL));
102     for (const auto& constant : constants) {
103         Gles::SpecConstantInfo info { static_cast<Gles::SpecConstantInfo::Types>(constant.type), constant.id, 1U, 1U,
104             {} };
105         outSpecInfo.push_back(info);
106     }
107 }
108 
SortSets(PipelineLayout & pipelineLayout)109 void SortSets(PipelineLayout& pipelineLayout)
110 {
111     pipelineLayout.descriptorSetCount = 0;
112     for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
113         DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
114         if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
115             pipelineLayout.descriptorSetCount++;
116             std::sort(currSet.bindings.begin(), currSet.bindings.end(),
117                 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
118         }
119     }
120 }
121 } // namespace
122 struct Reader {
123     const uint8_t* ptr;
GetUint8Reader124     uint8_t GetUint8()
125     {
126         return *ptr++;
127     }
128 
GetUint16Reader129     uint16_t GetUint16()
130     {
131         const uint16_t value = static_cast<uint16_t>(*ptr | (*(ptr + 1) << 8));
132         ptr += sizeof(uint16_t);
133         return value;
134     }
GetUint32Reader135     uint32_t GetUint32()
136     {
137         const uint32_t value =
138             static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | ((*(ptr + 2)) << 16) | ((*(ptr + 3)) << 24));
139         ptr += sizeof(uint32_t);
140         return value;
141     }
GetStringViewReader142     string_view GetStringView()
143     {
144         string_view value;
145         const uint16_t len = GetUint16();
146         value = string_view(static_cast<const char*>(static_cast<const void*>(ptr)), len);
147         ptr += len;
148         return value;
149     }
150 };
151 template<typename ShaderBase>
ProcessShaderModule(ShaderBase & me,const ShaderModuleCreateInfo & createInfo)152 void ProcessShaderModule(ShaderBase& me, const ShaderModuleCreateInfo& createInfo)
153 {
154     me.pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
155     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_VERTEX_BIT) {
156         me.vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
157         me.vertexInputBindingDescriptions_.reserve(me.vertexInputAttributeDescriptions_.size());
158         for (const auto& attrib : me.vertexInputAttributeDescriptions_) {
159             VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
160             bindingDesc.binding = attrib.binding;
161             bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
162             bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
163             me.vertexInputBindingDescriptions_.push_back(bindingDesc);
164         }
165         me.vidv_.bindingDescriptions = { me.vertexInputBindingDescriptions_.data(),
166             me.vertexInputBindingDescriptions_.size() };
167         me.vidv_.attributeDescriptions = { me.vertexInputAttributeDescriptions_.data(),
168             me.vertexInputAttributeDescriptions_.size() };
169     }
170 
171     if (me.shaderStageFlags_ & CORE_SHADER_STAGE_COMPUTE_BIT) {
172         const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
173         me.stg_.x = tgs.x;
174         me.stg_.y = tgs.y;
175         me.stg_.z = tgs.z;
176     }
177     if (auto* ptr = createInfo.reflectionData.GetPushConstants(); ptr) {
178         Reader read { ptr };
179         const auto constants = read.GetUint8();
180         for (uint8_t i = 0U; i < constants; ++i) {
181             Gles::PushConstantReflection refl;
182             refl.type = read.GetUint32();
183             refl.offset = read.GetUint16();
184             refl.size = read.GetUint16();
185             refl.arraySize = read.GetUint16();
186             refl.arrayStride = read.GetUint16();
187             refl.matrixStride = read.GetUint16();
188             refl.name = "CORE_PC_0";
189             refl.name += read.GetStringView();
190             refl.stage = me.shaderStageFlags_;
191             me.plat_.infos.push_back(move(refl));
192         }
193     }
194 
195     me.constants_ = createInfo.reflectionData.GetSpecializationConstants();
196     me.sscv_.constants = { me.constants_.data(), me.constants_.size() };
197     CollectRes(me.pipelineLayout_, me.plat_);
198     CreateSpecInfos(me.constants_, me.specInfo_);
199     // sort bindings inside sets (and count them)
200     SortSets(me.pipelineLayout_);
201 
202     me.source_.assign(
203         static_cast<const char*>(static_cast<const void*>(createInfo.spvData.data())), createInfo.spvData.size());
204 }
205 
206 template<typename ShaderBase>
SpecializeShaderModule(const ShaderBase & base,const ShaderSpecializationConstantDataView & specData)207 string SpecializeShaderModule(const ShaderBase& base, const ShaderSpecializationConstantDataView& specData)
208 {
209     return Gles::Specialize(base.shaderStageFlags_, base.source_, base.constants_, specData);
210 }
211 
ShaderModuleGLES(Device & device,const ShaderModuleCreateInfo & createInfo)212 ShaderModuleGLES::ShaderModuleGLES(Device& device, const ShaderModuleCreateInfo& createInfo)
213     : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
214 {
215     if (createInfo.reflectionData.IsValid() &&
216         (shaderStageFlags_ &
217             (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT))) {
218         ProcessShaderModule(*this, createInfo);
219     } else {
220         PLUGIN_LOG_E("invalid shader stages or invalid reflection data for shader module, invalid shader module");
221     }
222 }
223 
~ShaderModuleGLES()224 ShaderModuleGLES::~ShaderModuleGLES() {}
225 
GetShaderStageFlags() const226 ShaderStageFlags ShaderModuleGLES::GetShaderStageFlags() const
227 {
228     return shaderStageFlags_;
229 }
230 
GetGLSL(const ShaderSpecializationConstantDataView & specData) const231 string ShaderModuleGLES::GetGLSL(const ShaderSpecializationConstantDataView& specData) const
232 {
233     return SpecializeShaderModule(*this, specData);
234 }
235 
GetPlatformData() const236 const ShaderModulePlatformData& ShaderModuleGLES::GetPlatformData() const
237 {
238     return plat_;
239 }
240 
GetPipelineLayout() const241 const PipelineLayout& ShaderModuleGLES::GetPipelineLayout() const
242 {
243     return pipelineLayout_;
244 }
245 
GetSpecilization() const246 ShaderSpecilizationConstantView ShaderModuleGLES::GetSpecilization() const
247 {
248     return sscv_;
249 }
250 
GetVertexInputDeclaration() const251 VertexInputDeclarationView ShaderModuleGLES::GetVertexInputDeclaration() const
252 {
253     return vidv_;
254 }
255 
GetThreadGroupSize() const256 ShaderThreadGroup ShaderModuleGLES::GetThreadGroupSize() const
257 {
258     return stg_;
259 }
260 RENDER_END_NAMESPACE()
261