1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "shader_module_gles.h"
17
18 #include <algorithm>
19 #include <cstdint>
20
21 #include <base/containers/array_view.h>
22 #include <base/containers/fixed_string.h>
23 #include <base/containers/string.h>
24 #include <base/containers/string_view.h>
25 #include <base/containers/type_traits.h>
26 #include <base/math/vector.h>
27 #include <render/device/pipeline_layout_desc.h>
28 #include <render/namespace.h>
29
30 #include "device/gpu_program_util.h"
31 #include "device/shader_manager.h"
32 #include "gles/spirv_cross_helpers_gles.h"
33 #include "util/log.h"
34
35 using namespace BASE_NS;
36
37 RENDER_BEGIN_NAMESPACE()
38 namespace {
Collect(const uint32_t set,const DescriptorSetLayoutBinding & binding,BASE_NS::vector<ShaderModulePlatformDataGLES::Bind> & sets)39 void Collect(const uint32_t set, const DescriptorSetLayoutBinding& binding,
40 BASE_NS::vector<ShaderModulePlatformDataGLES::Bind>& sets)
41 {
42 const auto name = "s" + to_string(set) + "_b" + to_string(binding.binding);
43 sets.push_back({ static_cast<uint8_t>(set), static_cast<uint8_t>(binding.binding),
44 static_cast<uint8_t>(binding.descriptorCount), string { name } });
45 }
46
CollectRes(const PipelineLayout & pipeline,ShaderModulePlatformDataGLES & plat_)47 void CollectRes(const PipelineLayout& pipeline, ShaderModulePlatformDataGLES& plat_)
48 {
49 struct Bind {
50 uint8_t set;
51 uint8_t bind;
52 };
53 vector<Bind> samplers;
54 vector<Bind> images;
55 for (const auto& set : pipeline.descriptorSetLayouts) {
56 if (set.set != PipelineLayoutConstants::INVALID_INDEX) {
57 for (const auto& binding : set.bindings) {
58 switch (binding.descriptorType) {
59 case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLER:
60 samplers.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
61 break;
62 case DescriptorType::CORE_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
63 Collect(set.set, binding, plat_.cbSets);
64 break;
65 case DescriptorType::CORE_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
66 images.push_back({ static_cast<uint8_t>(set.set), static_cast<uint8_t>(binding.binding) });
67 break;
68 case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_IMAGE:
69 Collect(set.set, binding, plat_.ciSets);
70 break;
71 case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
72 [[fallthrough]];
73 case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
74 break;
75 case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
76 Collect(set.set, binding, plat_.ubSets);
77 break;
78 case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER:
79 Collect(set.set, binding, plat_.sbSets);
80 break;
81 case DescriptorType::CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
82 [[fallthrough]];
83 case DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:
84 break;
85 case DescriptorType::CORE_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
86 Collect(set.set, binding, plat_.siSets);
87 break;
88 case DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE:
89 [[fallthrough]];
90 case DescriptorType::CORE_DESCRIPTOR_TYPE_MAX_ENUM:
91 break;
92 }
93 }
94 }
95 }
96 for (const auto& sBinding : samplers) {
97 for (const auto& iBinding : images) {
98 const auto name = "s" + to_string(iBinding.set) + "_b" + to_string(iBinding.bind) + "_s" +
99 to_string(sBinding.set) + "_b" + to_string(sBinding.bind);
100 plat_.combSets.push_back({ sBinding.set, sBinding.bind, iBinding.set, iBinding.bind, string { name } });
101 }
102 }
103 }
104
CreateSpecInfos(array_view<const ShaderSpecialization::Constant> constants,vector<Gles::SpecConstantInfo> & outSpecInfo)105 void CreateSpecInfos(
106 array_view<const ShaderSpecialization::Constant> constants, vector<Gles::SpecConstantInfo>& outSpecInfo)
107 {
108 static_assert(static_cast<uint32_t>(Gles::SpecConstantInfo::Types::BOOL) ==
109 static_cast<uint32_t>(ShaderSpecialization::Constant::Type::BOOL));
110 for (const auto& constant : constants) {
111 Gles::SpecConstantInfo info { static_cast<Gles::SpecConstantInfo::Types>(constant.type), constant.id, 1U, 1U,
112 {} };
113 outSpecInfo.push_back(info);
114 }
115 }
116
SortSets(PipelineLayout & pipelineLayout)117 void SortSets(PipelineLayout& pipelineLayout)
118 {
119 for (auto& currSet : pipelineLayout.descriptorSetLayouts) {
120 if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
121 std::sort(currSet.bindings.begin(), currSet.bindings.end(),
122 [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
123 }
124 }
125 }
126 } // namespace
127 struct Reader {
128 const uint8_t* ptr;
GetUint8Reader129 uint8_t GetUint8()
130 {
131 return *ptr++;
132 }
133
GetUint16Reader134 uint16_t GetUint16()
135 {
136 const auto value = static_cast<uint16_t>(*ptr | (*(ptr + 1) << 8));
137 ptr += sizeof(uint16_t);
138 return value;
139 }
GetUint32Reader140 uint32_t GetUint32()
141 {
142 const auto value =
143 static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | ((*(ptr + 2)) << 16) | ((*(ptr + 3)) << 24));
144 ptr += sizeof(uint32_t);
145 return value;
146 }
GetStringViewReader147 string_view GetStringView()
148 {
149 string_view value;
150 const uint16_t len = GetUint16();
151 value = string_view(static_cast<const char*>(static_cast<const void*>(ptr)), len);
152 ptr += len;
153 return value;
154 }
155 };
156 template<typename ShaderBase>
ProcessShaderModule(ShaderBase & me,const ShaderModuleCreateInfo & createInfo)157 void ProcessShaderModule(ShaderBase& me, const ShaderModuleCreateInfo& createInfo)
158 {
159 me.pipelineLayout_ = createInfo.reflectionData.GetPipelineLayout();
160 if (me.shaderStageFlags_ & CORE_SHADER_STAGE_VERTEX_BIT) {
161 me.vertexInputAttributeDescriptions_ = createInfo.reflectionData.GetInputDescriptions();
162 me.vertexInputBindingDescriptions_.reserve(me.vertexInputAttributeDescriptions_.size());
163 for (const auto& attrib : me.vertexInputAttributeDescriptions_) {
164 VertexInputDeclaration::VertexInputBindingDescription bindingDesc;
165 bindingDesc.binding = attrib.binding;
166 bindingDesc.stride = GpuProgramUtil::FormatByteSize(attrib.format);
167 bindingDesc.vertexInputRate = VertexInputRate::CORE_VERTEX_INPUT_RATE_VERTEX;
168 me.vertexInputBindingDescriptions_.push_back(bindingDesc);
169 }
170 me.vidv_.bindingDescriptions = { me.vertexInputBindingDescriptions_.data(),
171 me.vertexInputBindingDescriptions_.size() };
172 me.vidv_.attributeDescriptions = { me.vertexInputAttributeDescriptions_.data(),
173 me.vertexInputAttributeDescriptions_.size() };
174 }
175
176 if (me.shaderStageFlags_ & CORE_SHADER_STAGE_COMPUTE_BIT) {
177 const Math::UVec3 tgs = createInfo.reflectionData.GetLocalSize();
178 me.stg_.x = tgs.x;
179 me.stg_.y = tgs.y;
180 me.stg_.z = tgs.z;
181 }
182 if (auto* ptr = createInfo.reflectionData.GetPushConstants(); ptr) {
183 Reader read { ptr };
184 const auto constants = read.GetUint8();
185 for (uint8_t i = 0U; i < constants; ++i) {
186 Gles::PushConstantReflection refl;
187 refl.type = read.GetUint32();
188 refl.offset = read.GetUint16();
189 refl.size = read.GetUint16();
190 refl.arraySize = read.GetUint16();
191 refl.arrayStride = read.GetUint16();
192 refl.matrixStride = read.GetUint16();
193 refl.name = "CORE_PC_0";
194 refl.name += read.GetStringView();
195 refl.stage = me.shaderStageFlags_;
196 me.plat_.infos.push_back(move(refl));
197 }
198 }
199
200 me.constants_ = createInfo.reflectionData.GetSpecializationConstants();
201 me.sscv_.constants = { me.constants_.data(), me.constants_.size() };
202 CollectRes(me.pipelineLayout_, me.plat_);
203 CreateSpecInfos(me.constants_, me.specInfo_);
204 // sort bindings inside sets (and count them)
205 SortSets(me.pipelineLayout_);
206
207 me.source_.assign(
208 static_cast<const char*>(static_cast<const void*>(createInfo.spvData.data())), createInfo.spvData.size());
209 }
210
211 template<typename ShaderBase>
SpecializeShaderModule(const ShaderBase & base,const ShaderSpecializationConstantDataView & specData)212 string SpecializeShaderModule(const ShaderBase& base, const ShaderSpecializationConstantDataView& specData)
213 {
214 return Gles::Specialize(base.shaderStageFlags_, base.source_, base.constants_, specData);
215 }
216
ShaderModuleGLES(Device & device,const ShaderModuleCreateInfo & createInfo)217 ShaderModuleGLES::ShaderModuleGLES(Device& device, const ShaderModuleCreateInfo& createInfo)
218 : device_(device), shaderStageFlags_(createInfo.shaderStageFlags)
219 {
220 if (createInfo.reflectionData.IsValid() &&
221 (shaderStageFlags_ &
222 (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT))) {
223 ProcessShaderModule(*this, createInfo);
224 } else {
225 PLUGIN_LOG_E("invalid shader stages or invalid reflection data for shader module, invalid shader module");
226 }
227 }
228
229 ShaderModuleGLES::~ShaderModuleGLES() = default;
230
GetShaderStageFlags() const231 ShaderStageFlags ShaderModuleGLES::GetShaderStageFlags() const
232 {
233 return shaderStageFlags_;
234 }
235
GetGLSL(const ShaderSpecializationConstantDataView & specData) const236 string ShaderModuleGLES::GetGLSL(const ShaderSpecializationConstantDataView& specData) const
237 {
238 return SpecializeShaderModule(*this, specData);
239 }
240
GetPlatformData() const241 const ShaderModulePlatformData& ShaderModuleGLES::GetPlatformData() const
242 {
243 return plat_;
244 }
245
GetPipelineLayout() const246 const PipelineLayout& ShaderModuleGLES::GetPipelineLayout() const
247 {
248 return pipelineLayout_;
249 }
250
GetSpecilization() const251 ShaderSpecializationConstantView ShaderModuleGLES::GetSpecilization() const
252 {
253 return sscv_;
254 }
255
GetVertexInputDeclaration() const256 VertexInputDeclarationView ShaderModuleGLES::GetVertexInputDeclaration() const
257 {
258 return vidv_;
259 }
260
GetThreadGroupSize() const261 ShaderThreadGroup ShaderModuleGLES::GetThreadGroupSize() const
262 {
263 return stg_;
264 }
265 RENDER_END_NAMESPACE()
266